Skip to content
This repository has been archived by the owner on Jul 7, 2023. It is now read-only.

Commit

Permalink
Adapted to upstream tokenizer change
Browse files Browse the repository at this point in the history
  • Loading branch information
vthorsteinsson committed Jul 14, 2017
1 parent 7bf4936 commit 5a72e5c
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 10 deletions.
9 changes: 4 additions & 5 deletions tensor2tensor/data_generators/generator_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,21 +300,20 @@ def get_or_generate_tabbed_vocab(tmp_dir, source_filename, index, vocab_filename
vocab = text_encoder.SubwordTextEncoder(vocab_filepath)
return vocab

tokenizer = Tokenizer()

# Use Tokenizer to count the word occurrences.
token_counts = defaultdict(int)
filepath = os.path.join(tmp_dir, source_filename)
with tf.gfile.GFile(filepath, mode="r") as source_file:
for line in source_file:
line = line.strip()
if line and '\t' in line:
parts = line.split('\t', maxsplit = 1)
part = parts[index].strip()
_ = tokenizer.encode(text_encoder.native_to_unicode(part))
for tok in tokenizer.encode(text_encoder.native_to_unicode(part)):
token_counts[tok] += 1

vocab = text_encoder.SubwordTextEncoder.build_to_target_size(
vocab_size, tokenizer.token_counts, 1,
min(1e3, vocab_size + text_encoder.NUM_RESERVED_TOKENS))
vocab_size, token_counts, 1, 1e3)
vocab.store_to_file(vocab_filepath)
return vocab

Expand Down
10 changes: 5 additions & 5 deletions tensor2tensor/data_generators/text_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
PAD_TOKEN = RESERVED_TOKENS.index(PAD) # Normally 0
EOS_TOKEN = RESERVED_TOKENS.index(EOS) # Normally 1

if six.PY2:
if PY2:
RESERVED_TOKENS_BYTES = RESERVED_TOKENS
else:
RESERVED_TOKENS_BYTES = [bytes(PAD, "ascii"), bytes(EOS, "ascii")]
Expand Down Expand Up @@ -110,7 +110,7 @@ class ByteTextEncoder(TextEncoder):

def encode(self, s):
numres = self._num_reserved_ids
if six.PY2:
if PY2:
return [ord(c) + numres for c in s]
# Python3: explicitly convert to UTF-8
return [c + numres for c in s.encode("utf-8")]
Expand All @@ -124,7 +124,7 @@ def decode(self, ids):
decoded_ids.append(RESERVED_TOKENS_BYTES[int(id_)])
else:
decoded_ids.append(int2byte(id_ - numres))
if six.PY2:
if PY2:
return "".join(decoded_ids)
# Python3: join byte arrays and then decode string
return b"".join(decoded_ids).decode("utf-8", "replace")
Expand Down Expand Up @@ -469,7 +469,7 @@ def store_to_file(self, filename):
f.write("'" + unicode_to_native(subtoken_string) + "'\n")

def _escape_token(self, token):
r"""Escape away underscores and OOV characters and append '_'.
"""Escape away underscores and OOV characters and append '_'.
This allows the token to be experessed as the concatenation of a list
of subtokens from the vocabulary. The underscore acts as a sentinel
Expand All @@ -491,7 +491,7 @@ def _escape_token(self, token):
return ret

def _unescape_token(self, escaped_token):
r"""Inverse of _escape_token().
"""Inverse of _escape_token().
Args:
escaped_token: a unicode string
Expand Down
1 change: 1 addition & 0 deletions tensor2tensor/data_generators/tokenizer.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ def read_corpus():
if corpus_max_lines > 0 and lines_read > corpus_max_lines:
return docs
return docs

counts = defaultdict(int)
for doc in read_corpus():
for tok in encode(_native_to_unicode(doc)):
Expand Down

0 comments on commit 5a72e5c

Please sign in to comment.