Skip to content

Commit

Permalink
Merge pull request #9 from zgornel/latest
Browse files Browse the repository at this point in the history
Latest
  • Loading branch information
zgornel authored Nov 8, 2018
2 parents 8ffc550 + 0d4900d commit 9e10bec
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 84 deletions.
19 changes: 10 additions & 9 deletions src/document_embeddings.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@ end
"""
Retrieves the embedding matrix for a given `document`.
"""
function embed_document(conceptnet::ConceptNet{L,K,V},
function embed_document(conceptnet::ConceptNet{L,K,E},
document::AbstractString;
language=Languages.English(),
keep_size::Bool=true,
compound_word_separator::String="_",
max_compound_word_length::Int=1,
wildcard_matching::Bool=false,
print_matched_words::Bool=false
) where {L<:Language, K, V}
) where {L<:Language, K, E<:Real}
# Split document into tokens and embed
return embed_document(conceptnet,
tokenize_for_conceptnet(document),
Expand All @@ -37,15 +37,15 @@ function embed_document(conceptnet::ConceptNet{L,K,V},
print_matched_words=print_matched_words)
end

function embed_document(conceptnet::ConceptNet{L,K,V},
function embed_document(conceptnet::ConceptNet{L,K,E},
document_tokens::Vector{S};
language=Languages.English(),
keep_size::Bool=true,
compound_word_separator::String="_",
max_compound_word_length::Int=1,
wildcard_matching::Bool=false,
print_matched_words::Bool=false
) where {L<:Language, K, V, S<:AbstractString}
) where {L<:Language, K, E<:Real, S<:AbstractString}
# Initializations
embeddings = conceptnet.embeddings[language]
# Get positions of words that can be used for indexing (found)
Expand Down Expand Up @@ -74,15 +74,15 @@ function embed_document(conceptnet::ConceptNet{L,K,V},
println("Embedded words: $found_words")
println("Mismatched words: $words_not_found")
end
default = zeros(eltype(V), conceptnet.width)
default = zeros(E, conceptnet.width)
_embdoc = get(conceptnet.embeddings[language],
found_words,
default,
conceptnet.fuzzy_words[language],
n=conceptnet.width,
wildcard_matching=wildcard_matching)
if keep_size
embedded_document = hcat(_embdoc, zeros(eltype(V), conceptnet.width,
embedded_document = hcat(_embdoc, zeros(E, conceptnet.width,
length(not_found_positions)))
else
embedded_document = _embdoc
Expand Down Expand Up @@ -127,13 +127,13 @@ end
# ...
# more_complicated,
# complicated]
function token_search(conceptnet::ConceptNet{L,K,V},
function token_search(conceptnet::ConceptNet{L,K,E},
tokens::Vector{S};
language::L=Languages.English(),
separator::String="_",
max_length::Int=3,
wildcard_matching::Bool=false) where
{L<:Language, K, V, S<:AbstractString}
{L<:Language, K, E<:Real, S<:AbstractString}
# Initializations
found = Vector{UnitRange{Int}}()
n = length(tokens)
Expand All @@ -142,7 +142,8 @@ function token_search(conceptnet::ConceptNet{L,K,V},
while i <= n
if j-i+1 <= max_length
token = join(tokens[i:j], separator, separator)
is_match = !isempty(get(conceptnet[language], token, V(),
is_match = !isempty(get(conceptnet[language], token,
Vector{E}(),
conceptnet.fuzzy_words[language],
wildcard_matching=wildcard_matching))
if is_match
Expand Down
65 changes: 32 additions & 33 deletions src/files.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,10 @@ specified, filtering on `languages`.
function load_embeddings(filepath::AbstractString;
max_vocab_size::Union{Nothing,Int}=nothing,
keep_words=String[],
languages::Union{Nothing,
Languages.Language,
languages::Union{Nothing, Languages.Language,
Vector{<:Languages.Language},
Symbol,
Vector{Symbol}
}=nothing)
Symbol, Vector{Symbol}}=nothing,
data_type::Type{E}=Float64) where E<:Real
if languages isa Nothing
languages = unique(collect(values(LANGUAGES)))
elseif languages isa Symbol
Expand All @@ -46,7 +44,8 @@ function load_embeddings(filepath::AbstractString;
GzipDecompressor(),
max_vocab_size,
keep_words,
languages=languages)
languages=languages,
data_type=data_type)
elseif any(endswith.(filepath, [".h5", ".hdf5"]))
conceptnet = _load_hdf5_embeddings(filepath,
max_vocab_size,
Expand All @@ -57,7 +56,8 @@ function load_embeddings(filepath::AbstractString;
Noop(),
max_vocab_size,
keep_words,
languages=languages)
languages=languages,
data_type=data_type)
end
return conceptnet
end
Expand All @@ -70,28 +70,27 @@ function _load_gz_embeddings(filepath::S1,
decompressor::TranscodingStreams.Codec,
max_vocab_size::Union{Nothing,Int},
keep_words::Vector{S2};
languages::Union{Nothing,
Languages.Language,
languages::Union{Nothing, Languages.Language,
Vector{<:Languages.Language}
}=nothing) where
{S1<:AbstractString, S2<:AbstractString}
}=nothing,
data_type::Type{E}=Float64) where
{E<:Real, S1<:AbstractString, S2<:AbstractString}
local lang_embs, _length::Int, _width::Int, type_lang, fuzzy_words
type_word = String
type_vector = Vector{Float64}
open(filepath, "r") do fid
cfid = TranscodingStream(decompressor, fid)
_length, _width = parse.(Int64, split(readline(cfid)))
vocab_size = _get_vocab_size(_length,
max_vocab_size,
keep_words)
lang_embs, languages, type_lang, english_only =
process_language_argument(languages, type_word, type_vector)
process_language_argument(languages, type_word, data_type)
fuzzy_words = Dict{type_lang, Vector{type_word}}()
no_custom_words = length(keep_words)==0
lang = :en
cnt = 0
for (idx, line) in enumerate(eachline(cfid))
word, _ = _parseline(line, word_only=true)
word, _ = _parseline(line, data_type, word_only=true)
if !english_only
_, _, _lang, word = split(word,"/")
lang = Symbol(_lang)
Expand All @@ -100,10 +99,11 @@ function _load_gz_embeddings(filepath::S1,
if lang in keys(LANGUAGES) && LANGUAGES[lang] in languages # use only languages mapped in LANGUAGES
_llang = LANGUAGES[lang]
if !haskey(lang_embs, _llang)
push!(lang_embs, _llang=>Dict{type_word, type_vector}())
push!(lang_embs, _llang=>Dict{type_word,
Vector{data_type}}())
push!(fuzzy_words, _llang=>type_word[])
end
_, embedding = _parseline(line, word_only=false)
_, embedding = _parseline(line, data_type, word_only=false)
occursin("#", word) && push!(fuzzy_words[_llang], word)
push!(lang_embs[_llang], word=>embedding)
cnt+=1
Expand All @@ -115,7 +115,7 @@ function _load_gz_embeddings(filepath::S1,
end
close(cfid)
end
return ConceptNet{type_lang, type_word, type_vector}(lang_embs, _width, fuzzy_words)
return ConceptNet{type_lang, type_word, data_type}(lang_embs, _width, fuzzy_words)
end


Expand All @@ -132,7 +132,6 @@ function _load_hdf5_embeddings(filepath::S1,
{S1<:AbstractString, S2<:AbstractString}
local fuzzy_words
type_word = String
type_vector = Vector{Int8}
payload = h5open(read, filepath)["mat"]
words = map(payload["axis1"]) do val
_, _, lang, word = split(val, "/")
Expand All @@ -143,7 +142,7 @@ function _load_hdf5_embeddings(filepath::S1,
max_vocab_size,
keep_words)
lang_embs, languages, type_lang, _ =
process_language_argument(languages, type_word, type_vector)
process_language_argument(languages, type_word, Int8)
fuzzy_words = Dict{type_lang, Vector{type_word}}()
no_custom_words = length(keep_words)==0
cnt = 0
Expand All @@ -152,7 +151,7 @@ function _load_hdf5_embeddings(filepath::S1,
if haskey(LANGUAGES, lang) && LANGUAGES[lang] in languages # use only languages mapped in LANGUAGES
_llang = LANGUAGES[lang]
if !haskey(lang_embs, _llang)
push!(lang_embs, _llang=>Dict{type_word, type_vector}())
push!(lang_embs, _llang=>Dict{type_word, Vector{Int8}}())
push!(fuzzy_words, _llang=>type_word[])
end
occursin("#", word) && push!(fuzzy_words[_llang], word)
Expand All @@ -164,7 +163,7 @@ function _load_hdf5_embeddings(filepath::S1,
end
end
end
return ConceptNet{type_lang, type_word, type_vector}(lang_embs, size(embeddings,1), fuzzy_words)
return ConceptNet{type_lang, type_word, Int8}(lang_embs, size(embeddings,1), fuzzy_words)
end


Expand All @@ -178,32 +177,32 @@ end
# - a flag specifying whether only English is used or not
function process_language_argument(languages::Nothing,
type_word::T1,
type_vector::T2) where {T1, T2}
return Dict{Languages.Language, Dict{type_word, type_vector}}(),
type_data::T2) where {T1, T2}
return Dict{Languages.Language, Dict{type_word, Vector{type_data}}}(),
collect(language for language in LANGUAGES),
Languages.Language, false
end

function process_language_argument(languages::Languages.English,
type_word::T1,
type_vector::T2) where {T1, T2}
return Dict{Languages.English, Dict{type_word, type_vector}}(), [languages],
type_data::T2) where {T1, T2}
return Dict{Languages.English, Dict{type_word, Vector{type_data}}}(), [languages],
Languages.English, true
end

function process_language_argument(languages::L,
type_word::T1,
type_vector::T2) where {L<:Languages.Language, T1, T2}
return Dict{L, Dict{type_word, type_vector}}(), [languages], L, false
type_data::T2) where {L<:Languages.Language, T1, T2}
return Dict{L, Dict{type_word, Vector{type_data}}}(), [languages], L, false
end

function process_language_argument(languages::Vector{L},
type_word::T1,
type_vector::T2) where {L<:Languages.Language, T1, T2}
type_data::T2) where {L<:Languages.Language, T1, T2}
if length(languages) == 1
return process_language_argument(languages[1], type_word, type_vector)
return process_language_argument(languages[1], type_word, type_data)
else
return Dict{L, Dict{type_word, type_vector}}(), languages, L, false
return Dict{L, Dict{type_word, Vector{type_data}}}(), languages, L, false
end
end

Expand Down Expand Up @@ -237,13 +236,13 @@ end
"""
Parse a line of text from a ConceptNetNumberbatch delimited file.
"""
function _parseline(buf; word_only=false)
function _parseline(buf, data_type::Type{E}; word_only=false) where E<:Real
bufvec = split(buf, " ")
word = string(popfirst!(bufvec))
if word_only
return word, Float64[]
return word, E[]
else
embedding = parse.(Float64, bufvec)
embedding = parse.(E, bufvec)
return word, embedding
end
end
Loading

0 comments on commit 9e10bec

Please sign in to comment.