Skip to content
This repository has been archived by the owner on Jul 7, 2023. It is now read-only.

Commit

Permalink
Correct text encoder, MultiModel, other merges.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 161036600
  • Loading branch information
Lukasz Kaiser committed Jul 6, 2017
1 parent 5adadf3 commit b88c13b
Show file tree
Hide file tree
Showing 11 changed files with 796 additions and 199 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setup(
name='tensor2tensor',
version='1.0.9',
version='1.0.10',
description='Tensor2Tensor',
author='Google Inc.',
author_email='no-reply@google.com',
Expand Down
17 changes: 10 additions & 7 deletions tensor2tensor/data_generators/generator_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from six.moves import xrange # pylint: disable=redefined-builtin
import six.moves.urllib_request as urllib # Imports urllib on Python2, urllib.request on Python3

from tensor2tensor.data_generators.text_encoder import SubwordTextEncoder
from tensor2tensor.data_generators import text_encoder
from tensor2tensor.data_generators.tokenizer import Tokenizer

import tensorflow as tf
Expand Down Expand Up @@ -218,15 +218,18 @@ def gunzip_file(gz_path, new_path):
]


def get_or_generate_vocab(tmp_dir, vocab_filename, vocab_size):
"""Generate a vocabulary from the datasets listed in _DATA_FILE_URLS."""
def get_or_generate_vocab(tmp_dir, vocab_filename, vocab_size, sources=None):
"""Generate a vocabulary from the datasets in sources (_DATA_FILE_URLS)."""
vocab_filepath = os.path.join(tmp_dir, vocab_filename)
if os.path.exists(vocab_filepath):
vocab = SubwordTextEncoder(vocab_filepath)
tf.logging.info("Found vocab file: %s", vocab_filepath)
vocab = text_encoder.SubwordTextEncoder(vocab_filepath)
return vocab

sources = sources or _DATA_FILE_URLS
tf.logging.info("Generating vocab from: %s", str(sources))
tokenizer = Tokenizer()
for source in _DATA_FILE_URLS:
for source in sources:
url = source[0]
filename = os.path.basename(url)
read_type = "r:gz" if "tgz" in filename else "r"
Expand Down Expand Up @@ -259,9 +262,9 @@ def get_or_generate_vocab(tmp_dir, vocab_filename, vocab_size):
break
line = line.strip()
file_byte_budget -= len(line)
_ = tokenizer.encode(line)
_ = tokenizer.encode(text_encoder.native_to_unicode(line))

vocab = SubwordTextEncoder.build_to_target_size(
vocab = text_encoder.SubwordTextEncoder.build_to_target_size(
vocab_size, tokenizer.token_counts, 1, 1e3)
vocab.store_to_file(vocab_filepath)
return vocab
Expand Down
17 changes: 10 additions & 7 deletions tensor2tensor/data_generators/text_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@


# Conversion between Unicode and UTF-8, if required (on Python2)
_native_to_unicode = (lambda s: s.decode("utf-8")) if PY2 else (lambda s: s)
native_to_unicode = (lambda s: s.decode("utf-8")) if PY2 else (lambda s: s)


_unicode_to_native = (lambda s: s.encode("utf-8")) if PY2 else (lambda s: s)
unicode_to_native = (lambda s: s.encode("utf-8")) if PY2 else (lambda s: s)


# Reserved tokens for things like padding and EOS symbols.
Expand Down Expand Up @@ -220,7 +220,7 @@ def encode(self, raw_text):
a list of integers in the range [0, vocab_size)
"""
return self._tokens_to_subtokens(self._tokenizer.encode(
_native_to_unicode(raw_text)))
native_to_unicode(raw_text)))

def decode(self, subtokens):
"""Converts a sequence of subtoken ids to a native string.
Expand All @@ -230,7 +230,7 @@ def decode(self, subtokens):
Returns:
a native string
"""
return _unicode_to_native(self._tokenizer.decode(
return unicode_to_native(self._tokenizer.decode(
self._subtokens_to_tokens(subtokens)))

@property
Expand Down Expand Up @@ -335,6 +335,9 @@ def bisect(min_val, max_val):
else:
other_subtokenizer = bisect(min_val, present_count - 1)

if other_subtokenizer is None:
return subtokenizer

if (abs(other_subtokenizer.vocab_size - target_size) <
abs(subtokenizer.vocab_size - target_size)):
return other_subtokenizer
Expand Down Expand Up @@ -449,13 +452,13 @@ def _load_from_file(self, filename):
subtoken_strings = []
with tf.gfile.Open(filename) as f:
for line in f:
subtoken_strings.append(_native_to_unicode(line.strip()[1:-1]))
subtoken_strings.append(native_to_unicode(line.strip()[1:-1]))
self._init_from_list(subtoken_strings)

def store_to_file(self, filename):
with tf.gfile.Open(filename, "w") as f:
for subtoken_string in self._all_subtoken_strings:
f.write("'" + _unicode_to_native(subtoken_string) + "'\n")
f.write("'" + unicode_to_native(subtoken_string) + "'\n")

def _escape_token(self, token):
r"""Escape away underscores and OOV characters and append '_'.
Expand Down Expand Up @@ -524,7 +527,7 @@ def get_token_counts(cls, text_filepattern, corpus_max_lines):
with tf.gfile.Open(text_filename) as f:
for line in f:
# The tokenizer updates token_counts in encode()
tok.encode(_native_to_unicode(line.strip()))
tok.encode(native_to_unicode(line.strip()))
lines_read += 1
if corpus_max_lines > 0 and lines_read > corpus_max_lines:
return tok.token_counts
Expand Down
Loading

0 comments on commit b88c13b

Please sign in to comment.