Skip to content
This repository has been archived by the owner on Jul 7, 2023. It is now read-only.

Commit

Permalink
Merge pull request #162 from vthorsteinsson/ice
Browse files Browse the repository at this point in the history
Bug fixes in inference and data generation; faster token unescaping
  • Loading branch information
lukaszkaiser authored Jul 18, 2017
2 parents 1368b00 + e2ed8ed commit c91989c
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 36 deletions.
Empty file modified tensor2tensor/bin/t2t-datagen
100644 → 100755
Empty file.
Empty file modified tensor2tensor/bin/t2t-trainer
100644 → 100755
Empty file.
7 changes: 4 additions & 3 deletions tensor2tensor/data_generators/generator_utils.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -329,18 +329,19 @@ def get_or_generate_tabbed_vocab(tmp_dir, source_filename,
return vocab

# Use Tokenizer to count the word occurrences.
token_counts = defaultdict(int)
filepath = os.path.join(tmp_dir, source_filename)
with tf.gfile.GFile(filepath, mode="r") as source_file:
for line in source_file:
line = line.strip()
if line and "\t" in line:
parts = line.split("\t", maxsplit=1)
part = parts[index].strip()
_ = tokenizer.encode(text_encoder.native_to_unicode(part))
for tok in tokenizer.encode(text_encoder.native_to_unicode(part)):
token_counts[tok] += 1

vocab = text_encoder.SubwordTextEncoder.build_to_target_size(
vocab_size, tokenizer.token_counts, 1,
min(1e3, vocab_size + text_encoder.NUM_RESERVED_TOKENS))
vocab_size, token_counts, 1, 1e3)
vocab.store_to_file(vocab_filepath)
return vocab

Expand Down
48 changes: 19 additions & 29 deletions tensor2tensor/data_generators/text_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from __future__ import print_function

from collections import defaultdict
import re

# Dependency imports

Expand Down Expand Up @@ -225,6 +226,7 @@ class SubwordTextEncoder(TextEncoder):

def __init__(self, filename=None):
"""Initialize and read from a file, if provided."""
self._alphabet = set()
if filename is not None:
self._load_from_file(filename)
super(SubwordTextEncoder, self).__init__(num_reserved_ids=None)
Expand Down Expand Up @@ -503,6 +505,12 @@ def _escape_token(self, token):
ret += u"\\%d;" % ord(c)
return ret

# Regular expression for unescaping token strings
# '\u' is converted to '_'
# '\\' is converted to '\'
# '\213;' is converted to unichr(213)
_UNESCAPE_REGEX = re.compile(u'|'.join([r"\\u", r"\\\\", r"\\([0-9]+);"]))

def _unescape_token(self, escaped_token):
"""Inverse of _escape_token().
Expand All @@ -511,32 +519,14 @@ def _unescape_token(self, escaped_token):
Returns:
token: a unicode string
"""
ret = u""
escaped_token = escaped_token[:-1]
pos = 0
while pos < len(escaped_token):
c = escaped_token[pos]
if c == "\\":
pos += 1
if pos >= len(escaped_token):
break
c = escaped_token[pos]
if c == u"u":
ret += u"_"
pos += 1
elif c == "\\":
ret += u"\\"
pos += 1
else:
semicolon_pos = escaped_token.find(u";", pos)
if semicolon_pos == -1:
continue
try:
ret += unichr(int(escaped_token[pos:semicolon_pos]))
pos = semicolon_pos + 1
except (ValueError, OverflowError) as _:
pass
else:
ret += c
pos += 1
return ret
def match(m):
if m.group(1) is not None:
# Convert '\213;' to unichr(213)
try:
return unichr(int(m.group(1)))
except (ValueError, OverflowError) as _:
return ""
# Convert '\u' to '_' and '\\' to '\'
return u"_" if m.group(0) == u"\\u" else u"\\"
# Cut off the trailing underscore and apply the regex substitution
return self._UNESCAPE_REGEX.sub(match, escaped_token[:-1])
2 changes: 1 addition & 1 deletion tensor2tensor/data_generators/tokenizer_test.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# -*- coding: utf-8 -*-
# Copyright 2017 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -12,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# coding=utf-8
"""Tests for tensor2tensor.data_generators.tokenizer."""

from __future__ import absolute_import
Expand Down
9 changes: 6 additions & 3 deletions tensor2tensor/utils/trainer_utils.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -585,6 +585,7 @@ def decode_from_dataset(estimator):
tf.logging.info("Performing local inference.")
infer_problems_data = get_datasets_for_mode(hparams.data_dir,
tf.contrib.learn.ModeKeys.INFER)

infer_input_fn = get_input_fn(
mode=tf.contrib.learn.ModeKeys.INFER,
hparams=hparams,
Expand Down Expand Up @@ -625,9 +626,11 @@ def log_fn(inputs,

# The function predict() returns an iterable over the network's
# predictions from the test input. We use it to log inputs and decodes.
for j, result in enumerate(result_iter):
inputs, targets, outputs = (result["inputs"], result["targets"],
result["outputs"])
inputs_iter = result_iter["inputs"]
targets_iter = result_iter["targets"]
outputs_iter = result_iter["outputs"]
for j, result in enumerate(zip(inputs_iter, targets_iter, outputs_iter)):
inputs, targets, outputs = result
if FLAGS.decode_return_beams:
output_beams = np.split(outputs, FLAGS.decode_beam_size, axis=0)
for k, beam in enumerate(output_beams):
Expand Down

0 comments on commit c91989c

Please sign in to comment.