Skip to content
This repository has been archived by the owner on Jul 7, 2023. It is now read-only.

Commit

Permalink
Bug fixes in generator_utils and trainer_utils
Browse files Browse the repository at this point in the history
  • Loading branch information
vthorsteinsson committed Jul 17, 2017
1 parent d8d379c commit e2ed8ed
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 6 deletions.
7 changes: 4 additions & 3 deletions tensor2tensor/data_generators/generator_utils.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -324,18 +324,19 @@ def get_or_generate_tabbed_vocab(tmp_dir, source_filename,
return vocab

# Use Tokenizer to count the word occurrences.
token_counts = defaultdict(int)
filepath = os.path.join(tmp_dir, source_filename)
with tf.gfile.GFile(filepath, mode="r") as source_file:
for line in source_file:
line = line.strip()
if line and "\t" in line:
parts = line.split("\t", maxsplit=1)
part = parts[index].strip()
_ = tokenizer.encode(text_encoder.native_to_unicode(part))
for tok in tokenizer.encode(text_encoder.native_to_unicode(part)):
token_counts[tok] += 1

vocab = text_encoder.SubwordTextEncoder.build_to_target_size(
vocab_size, tokenizer.token_counts, 1,
min(1e3, vocab_size + text_encoder.NUM_RESERVED_TOKENS))
vocab_size, token_counts, 1, 1e3)
vocab.store_to_file(vocab_filepath)
return vocab

Expand Down
9 changes: 6 additions & 3 deletions tensor2tensor/utils/trainer_utils.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -585,6 +585,7 @@ def decode_from_dataset(estimator):
tf.logging.info("Performing local inference.")
infer_problems_data = get_datasets_for_mode(hparams.data_dir,
tf.contrib.learn.ModeKeys.INFER)

infer_input_fn = get_input_fn(
mode=tf.contrib.learn.ModeKeys.INFER,
hparams=hparams,
Expand Down Expand Up @@ -625,9 +626,11 @@ def log_fn(inputs,

# The function predict() returns an iterable over the network's
# predictions from the test input. We use it to log inputs and decodes.
for j, result in enumerate(result_iter):
inputs, targets, outputs = (result["inputs"], result["targets"],
result["outputs"])
inputs_iter = result_iter["inputs"]
targets_iter = result_iter["targets"]
outputs_iter = result_iter["outputs"]
for j, result in enumerate(zip(inputs_iter, targets_iter, outputs_iter)):
inputs, targets, outputs = result
if FLAGS.decode_return_beams:
output_beams = np.split(outputs, FLAGS.decode_beam_size, axis=0)
for k, beam in enumerate(output_beams):
Expand Down

0 comments on commit e2ed8ed

Please sign in to comment.