Skip to content

Commit

Permalink
Merge pull request #10 from zgornel/latest
Browse files Browse the repository at this point in the history
Allow type specification for hdf5 embeddings loading
  • Loading branch information
zgornel authored Dec 27, 2018
2 parents 9e10bec + 24d5e80 commit e264e0c
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 22 deletions.
18 changes: 9 additions & 9 deletions src/files.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ function load_embeddings(filepath::AbstractString;
conceptnet = _load_hdf5_embeddings(filepath,
max_vocab_size,
keep_words,
languages=languages)
languages=languages,
data_type=data_type)
else
conceptnet = _load_gz_embeddings(filepath,
Noop(),
Expand Down Expand Up @@ -125,11 +126,10 @@ Load the ConceptNetNumberbatch embeddings from a HDF5 file.
function _load_hdf5_embeddings(filepath::S1,
max_vocab_size::Union{Nothing,Int},
keep_words::Vector{S2};
languages::Union{Nothing,
Languages.Language,
Vector{<:Languages.Language}
}=nothing) where
{S1<:AbstractString, S2<:AbstractString}
languages::Union{Nothing, Languages.Language,
Vector{<:Languages.Language}}=nothing,
data_type::Type{E}=Int8) where
{S1<:AbstractString, S2<:AbstractString, E<:Real}
local fuzzy_words
type_word = String
payload = h5open(read, filepath)["mat"]
Expand All @@ -142,7 +142,7 @@ function _load_hdf5_embeddings(filepath::S1,
max_vocab_size,
keep_words)
lang_embs, languages, type_lang, _ =
process_language_argument(languages, type_word, Int8)
process_language_argument(languages, type_word, E)
fuzzy_words = Dict{type_lang, Vector{type_word}}()
no_custom_words = length(keep_words)==0
cnt = 0
Expand All @@ -151,7 +151,7 @@ function _load_hdf5_embeddings(filepath::S1,
if haskey(LANGUAGES, lang) && LANGUAGES[lang] in languages # use only languages mapped in LANGUAGES
_llang = LANGUAGES[lang]
if !haskey(lang_embs, _llang)
push!(lang_embs, _llang=>Dict{type_word, Vector{Int8}}())
push!(lang_embs, _llang=>Dict{type_word, Vector{E}}())
push!(fuzzy_words, _llang=>type_word[])
end
occursin("#", word) && push!(fuzzy_words[_llang], word)
Expand All @@ -163,7 +163,7 @@ function _load_hdf5_embeddings(filepath::S1,
end
end
end
return ConceptNet{type_lang, type_word, Int8}(lang_embs, size(embeddings,1), fuzzy_words)
return ConceptNet{type_lang, type_word, E}(lang_embs, size(embeddings,1), fuzzy_words)
end


Expand Down
34 changes: 21 additions & 13 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@ using Languages
using ConceptnetNumberbatch

# Test file with just 2 entriesa (test purposes only)
const DATA_TYPE = Float64
const CONCEPTNET_TEST_DATA = Dict( # filename => output type
const DATA_TYPE = Float32
const CONCEPTNET_TEST_DATA = Dict(
# filename => output type
(joinpath(string(@__DIR__), "data", "_test_file_en.txt.gz") =>
([Languages.English()],
["####_ish", "####_form", "####_metres"],
Expand All @@ -17,26 +18,29 @@ const CONCEPTNET_TEST_DATA = Dict( # filename => output type

(joinpath(string(@__DIR__), "data", "_test_file.txt") =>
(nothing,
["1_konings", "aaklig", "aak"],
["1_konings", "aaklig", "aak"],
ConceptNet{Languages.Language, String, DATA_TYPE})),

(joinpath(string(@__DIR__), "data", "_test_file.h5") =>
(nothing,
["1", "2", "2d"],
ConceptNet{Languages.Language, String, Int8}))
)
ConceptNet{Languages.Language, String, DATA_TYPE}))
)

@testset "Parser: (no arguments)" begin
for (filename, (languages, _, resulting_type)) in CONCEPTNET_TEST_DATA
conceptnet = load_embeddings(filename, languages=languages);
conceptnet = load_embeddings(filename,
languages=languages,
data_type=DATA_TYPE);
@test conceptnet isa resulting_type
end
end

max_vocab_size=5
@testset "Parser: max_vocab_size=5" begin
for (filename, (languages, _, _)) in CONCEPTNET_TEST_DATA
conceptnet = load_embeddings(filename, max_vocab_size=max_vocab_size,
conceptnet = load_embeddings(filename,
max_vocab_size=max_vocab_size,
languages=languages);
@test length(conceptnet) == max_vocab_size
end
Expand All @@ -45,8 +49,10 @@ end
max_vocab_size=5
@testset "Parser: max_vocab_size=5, 3 keep words" begin
for (filename, (languages, keep_words, _)) in CONCEPTNET_TEST_DATA
conceptnet = load_embeddings(filename, max_vocab_size=max_vocab_size,
keep_words=keep_words, languages=languages)
conceptnet = load_embeddings(filename,
max_vocab_size=max_vocab_size,
keep_words=keep_words,
languages=languages)
@test length(conceptnet) == length(keep_words)
for word in keep_words
@test word in conceptnet
Expand All @@ -63,7 +69,7 @@ end
# Test indexing
idx = 1
@test conceptnet[words[idx]] == conceptnet[:en, words[idx]] ==
conceptnet[Languages.English(), words[idx]]
conceptnet[Languages.English(), words[idx]]

# Test values
embeddings = conceptnet[words]
Expand All @@ -84,14 +90,14 @@ end
@test_throws MethodError conceptnet[words] # type of language is Language, cannot directly search
@test_throws KeyError conceptnet[:en, "word"] # English language not present
@test conceptnet[:nl, words[idx]] ==
conceptnet[Languages.Dutch(), words[idx]]
conceptnet[Languages.Dutch(), words[idx]]

# Test values
for (idx, word) in enumerate(words)
@test_throws KeyError conceptnet[Languages.English(), word]
if word in conceptnet
@test vec(conceptnet[Languages.Dutch(), word]) ==
conceptnet.embeddings[Languages.Dutch()][word]
conceptnet.embeddings[Languages.Dutch()][word]
else
@test iszero(conceptnet[Languages.Dutch(),word])
end
Expand All @@ -112,7 +118,9 @@ end

@testset "Document Embedding" begin
filepath = joinpath(string(@__DIR__), "data", "_test_file_en.txt.gz")
conceptnet = load_embeddings(filepath, languages=[Languages.English()])
conceptnet = load_embeddings(filepath,
languages=[Languages.English()],
data_type=DATA_TYPE)
# Document with no matchable words
doc = "a aaaaa b"
embedded_doc, missed = embed_document(conceptnet,
Expand Down

0 comments on commit e264e0c

Please sign in to comment.