diff --git a/tensor2tensor/bin/t2t_decoder.py b/tensor2tensor/bin/t2t_decoder.py index 5bd947f93..fd103a6a1 100644 --- a/tensor2tensor/bin/t2t_decoder.py +++ b/tensor2tensor/bin/t2t_decoder.py @@ -37,7 +37,9 @@ # Dependency imports from tensor2tensor.bin import t2t_trainer +from tensor2tensor.data_generators import text_encoder from tensor2tensor.utils import decoding +from tensor2tensor.utils import registry from tensor2tensor.utils import trainer_lib from tensor2tensor.utils import usr_dir @@ -59,6 +61,8 @@ flags.DEFINE_bool("decode_interactive", False, "Interactive local inference mode.") flags.DEFINE_integer("decode_shards", 1, "Number of decoding replicas.") +flags.DEFINE_string("score_file", "", "File to score. Each line in the file " + "must be in the format input \t target.") def create_hparams(): @@ -96,12 +100,80 @@ def decode(estimator, hparams, decode_hp): dataset_split="test" if FLAGS.eval_use_test_set else None) +def score_file(filename): + """Score each line in a file and return the scores.""" + # Prepare model. + hparams = create_hparams() + encoders = registry.problem(FLAGS.problems).feature_encoders(FLAGS.data_dir) + has_inputs = "inputs" in encoders + + # Prepare features for feeding into the model. + if has_inputs: + inputs_ph = tf.placeholder(dtype=tf.int32) # Just length dimension. + batch_inputs = tf.reshape(inputs_ph, [1, -1, 1, 1]) # Make it 4D. + targets_ph = tf.placeholder(dtype=tf.int32) # Just length dimension. + batch_targets = tf.reshape(targets_ph, [1, -1, 1, 1]) # Make it 4D. + features = { + "inputs": batch_inputs, + "targets": batch_targets, + } if has_inputs else {"targets": batch_targets} + + # Prepare the model and the graph when model runs on features. + model = registry.model(FLAGS.model)(hparams, tf.estimator.ModeKeys.EVAL) + _, losses = model(features) + saver = tf.train.Saver() + + with tf.Session() as sess: + # Load weights from checkpoint. + ckpts = tf.train.get_checkpoint_state(FLAGS.output_dir) + ckpt = ckpts.model_checkpoint_path + saver.restore(sess, ckpt) + # Run on each line. + results = [] + for line in open(filename): + tab_split = line.split("\t") + if len(tab_split) > 2: + raise ValueError("Each line must have at most one tab separator.") + if len(tab_split) == 1: + targets = tab_split[0].strip() + else: + targets = tab_split[1].strip() + inputs = tab_split[0].strip() + # Run encoders and append EOS symbol. + targets_numpy = encoders["targets"].encode( + targets) + [text_encoder.EOS_ID] + if has_inputs: + inputs_numpy = encoders["inputs"].encode(inputs) + [text_encoder.EOS_ID] + # Prepare the feed. + feed = { + inputs_ph: inputs_numpy, + targets_ph: targets_numpy + } if has_inputs else {targets_ph: targets_numpy} + # Get the score. + np_loss = sess.run(losses["training"], feed) + results.append(np_loss) + return results + + def main(_): tf.logging.set_verbosity(tf.logging.INFO) trainer_lib.set_random_seed(FLAGS.random_seed) usr_dir.import_usr_dir(FLAGS.t2t_usr_dir) FLAGS.use_tpu = False # decoding not supported on TPU + if FLAGS.score_file: + filename = os.path.expanduser(FLAGS.score_file) + if not tf.gfile.Exists(filename): + raise ValueError("The file to score doesn't exist: %s" % filename) + results = score_file(filename) + if not FLAGS.decode_to_file: + raise ValueError("To score a file, specify --decode_to_file for results.") + write_file = open(os.path.expanduser(FLAGS.decode_to_file), "w") + for score in results: + write_file.write("%.6f\n" % score) + write_file.close() + return + hp = create_hparams() decode_hp = create_decode_hparams()