diff --git a/tensor2tensor/data_generators/wmt.py b/tensor2tensor/data_generators/wmt.py index f5c68f3ea..2d43d1739 100755 --- a/tensor2tensor/data_generators/wmt.py +++ b/tensor2tensor/data_generators/wmt.py @@ -38,9 +38,8 @@ FLAGS = tf.flags.FLAGS -# End-of-sentence marker (should correspond to the position of EOS in the -# RESERVED_TOKENS list in text_encoder.py) -EOS = 1 +# End-of-sentence marker +EOS = text_encoder.EOS_TOKEN def character_generator(source_path, target_path, character_vocab, eos=None): @@ -183,7 +182,7 @@ def ende_bpe_token_generator(tmp_dir, train): train_path = _get_wmt_ende_dataset(tmp_dir, dataset_path) token_path = os.path.join(tmp_dir, "vocab.bpe.32000") token_vocab = text_encoder.TokenTextEncoder(vocab_filename=token_path) - return token_generator(train_path + ".en", train_path + ".de", token_vocab, 1) + return token_generator(train_path + ".en", train_path + ".de", token_vocab, EOS) _ENDE_TRAIN_DATASETS = [