Skip to content
This repository has been archived by the owner on Jul 7, 2023. It is now read-only.

Commit

Permalink
Add an option to score files to t2t_decoder.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 191769234
  • Loading branch information
Lukasz Kaiser authored and Ryan Sepassi committed Apr 5, 2018
1 parent b951c79 commit 6eea0e2
Showing 1 changed file with 72 additions and 0 deletions.
72 changes: 72 additions & 0 deletions tensor2tensor/bin/t2t_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@
# Dependency imports

from tensor2tensor.bin import t2t_trainer
from tensor2tensor.data_generators import text_encoder
from tensor2tensor.utils import decoding
from tensor2tensor.utils import registry
from tensor2tensor.utils import trainer_lib
from tensor2tensor.utils import usr_dir

Expand All @@ -59,6 +61,8 @@
flags.DEFINE_bool("decode_interactive", False,
"Interactive local inference mode.")
flags.DEFINE_integer("decode_shards", 1, "Number of decoding replicas.")
flags.DEFINE_string("score_file", "", "File to score. Each line in the file "
"must be in the format input \t target.")


def create_hparams():
Expand Down Expand Up @@ -96,12 +100,80 @@ def decode(estimator, hparams, decode_hp):
dataset_split="test" if FLAGS.eval_use_test_set else None)


def score_file(filename):
"""Score each line in a file and return the scores."""
# Prepare model.
hparams = create_hparams()
encoders = registry.problem(FLAGS.problems).feature_encoders(FLAGS.data_dir)
has_inputs = "inputs" in encoders

# Prepare features for feeding into the model.
if has_inputs:
inputs_ph = tf.placeholder(dtype=tf.int32) # Just length dimension.
batch_inputs = tf.reshape(inputs_ph, [1, -1, 1, 1]) # Make it 4D.
targets_ph = tf.placeholder(dtype=tf.int32) # Just length dimension.
batch_targets = tf.reshape(targets_ph, [1, -1, 1, 1]) # Make it 4D.
features = {
"inputs": batch_inputs,
"targets": batch_targets,
} if has_inputs else {"targets": batch_targets}

# Prepare the model and the graph when model runs on features.
model = registry.model(FLAGS.model)(hparams, tf.estimator.ModeKeys.EVAL)
_, losses = model(features)
saver = tf.train.Saver()

with tf.Session() as sess:
# Load weights from checkpoint.
ckpts = tf.train.get_checkpoint_state(FLAGS.output_dir)
ckpt = ckpts.model_checkpoint_path
saver.restore(sess, ckpt)
# Run on each line.
results = []
for line in open(filename):
tab_split = line.split("\t")
if len(tab_split) > 2:
raise ValueError("Each line must have at most one tab separator.")
if len(tab_split) == 1:
targets = tab_split[0].strip()
else:
targets = tab_split[1].strip()
inputs = tab_split[0].strip()
# Run encoders and append EOS symbol.
targets_numpy = encoders["targets"].encode(
targets) + [text_encoder.EOS_ID]
if has_inputs:
inputs_numpy = encoders["inputs"].encode(inputs) + [text_encoder.EOS_ID]
# Prepare the feed.
feed = {
inputs_ph: inputs_numpy,
targets_ph: targets_numpy
} if has_inputs else {targets_ph: targets_numpy}
# Get the score.
np_loss = sess.run(losses["training"], feed)
results.append(np_loss)
return results


def main(_):
tf.logging.set_verbosity(tf.logging.INFO)
trainer_lib.set_random_seed(FLAGS.random_seed)
usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
FLAGS.use_tpu = False # decoding not supported on TPU

if FLAGS.score_file:
filename = os.path.expanduser(FLAGS.score_file)
if not tf.gfile.Exists(filename):
raise ValueError("The file to score doesn't exist: %s" % filename)
results = score_file(filename)
if not FLAGS.decode_to_file:
raise ValueError("To score a file, specify --decode_to_file for results.")
write_file = open(os.path.expanduser(FLAGS.decode_to_file), "w")
for score in results:
write_file.write("%.6f\n" % score)
write_file.close()
return

hp = create_hparams()
decode_hp = create_decode_hparams()

Expand Down

0 comments on commit 6eea0e2

Please sign in to comment.