From e2ed8ed3b55f64c05688cb8852f465131140fa2e Mon Sep 17 00:00:00 2001 From: vthorsteinsson Date: Mon, 17 Jul 2017 17:04:31 +0000 Subject: [PATCH] Bug fixes in generator_utils and trainer_utils --- tensor2tensor/data_generators/generator_utils.py | 7 ++++--- tensor2tensor/utils/trainer_utils.py | 9 ++++++--- 2 files changed, 10 insertions(+), 6 deletions(-) mode change 100644 => 100755 tensor2tensor/data_generators/generator_utils.py mode change 100644 => 100755 tensor2tensor/utils/trainer_utils.py diff --git a/tensor2tensor/data_generators/generator_utils.py b/tensor2tensor/data_generators/generator_utils.py old mode 100644 new mode 100755 index 890f92c2a..cacad12fc --- a/tensor2tensor/data_generators/generator_utils.py +++ b/tensor2tensor/data_generators/generator_utils.py @@ -324,6 +324,7 @@ def get_or_generate_tabbed_vocab(tmp_dir, source_filename, return vocab # Use Tokenizer to count the word occurrences. + token_counts = defaultdict(int) filepath = os.path.join(tmp_dir, source_filename) with tf.gfile.GFile(filepath, mode="r") as source_file: for line in source_file: @@ -331,11 +332,11 @@ def get_or_generate_tabbed_vocab(tmp_dir, source_filename, if line and "\t" in line: parts = line.split("\t", maxsplit=1) part = parts[index].strip() - _ = tokenizer.encode(text_encoder.native_to_unicode(part)) + for tok in tokenizer.encode(text_encoder.native_to_unicode(part)): + token_counts[tok] += 1 vocab = text_encoder.SubwordTextEncoder.build_to_target_size( - vocab_size, tokenizer.token_counts, 1, - min(1e3, vocab_size + text_encoder.NUM_RESERVED_TOKENS)) + vocab_size, token_counts, 1, 1e3) vocab.store_to_file(vocab_filepath) return vocab diff --git a/tensor2tensor/utils/trainer_utils.py b/tensor2tensor/utils/trainer_utils.py old mode 100644 new mode 100755 index b5894904d..66a01487c --- a/tensor2tensor/utils/trainer_utils.py +++ b/tensor2tensor/utils/trainer_utils.py @@ -585,6 +585,7 @@ def decode_from_dataset(estimator): tf.logging.info("Performing local inference.") infer_problems_data = get_datasets_for_mode(hparams.data_dir, tf.contrib.learn.ModeKeys.INFER) + infer_input_fn = get_input_fn( mode=tf.contrib.learn.ModeKeys.INFER, hparams=hparams, @@ -625,9 +626,11 @@ def log_fn(inputs, # The function predict() returns an iterable over the network's # predictions from the test input. We use it to log inputs and decodes. - for j, result in enumerate(result_iter): - inputs, targets, outputs = (result["inputs"], result["targets"], - result["outputs"]) + inputs_iter = result_iter["inputs"] + targets_iter = result_iter["targets"] + outputs_iter = result_iter["outputs"] + for j, result in enumerate(zip(inputs_iter, targets_iter, outputs_iter)): + inputs, targets, outputs = result if FLAGS.decode_return_beams: output_beams = np.split(outputs, FLAGS.decode_beam_size, axis=0) for k, beam in enumerate(output_beams):