Skip to content

Commit

Permalink
simplify
Browse files Browse the repository at this point in the history
  • Loading branch information
danemadsen committed Aug 13, 2024
1 parent db9e393 commit 23935d0
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 15 deletions.
3 changes: 0 additions & 3 deletions include/babylon.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,6 @@ namespace DeepPhonemizer {
std::string pad_token;
std::string end_token;
std::unordered_set<std::string> special_tokens;

int get_start_index(const std::string& language) const;
std::string make_start_token(const std::string& language) const;
};

class Session {
Expand Down
15 changes: 3 additions & 12 deletions src/phonemizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ std::vector<float> softmax(const std::vector<float>& logits) {
for (size_t i = 0; i < logits.size(); ++i) {
probabilities[i] = std::exp(logits[i] - max_logit) / sum;
}

return probabilities;
}

Expand All @@ -33,7 +33,7 @@ namespace DeepPhonemizer {
special_tokens.insert(pad_token);

for (const auto& lang : languages) {
std::string lang_token = make_start_token(lang);
std::string lang_token = "<" + lang + ">";
token_to_idx[lang_token] = token_to_idx.size();
special_tokens.insert(lang_token);
}
Expand Down Expand Up @@ -68,7 +68,7 @@ namespace DeepPhonemizer {
}

if (append_start_end) {
sequence.insert(sequence.begin(), get_start_index(language));
sequence.insert(sequence.begin(), token_to_idx.at("<" + language + ">"));
sequence.push_back(end_index);
}

Expand Down Expand Up @@ -123,15 +123,6 @@ namespace DeepPhonemizer {
return decoded;
}

int SequenceTokenizer::get_start_index(const std::string& language) const {
std::string lang_token = make_start_token(language);
return token_to_idx.at(lang_token);
}

std::string SequenceTokenizer::make_start_token(const std::string& language) const {
return "<" + language + ">";
}

Session::Session(const std::string& model_path, const std::string language, const bool use_punctuation) {
Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "DeepPhonemizer");
env.DisableTelemetryEvents();
Expand Down

0 comments on commit 23935d0

Please sign in to comment.