diff --git a/README.md b/README.md index 1fdd7e883..27bb47947 100644 --- a/README.md +++ b/README.md @@ -242,7 +242,7 @@ def transformer_my_very_own_hparams_set(): ```python # In ~/usr/t2t_usr/__init__.py -import my_registrations +from . import my_registrations ``` ``` diff --git a/setup.py b/setup.py index b70966986..00325cff2 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setup( name='tensor2tensor', - version='1.0.13', + version='1.0.14', description='Tensor2Tensor', author='Google Inc.', author_email='no-reply@google.com', diff --git a/tensor2tensor/__init__.py b/tensor2tensor/__init__.py index 27d533abc..eff6a2b14 100644 --- a/tensor2tensor/__init__.py +++ b/tensor2tensor/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2017 Google Inc. +# Copyright 2017 The Tensor2Tensor Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tensor2tensor/bin/t2t-datagen b/tensor2tensor/bin/t2t-datagen old mode 100755 new mode 100644 index a0d1454a4..44e4b34d3 --- a/tensor2tensor/bin/t2t-datagen +++ b/tensor2tensor/bin/t2t-datagen @@ -1,5 +1,5 @@ #!/usr/bin/env python -# Copyright 2017 Google Inc. +# Copyright 2017 The Tensor2Tensor Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -24,6 +24,9 @@ takes 2 arguments - input_directory and mode (one of "train" or "dev") - and yields for each training example a dictionary mapping string feature names to lists of {string, int, float}. The generator will be run once for each mode. """ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function import random import tempfile @@ -34,6 +37,7 @@ import numpy as np from tensor2tensor.data_generators import algorithmic from tensor2tensor.data_generators import algorithmic_math +from tensor2tensor.data_generators import all_problems # pylint: disable=unused-import from tensor2tensor.data_generators import audio from tensor2tensor.data_generators import generator_utils from tensor2tensor.data_generators import image @@ -43,6 +47,7 @@ from tensor2tensor.data_generators import snli from tensor2tensor.data_generators import wiki from tensor2tensor.data_generators import wmt from tensor2tensor.data_generators import wsj_parsing +from tensor2tensor.utils import registry import tensorflow as tf @@ -62,12 +67,6 @@ flags.DEFINE_integer("random_seed", 429459, "Random seed to use.") # Mapping from problems that we can generate data for to their generators. # pylint: disable=g-long-lambda _SUPPORTED_PROBLEM_GENERATORS = { - "algorithmic_identity_binary40": ( - lambda: algorithmic.identity_generator(2, 40, 100000), - lambda: algorithmic.identity_generator(2, 400, 10000)), - "algorithmic_identity_decimal40": ( - lambda: algorithmic.identity_generator(10, 40, 100000), - lambda: algorithmic.identity_generator(10, 400, 10000)), "algorithmic_shift_decimal40": ( lambda: algorithmic.shift_generator(20, 10, 40, 100000), lambda: algorithmic.shift_generator(20, 10, 80, 10000)), @@ -104,9 +103,9 @@ _SUPPORTED_PROBLEM_GENERATORS = { lambda: algorithmic_math.algebra_inverse(26, 3, 3, 10000)), "ice_parsing_tokens": ( lambda: wmt.tabbed_parsing_token_generator(FLAGS.tmp_dir, - True, "ice", 2**13, 2**8), + True, "ice", 2**13, 2**8), lambda: wmt.tabbed_parsing_token_generator(FLAGS.tmp_dir, - False, "ice", 2**13, 2**8)), + False, "ice", 2**13, 2**8)), "ice_parsing_characters": ( lambda: wmt.tabbed_parsing_character_generator(FLAGS.tmp_dir, True), lambda: wmt.tabbed_parsing_character_generator(FLAGS.tmp_dir, False)), @@ -118,11 +117,6 @@ _SUPPORTED_PROBLEM_GENERATORS = { 2**14, 2**9), lambda: wsj_parsing.parsing_token_generator(FLAGS.tmp_dir, False, 2**14, 2**9)), - "wsj_parsing_tokens_32k": ( - lambda: wsj_parsing.parsing_token_generator(FLAGS.tmp_dir, True, - 2**15, 2**9), - lambda: wsj_parsing.parsing_token_generator(FLAGS.tmp_dir, False, - 2**15, 2**9)), "wmt_enfr_characters": ( lambda: wmt.enfr_character_generator(FLAGS.tmp_dir, True), lambda: wmt.enfr_character_generator(FLAGS.tmp_dir, False)), @@ -140,14 +134,6 @@ _SUPPORTED_PROBLEM_GENERATORS = { "wmt_ende_bpe32k": ( lambda: wmt.ende_bpe_token_generator(FLAGS.tmp_dir, True), lambda: wmt.ende_bpe_token_generator(FLAGS.tmp_dir, False)), - "wmt_ende_tokens_8k": ( - lambda: wmt.ende_wordpiece_token_generator(FLAGS.tmp_dir, True, 2**13), - lambda: wmt.ende_wordpiece_token_generator(FLAGS.tmp_dir, False, 2**13) - ), - "wmt_ende_tokens_32k": ( - lambda: wmt.ende_wordpiece_token_generator(FLAGS.tmp_dir, True, 2**15), - lambda: wmt.ende_wordpiece_token_generator(FLAGS.tmp_dir, False, 2**15) - ), "wmt_zhen_tokens_32k": ( lambda: wmt.zhen_wordpiece_token_generator(FLAGS.tmp_dir, True, 2**15, 2**15), @@ -174,26 +160,9 @@ _SUPPORTED_PROBLEM_GENERATORS = { "image_cifar10_test": ( lambda: image.cifar10_generator(FLAGS.tmp_dir, True, 50000), lambda: image.cifar10_generator(FLAGS.tmp_dir, False, 10000)), - "image_mscoco_characters_tune": ( - lambda: image.mscoco_generator(FLAGS.tmp_dir, True, 70000), - lambda: image.mscoco_generator(FLAGS.tmp_dir, True, 10000, 70000)), "image_mscoco_characters_test": ( lambda: image.mscoco_generator(FLAGS.tmp_dir, True, 80000), lambda: image.mscoco_generator(FLAGS.tmp_dir, False, 40000)), - "image_mscoco_tokens_8k_tune": ( - lambda: image.mscoco_generator( - FLAGS.tmp_dir, - True, - 70000, - vocab_filename="tokens.vocab.%d" % 2**13, - vocab_size=2**13), - lambda: image.mscoco_generator( - FLAGS.tmp_dir, - True, - 10000, - 70000, - vocab_filename="tokens.vocab.%d" % 2**13, - vocab_size=2**13)), "image_mscoco_tokens_8k_test": ( lambda: image.mscoco_generator( FLAGS.tmp_dir, @@ -207,20 +176,6 @@ _SUPPORTED_PROBLEM_GENERATORS = { 40000, vocab_filename="tokens.vocab.%d" % 2**13, vocab_size=2**13)), - "image_mscoco_tokens_32k_tune": ( - lambda: image.mscoco_generator( - FLAGS.tmp_dir, - True, - 70000, - vocab_filename="tokens.vocab.%d" % 2**15, - vocab_size=2**15), - lambda: image.mscoco_generator( - FLAGS.tmp_dir, - True, - 10000, - 70000, - vocab_filename="tokens.vocab.%d" % 2**15, - vocab_size=2**15)), "image_mscoco_tokens_32k_test": ( lambda: image.mscoco_generator( FLAGS.tmp_dir, @@ -308,8 +263,6 @@ _SUPPORTED_PROBLEM_GENERATORS = { # pylint: enable=g-long-lambda -UNSHUFFLED_SUFFIX = "-unshuffled" - def set_random_seed(): """Set the random seed from flag everywhere.""" @@ -322,13 +275,15 @@ def main(_): tf.logging.set_verbosity(tf.logging.INFO) # Calculate the list of problems to generate. - problems = list(sorted(_SUPPORTED_PROBLEM_GENERATORS)) + problems = sorted( + list(_SUPPORTED_PROBLEM_GENERATORS) + registry.list_problems()) if FLAGS.problem and FLAGS.problem[-1] == "*": problems = [p for p in problems if p.startswith(FLAGS.problem[:-1])] elif FLAGS.problem: problems = [p for p in problems if p == FLAGS.problem] else: problems = [] + # Remove TIMIT if paths are not given. if not FLAGS.timit_paths: problems = [p for p in problems if "timit" not in p] @@ -340,7 +295,8 @@ def main(_): problems = [p for p in problems if "ende_bpe" not in p] if not problems: - problems_str = "\n * ".join(sorted(_SUPPORTED_PROBLEM_GENERATORS)) + problems_str = "\n * ".join( + sorted(list(_SUPPORTED_PROBLEM_GENERATORS) + registry.list_problems())) error_msg = ("You must specify one of the supported problems to " "generate data for:\n * " + problems_str + "\n") error_msg += ("TIMIT, ende_bpe and parsing need data_sets specified with " @@ -357,40 +313,50 @@ def main(_): for problem in problems: set_random_seed() - training_gen, dev_gen = _SUPPORTED_PROBLEM_GENERATORS[problem] - - if isinstance(dev_gen, int): - # The dev set and test sets are generated as extra shards using the - # training generator. The integer specifies the number of training - # shards. FLAGS.num_shards is ignored. - num_training_shards = dev_gen - tf.logging.info("Generating data for %s.", problem) - all_output_files = generator_utils.combined_data_filenames( - problem + UNSHUFFLED_SUFFIX, FLAGS.data_dir, num_training_shards) - generator_utils.generate_files( - training_gen(), all_output_files, FLAGS.max_cases) + if problem in _SUPPORTED_PROBLEM_GENERATORS: + generate_data_for_problem(problem) else: - # usual case - train data and dev data are generated using separate - # generators. - tf.logging.info("Generating training data for %s.", problem) - train_output_files = generator_utils.train_data_filenames( - problem + UNSHUFFLED_SUFFIX, FLAGS.data_dir, FLAGS.num_shards) - generator_utils.generate_files( - training_gen(), train_output_files, FLAGS.max_cases) - tf.logging.info("Generating development data for %s.", problem) - dev_shards = 10 if "coco" in problem else 1 - dev_output_files = generator_utils.dev_data_filenames( - problem + UNSHUFFLED_SUFFIX, FLAGS.data_dir, dev_shards) - generator_utils.generate_files(dev_gen(), dev_output_files) - all_output_files = train_output_files + dev_output_files + generate_data_for_registered_problem(problem) + + +def generate_data_for_problem(problem): + """Generate data for a problem in _SUPPORTED_PROBLEM_GENERATORS.""" + training_gen, dev_gen = _SUPPORTED_PROBLEM_GENERATORS[problem] + + if isinstance(dev_gen, int): + # The dev set and test sets are generated as extra shards using the + # training generator. The integer specifies the number of training + # shards. FLAGS.num_shards is ignored. + num_training_shards = dev_gen + tf.logging.info("Generating data for %s.", problem) + all_output_files = generator_utils.combined_data_filenames( + problem + generator_utils.UNSHUFFLED_SUFFIX, FLAGS.data_dir, + num_training_shards) + generator_utils.generate_files(training_gen(), all_output_files, + FLAGS.max_cases) + else: + # usual case - train data and dev data are generated using separate + # generators. + tf.logging.info("Generating training data for %s.", problem) + train_output_files = generator_utils.train_data_filenames( + problem + generator_utils.UNSHUFFLED_SUFFIX, FLAGS.data_dir, + FLAGS.num_shards) + generator_utils.generate_files(training_gen(), train_output_files, + FLAGS.max_cases) + tf.logging.info("Generating development data for %s.", problem) + dev_shards = 10 if "coco" in problem else 1 + dev_output_files = generator_utils.dev_data_filenames( + problem + generator_utils.UNSHUFFLED_SUFFIX, FLAGS.data_dir, dev_shards) + generator_utils.generate_files(dev_gen(), dev_output_files) + all_output_files = train_output_files + dev_output_files + + tf.logging.info("Shuffling data...") + generator_utils.shuffle_dataset(all_output_files) + - tf.logging.info("Shuffling data...") - for fname in all_output_files: - records = generator_utils.read_records(fname) - random.shuffle(records) - out_fname = fname.replace(UNSHUFFLED_SUFFIX, "") - generator_utils.write_records(records, out_fname) - tf.gfile.Remove(fname) +def generate_data_for_registered_problem(problem_name): + problem = registry.problem(problem_name) + problem.generate_data(FLAGS.data_dir, FLAGS.tmp_dir) if __name__ == "__main__": diff --git a/tensor2tensor/bin/t2t-make-tf-configs b/tensor2tensor/bin/t2t-make-tf-configs index ae87ffbd8..6a4dc8641 100644 --- a/tensor2tensor/bin/t2t-make-tf-configs +++ b/tensor2tensor/bin/t2t-make-tf-configs @@ -1,5 +1,5 @@ #!/usr/bin/env python -# Copyright 2017 Google Inc. +# Copyright 2017 The Tensor2Tensor Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,13 +17,13 @@ Usage: -`t2t-make-tf-configs --workers="server1:1234" --ps="server3:2134,server4:2334"` +`t2t-make-tf-configs --masters="server1:1234" --ps="server3:2134,server4:2334"` -Outputs 1 line per job to stdout, first the workers, then the parameter servers. +Outputs 1 line per job to stdout, first the masters, then the parameter servers. Each line has the TF_CONFIG, then a tab, then the command line flags for that job. -If there is a single worker, workers will have the `--sync` flag. +If there is a single master, it will have the `--sync` flag. """ from __future__ import absolute_import from __future__ import division @@ -38,31 +38,32 @@ import tensorflow as tf flags = tf.flags FLAGS = flags.FLAGS -flags.DEFINE_string("workers", "", "Comma-separated list of worker addresses") +flags.DEFINE_string("masters", "", "Comma-separated list of master addresses") flags.DEFINE_string("ps", "", "Comma-separated list of ps addresses") def main(_): - if not (FLAGS.workers and FLAGS.ps): - raise ValueError("Must provide --workers and --ps") + if not (FLAGS.masters and FLAGS.ps): + raise ValueError("Must provide --masters and --ps") - workers = FLAGS.workers.split(",") + masters = FLAGS.masters.split(",") ps = FLAGS.ps.split(",") - cluster = {"ps": ps, "worker": workers} + cluster = {"ps": ps, "master": masters} - for task_type, jobs in (("worker", workers), ("ps", ps)): + for task_type, jobs in (("master", masters), ("ps", ps)): for idx, job in enumerate(jobs): - if task_type == "worker": + if task_type == "master": cmd_line_flags = " ".join([ "--master=grpc://%s" % job, "--ps_replicas=%d" % len(ps), - "--worker_replicas=%d" % len(workers), + "--worker_replicas=%d" % len(masters), "--worker_gpu=1", "--worker_id=%d" % idx, + "--worker_job='/job:master'", "--ps_gpu=1", "--schedule=train", - "--sync" if len(workers) == 1 else "", + "--sync" if len(masters) == 1 else "", ]) else: cmd_line_flags = " ".join([ diff --git a/tensor2tensor/bin/t2t-trainer b/tensor2tensor/bin/t2t-trainer old mode 100755 new mode 100644 index 92f671826..322957028 --- a/tensor2tensor/bin/t2t-trainer +++ b/tensor2tensor/bin/t2t-trainer @@ -1,5 +1,5 @@ #!/usr/bin/env python -# Copyright 2017 Google Inc. +# Copyright 2017 The Tensor2Tensor Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tensor2tensor/data_generators/__init__.py b/tensor2tensor/data_generators/__init__.py index 27d533abc..eff6a2b14 100644 --- a/tensor2tensor/data_generators/__init__.py +++ b/tensor2tensor/data_generators/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2017 Google Inc. +# Copyright 2017 The Tensor2Tensor Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tensor2tensor/data_generators/algorithmic.py b/tensor2tensor/data_generators/algorithmic.py index 4cd14753b..6ec1f28a0 100644 --- a/tensor2tensor/data_generators/algorithmic.py +++ b/tensor2tensor/data_generators/algorithmic.py @@ -1,4 +1,4 @@ -# Copyright 2017 Google Inc. +# Copyright 2017 The Tensor2Tensor Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -23,12 +23,50 @@ from six.moves import xrange # pylint: disable=redefined-builtin +from tensor2tensor.data_generators import generator_utils as utils +from tensor2tensor.data_generators import problem +from tensor2tensor.utils import registry + + +@registry.register_problem +class AlgorithmicIdentityBinary40(problem.Problem): + """Problem spec for algorithmic binary identity task.""" + + @property + def num_symbols(self): + return 2 + + def generate_data(self, data_dir, _): + utils.generate_dataset_and_shuffle( + identity_generator(self.num_symbols, 40, 100000), + self.training_filepaths(data_dir, 100, shuffled=True), + identity_generator(self.num_symbols, 400, 10000), + self.dev_filepaths(data_dir, 1, shuffled=True), + shuffle=False) + + def hparams(self, defaults, unused_model_hparams): + p = defaults + vocab_size = self.num_symbols + self._encoders["inputs"].num_reserved_ids + p.input_modality = {"inputs": (registry.Modalities.SYMBOL, vocab_size)} + p.target_modality = (registry.Modalities.SYMBOL, vocab_size) + p.input_space_id = problem.SpaceID.DIGIT_0 + p.target_space_id = problem.SpaceID.DIGIT_1 + + +@registry.register_problem +class AlgorithmicIdentityDecimal40(AlgorithmicIdentityBinary40): + """Problem spec for algorithmic decimal identity task.""" + + @property + def num_symbols(self): + return 10 + def identity_generator(nbr_symbols, max_length, nbr_cases): """Generator for the identity (copy) task on sequences of symbols. The length of the sequence is drawn uniformly at random from [1, max_length] - and then symbols are drawn uniformly at random from [2, nbr_symbols] until + and then symbols are drawn uniformly at random from [2, nbr_symbols + 2) until nbr_cases sequences have been produced. Args: @@ -66,8 +104,10 @@ def shift_generator(nbr_symbols, shift, max_length, nbr_cases): for _ in xrange(nbr_cases): l = np.random.randint(max_length) + 1 inputs = [np.random.randint(nbr_symbols - shift) + 2 for _ in xrange(l)] - yield {"inputs": inputs, - "targets": [i + shift for i in inputs] + [1]} # [1] for EOS + yield { + "inputs": inputs, + "targets": [i + shift for i in inputs] + [1] + } # [1] for EOS def reverse_generator(nbr_symbols, max_length, nbr_cases): @@ -89,8 +129,10 @@ def reverse_generator(nbr_symbols, max_length, nbr_cases): for _ in xrange(nbr_cases): l = np.random.randint(max_length) + 1 inputs = [np.random.randint(nbr_symbols) + 2 for _ in xrange(l)] - yield {"inputs": inputs, - "targets": list(reversed(inputs)) + [1]} # [1] for EOS + yield { + "inputs": inputs, + "targets": list(reversed(inputs)) + [1] + } # [1] for EOS def zipf_distribution(nbr_symbols, alpha): @@ -106,7 +148,7 @@ def zipf_distribution(nbr_symbols, alpha): distr_map: list of float, Zipf's distribution over nbr_symbols. """ - tmp = np.power(np.arange(1, nbr_symbols+1), -alpha) + tmp = np.power(np.arange(1, nbr_symbols + 1), -alpha) zeta = np.r_[0.0, np.cumsum(tmp)] return [x / zeta[-1] for x in zeta] @@ -128,11 +170,14 @@ def zipf_random_sample(distr_map, sample_len): # we have made a sanity check to overcome this issue. On the other hand, # t+1 is enough from saving us to generate PAD(0) and EOS(1) which are # reservated symbols. - return [t+1 if t > 0 else t+2 for t in np.searchsorted(distr_map, u)] + return [t + 1 if t > 0 else t + 2 for t in np.searchsorted(distr_map, u)] -def reverse_generator_nlplike(nbr_symbols, max_length, nbr_cases, - scale_std_dev=100, alpha=1.5): +def reverse_generator_nlplike(nbr_symbols, + max_length, + nbr_cases, + scale_std_dev=100, + alpha=1.5): """Generator for the reversing nlp-like task on sequences of symbols. The length of the sequence is drawn from a Gaussian(Normal) distribution @@ -157,10 +202,12 @@ def reverse_generator_nlplike(nbr_symbols, max_length, nbr_cases, std_dev = max_length / scale_std_dev distr_map = zipf_distribution(nbr_symbols, alpha) for _ in xrange(nbr_cases): - l = int(abs(np.random.normal(loc=max_length/2, scale=std_dev)) + 1) + l = int(abs(np.random.normal(loc=max_length / 2, scale=std_dev)) + 1) inputs = zipf_random_sample(distr_map, l) - yield {"inputs": inputs, - "targets": list(reversed(inputs)) + [1]} # [1] for EOS + yield { + "inputs": inputs, + "targets": list(reversed(inputs)) + [1] + } # [1] for EOS def lower_endian_to_number(l, base): diff --git a/tensor2tensor/data_generators/algorithmic_math.py b/tensor2tensor/data_generators/algorithmic_math.py index ec3b7670a..e65b47ff0 100644 --- a/tensor2tensor/data_generators/algorithmic_math.py +++ b/tensor2tensor/data_generators/algorithmic_math.py @@ -1,4 +1,4 @@ -# Copyright 2017 Google Inc. +# Copyright 2017 The Tensor2Tensor Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tensor2tensor/data_generators/algorithmic_math_test.py b/tensor2tensor/data_generators/algorithmic_math_test.py index 6c4b63054..5f0de29fb 100644 --- a/tensor2tensor/data_generators/algorithmic_math_test.py +++ b/tensor2tensor/data_generators/algorithmic_math_test.py @@ -1,4 +1,4 @@ -# Copyright 2017 Google Inc. +# Copyright 2017 The Tensor2Tensor Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tensor2tensor/data_generators/algorithmic_test.py b/tensor2tensor/data_generators/algorithmic_test.py index 70a5d68b8..9961e6173 100644 --- a/tensor2tensor/data_generators/algorithmic_test.py +++ b/tensor2tensor/data_generators/algorithmic_test.py @@ -1,4 +1,4 @@ -# Copyright 2017 Google Inc. +# Copyright 2017 The Tensor2Tensor Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tensor2tensor/data_generators/all_problems.py b/tensor2tensor/data_generators/all_problems.py new file mode 100644 index 000000000..364c252a7 --- /dev/null +++ b/tensor2tensor/data_generators/all_problems.py @@ -0,0 +1,31 @@ +# Copyright 2017 The Tensor2Tensor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Imports for problem modules.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=unused-import +from tensor2tensor.data_generators import algorithmic +from tensor2tensor.data_generators import algorithmic_math +from tensor2tensor.data_generators import audio +from tensor2tensor.data_generators import image +from tensor2tensor.data_generators import lm1b +from tensor2tensor.data_generators import ptb +from tensor2tensor.data_generators import snli +from tensor2tensor.data_generators import wiki +from tensor2tensor.data_generators import wmt +from tensor2tensor.data_generators import wsj_parsing +# pylint: enable=unused-import diff --git a/tensor2tensor/data_generators/audio.py b/tensor2tensor/data_generators/audio.py index 12e0c7b43..81cfde008 100644 --- a/tensor2tensor/data_generators/audio.py +++ b/tensor2tensor/data_generators/audio.py @@ -1,4 +1,4 @@ -# Copyright 2017 Google Inc. +# Copyright 2017 The Tensor2Tensor Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tensor2tensor/data_generators/audio_test.py b/tensor2tensor/data_generators/audio_test.py index f1830043f..1c19432c3 100644 --- a/tensor2tensor/data_generators/audio_test.py +++ b/tensor2tensor/data_generators/audio_test.py @@ -1,4 +1,4 @@ -# Copyright 2017 Google Inc. +# Copyright 2017 The Tensor2Tensor Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tensor2tensor/data_generators/concatenate_examples.py b/tensor2tensor/data_generators/concatenate_examples.py index b346b6c08..158bc1b59 100644 --- a/tensor2tensor/data_generators/concatenate_examples.py +++ b/tensor2tensor/data_generators/concatenate_examples.py @@ -1,4 +1,4 @@ -# Copyright 2017 Google Inc. +# Copyright 2017 The Tensor2Tensor Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tensor2tensor/data_generators/generator_utils.py b/tensor2tensor/data_generators/generator_utils.py old mode 100755 new mode 100644 index f076c10da..890f92c2a --- a/tensor2tensor/data_generators/generator_utils.py +++ b/tensor2tensor/data_generators/generator_utils.py @@ -1,4 +1,4 @@ -# Copyright 2017 Google Inc. +# Copyright 2017 The Tensor2Tensor Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -22,6 +22,7 @@ import gzip import io import os +import random import tarfile # Dependency imports @@ -35,6 +36,8 @@ import tensorflow as tf +UNSHUFFLED_SUFFIX = "-unshuffled" + def to_example(dictionary): """Helper: build tf.Example from (string -> int/float/str list) dictionary.""" @@ -66,7 +69,7 @@ def generate_files_distributed(generator, task_id=0): """generate_files but with a single writer writing to shard task_id.""" assert task_id < num_shards - output_filename = "%s-%.5d-of-%.5d" % (output_name, task_id, num_shards) + output_filename = sharded_name(output_name, task_id, num_shards) output_file = os.path.join(output_dir, output_filename) tf.logging.info("Writing to file %s", output_file) writer = tf.python_io.TFRecordWriter(output_file) @@ -86,14 +89,14 @@ def generate_files_distributed(generator, def _data_filenames(output_name, output_dir, num_shards): - return [os.path.join( - output_dir, "%s-%.5d-of-%.5d" % (output_name, shard, num_shards)) - for shard in xrange(num_shards)] + return [ + os.path.join(output_dir, fname) + for fname in shard_filepath(output_name, num_shards) + ] def train_data_filenames(problem, output_dir, num_shards): - return _data_filenames( - problem + "-train", output_dir, num_shards) + return _data_filenames(problem + "-train", output_dir, num_shards) def dev_data_filenames(problem, output_dir, num_shards): @@ -105,15 +108,22 @@ def test_data_filenames(problem, output_dir, num_shards): def combined_data_filenames(problem, output_dir, num_training_shards): - return ( - train_data_filenames(problem, output_dir, num_training_shards) + - dev_data_filenames(problem, output_dir, 1) + - test_data_filenames(problem, output_dir, 1)) + return (train_data_filenames(problem, output_dir, num_training_shards) + + dev_data_filenames(problem, output_dir, 1) + test_data_filenames( + problem, output_dir, 1)) + + +def sharded_name(base_name, shard, total_shards): + return "%s-%.5d-of-%.5d" % (base_name, shard, total_shards) + +def shard_filepath(fname, num_shards): + return [ + sharded_name(fname, shard, num_shards) for shard in xrange(num_shards) + ] -def generate_files(generator, - output_filenames, - max_cases=None): + +def generate_files(generator, output_filenames, max_cases=None): """Generate cases from a generator and save as TFRecord files. Generated cases are transformed to tf.Example protos and saved as TFRecords @@ -172,8 +182,8 @@ def maybe_download(directory, filename, url): if not tf.gfile.Exists(filepath): tf.logging.info("Downloading %s to %s" % (url, filepath)) inprogress_filepath = filepath + ".incomplete" - inprogress_filepath, _ = urllib.urlretrieve(url, inprogress_filepath, - reporthook=download_report_hook) + inprogress_filepath, _ = urllib.urlretrieve( + url, inprogress_filepath, reporthook=download_report_hook) # Print newline to clear the carriage return from the download progress print() tf.gfile.Rename(inprogress_filepath, filepath) @@ -266,8 +276,8 @@ def get_or_generate_vocab(tmp_dir, vocab_filename, vocab_size, sources=None): if ".gz" in lang_file: new_filepath = os.path.join(tmp_dir, lang_file[:-3]) if tf.gfile.Exists(new_filepath): - tf.logging.info("Subdirectory %s already exists, skipping unpacking" - % filepath) + tf.logging.info( + "Subdirectory %s already exists, skipping unpacking" % filepath) else: tf.logging.info("Unpacking subdirectory %s" % filepath) gunzip_file(filepath, new_filepath) @@ -290,30 +300,42 @@ def get_or_generate_vocab(tmp_dir, vocab_filename, vocab_size, sources=None): return vocab -def get_or_generate_tabbed_vocab(tmp_dir, source_filename, index, vocab_filename, vocab_size): - """Generate a vocabulary from the source file. This is assumed to be - a file of source, target pairs, where each line contains a source string - and a target string, separated by a tab ('\t') character. The index - parameter specifies 0 for the source or 1 for the target.""" +def get_or_generate_tabbed_vocab(tmp_dir, source_filename, + index, vocab_filename, vocab_size): + r"""Generate a vocabulary from a tabbed source file. + + The source is a file of source, target pairs, where each line contains + a source string and a target string, separated by a tab ('\t') character. + The index parameter specifies 0 for the source or 1 for the target. + + Args: + tmp_dir: path to the temporary directory. + source_filename: the name of the tab-separated source file. + index: index. + vocab_filename: the name of the vocabulary file. + vocab_size: vocabulary size. + + Returns: + The vocabulary. + """ vocab_filepath = os.path.join(tmp_dir, vocab_filename) if os.path.exists(vocab_filepath): vocab = text_encoder.SubwordTextEncoder(vocab_filepath) return vocab # Use Tokenizer to count the word occurrences. - token_counts = defaultdict(int) filepath = os.path.join(tmp_dir, source_filename) with tf.gfile.GFile(filepath, mode="r") as source_file: for line in source_file: line = line.strip() - if line and '\t' in line: - parts = line.split('\t', maxsplit = 1) + if line and "\t" in line: + parts = line.split("\t", maxsplit=1) part = parts[index].strip() - for tok in tokenizer.encode(text_encoder.native_to_unicode(part)): - token_counts[tok] += 1 + _ = tokenizer.encode(text_encoder.native_to_unicode(part)) vocab = text_encoder.SubwordTextEncoder.build_to_target_size( - vocab_size, token_counts, 1, 1e3) + vocab_size, tokenizer.token_counts, 1, + min(1e3, vocab_size + text_encoder.NUM_RESERVED_TOKENS)) vocab.store_to_file(vocab_filepath) return vocab @@ -335,3 +357,24 @@ def write_records(records, out_filename): if count > 0 and count % 100000 == 0: tf.logging.info("write: %d", count) writer.close() + + +def generate_dataset_and_shuffle(train_gen, + train_paths, + dev_gen, + dev_paths, + shuffle=True): + generate_files(train_gen, train_paths) + generate_files(dev_gen, dev_paths) + if shuffle: + shuffle_dataset(train_paths + dev_paths) + + +def shuffle_dataset(filenames): + tf.logging.info("Shuffling data...") + for fname in filenames: + records = read_records(fname) + random.shuffle(records) + out_fname = fname.replace(UNSHUFFLED_SUFFIX, "") + write_records(records, out_fname) + tf.gfile.Remove(fname) diff --git a/tensor2tensor/data_generators/generator_utils_test.py b/tensor2tensor/data_generators/generator_utils_test.py index 320d1a02d..c776d120c 100644 --- a/tensor2tensor/data_generators/generator_utils_test.py +++ b/tensor2tensor/data_generators/generator_utils_test.py @@ -1,4 +1,4 @@ -# Copyright 2017 Google Inc. +# Copyright 2017 The Tensor2Tensor Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tensor2tensor/data_generators/image.py b/tensor2tensor/data_generators/image.py index 377bf3e54..0cba1800b 100644 --- a/tensor2tensor/data_generators/image.py +++ b/tensor2tensor/data_generators/image.py @@ -1,4 +1,4 @@ -# Copyright 2017 Google Inc. +# Copyright 2017 The Tensor2Tensor Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -33,6 +33,9 @@ from six.moves import xrange # pylint: disable=redefined-builtin from six.moves import zip # pylint: disable=redefined-builtin from tensor2tensor.data_generators import generator_utils +from tensor2tensor.data_generators import problem +from tensor2tensor.data_generators import text_encoder +from tensor2tensor.utils import registry import tensorflow as tf @@ -300,3 +303,47 @@ def mscoco_generator(tmp_dir, "image/height": [height], "image/width": [width] } + +# French street names dataset. + + +@registry.register_problem +class ImageFSNS(problem.Problem): + """Problem spec for French Street Name recognition.""" + + def generate_data(self, data_dir, tmp_dir): + list_url = ("https://raw.githubusercontent.com/tensorflow/models/master/" + "street/python/fsns_urls.txt") + fsns_urls = generator_utils.maybe_download( + tmp_dir, "fsns_urls.txt", list_url) + fsns_files = [f.strip() for f in open(fsns_urls, "r") + if f.startswith("http://")] + for url in fsns_files: + if "/train/train" in url: + generator_utils.maybe_download( + data_dir, "image_fsns-train" + url[-len("-00100-of-00512"):], url) + elif "/validation/validation" in url: + generator_utils.maybe_download( + data_dir, "image_fsns-dev" + url[-len("-00100-of-00512"):], url) + elif "charset" in url: + generator_utils.maybe_download( + data_dir, "charset_size134.txt", url) + + def hparams(self, defaults, model_hparams): + p = defaults + p.input_modality = {"inputs": (registry.Modalities.IMAGE, None)} + # This vocab file must be present within the data directory. + vocab_filename = os.path.join(model_hparams.data_dir, "charset_size134.txt") + subtokenizer = text_encoder.SubwordTextEncoder(vocab_filename) + p.target_modality = (registry.Modalities.SYMBOL, subtokenizer.vocab_size) + p.vocabulary = { + "inputs": text_encoder.TextEncoder(), + "targets": subtokenizer, + } + p.batch_size_multiplier = 256 + p.max_expected_batch_size_per_shard = 2 + vocab_size = 144 + p.input_modality = {"inputs": (registry.Modalities.SYMBOL, vocab_size)} + p.target_modality = (registry.Modalities.SYMBOL, vocab_size) + p.input_space_id = problem.SpaceID.DIGIT_0 + p.target_space_id = problem.SpaceID.DIGIT_1 diff --git a/tensor2tensor/data_generators/image_test.py b/tensor2tensor/data_generators/image_test.py index c5b4f14be..6c9984265 100644 --- a/tensor2tensor/data_generators/image_test.py +++ b/tensor2tensor/data_generators/image_test.py @@ -1,4 +1,4 @@ -# Copyright 2017 Google Inc. +# Copyright 2017 The Tensor2Tensor Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tensor2tensor/data_generators/inspect.py b/tensor2tensor/data_generators/inspect.py index a0da09150..fba3c6492 100644 --- a/tensor2tensor/data_generators/inspect.py +++ b/tensor2tensor/data_generators/inspect.py @@ -1,4 +1,4 @@ -# Copyright 2017 Google Inc. +# Copyright 2017 The Tensor2Tensor Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tensor2tensor/data_generators/lm1b.py b/tensor2tensor/data_generators/lm1b.py index 66a3d52a0..78fb001bc 100644 --- a/tensor2tensor/data_generators/lm1b.py +++ b/tensor2tensor/data_generators/lm1b.py @@ -1,4 +1,4 @@ -# Copyright 2017 Google Inc. +# Copyright 2017 The Tensor2Tensor Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tensor2tensor/data_generators/problem.py b/tensor2tensor/data_generators/problem.py new file mode 100644 index 000000000..28f4dcb1b --- /dev/null +++ b/tensor2tensor/data_generators/problem.py @@ -0,0 +1,266 @@ +# Copyright 2017 The Tensor2Tensor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Base class for problem/dataset definitions.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# Dependency imports + +from tensor2tensor.data_generators import generator_utils as utils +from tensor2tensor.data_generators import text_encoder + +import tensorflow as tf + + +class SpaceID(object): + """Input and target space ids. Add more as needed.""" + # Generic / unknown output space (default) + GENERIC = 0 + # Image labels + IMAGE_LABEL = 1 + # English characters + EN_CHR = 2 + # English tokens + EN_TOK = 3 + # English bpe tokens + EN_BPE_TOK = 4 + # French characters + FR_CHR = 5 + # French tokens + FR_TOK = 6 + # German characters + DE_CHR = 7 + # German tokens + DE_TOK = 8 + # German bpe tokens + DE_BPE_TOK = 9 + # Digit cipher lexicon 0 + DIGIT_0 = 10 + # Digit cipher lexicon 1 + DIGIT_1 = 11 + # Audio waveform domain + AUDIO_WAV = 12 + # Audio spectral domain + AUDIO_SPECTRAL = 13 + # Parse characters + PARSE_CHR = 14 + # Parse tokens + PARSE_TOK = 15 + # Chinese tokens + ZH_TOK = 16 + # Icelandic characters + ICE_CHAR = 17 + # Icelandic tokens + ICE_TOK = 18 + # Icelandic parse tokens + ICE_PARSE_TOK = 19 + + +class Problem(object): + """Problem base class. Specifies a T2T problem. + + Problems unify the specification of a problem for data generation, training, + and inference. + + New problems are specified by the following methods: + + Data generation: + * generate_data(data_dir, tmp_dir) + - Generate training and dev datasets into data_dir. + - Additonal files, e.g. vocabulary files, should also be written to + data_dir. + - Downloads and other files can be written to tmp_dir + - If you have a training and dev generator, you can generate the + training and dev datasets with + generator_utils.generate_dataset_and_shuffle. + - Use the self.training_filepaths and self.dev_filepaths functions to + get sharded filenames. If shuffled=False, the filenames will contain + an "unshuffled" suffix; you should then shuffle the data + shard-by-shard with generator_utils.shuffle_dataset. + - Subclasses must override + * dataset_filename() + - Base filename for problem. + - Defaults to registered name (self.name). + + Training: + * hparams(defaults, model_hparams) + - Specify the problem hyperparameters (see _default_hparams) + - Mutate defaults as needed + + Inference: + * feature_encoders(data_dir) + - Return a dict of for encoding and decoding + inference input/output. + - Defaults to TextEncoder for inputs and targets. + """ + + # ============================================================================ + # BEGIN SUBCLASS INTERFACE + # ============================================================================ + + def generate_data(self, data_dir, tmp_dir): + raise NotImplementedError() + + def hparams(self, defaults, model_hparams): + pass + + def dataset_filename(self): + return self.name + + def feature_encoders(self, data_dir): + del data_dir + return { + "inputs": text_encoder.TextEncoder(), + "targets": text_encoder.TextEncoder() + } + + # ============================================================================ + # END SUBCLASS INTERFACE + # ============================================================================ + + def training_filepaths(self, data_dir, num_shards, shuffled): + file_basename = self.dataset_filename() + if not shuffled: + file_basename += utils.UNSHUFFLED_SUFFIX + return utils.train_data_filenames(file_basename, data_dir, num_shards) + + def dev_filepaths(self, data_dir, num_shards, shuffled): + file_basename = self.dataset_filename() + if not shuffled: + file_basename += utils.UNSHUFFLED_SUFFIX + return utils.dev_data_filenames(file_basename, data_dir, num_shards) + + def __init__(self, was_reversed=False, was_copy=False): + """Create a Problem. + + Args: + was_reversed: bool, whether to reverse inputs and targets. + was_copy: bool, whether to copy inputs to targets. Can be composed with + was_reversed so that if both are true, the targets become the inputs, + which are then copied to targets so that the task is targets->targets. + """ + self._was_reversed = was_reversed + self._was_copy = was_copy + self._encoders = None + + def internal_build_encoders(self, data_dir): + self._encoders = self.feature_encoders(data_dir) + + def internal_hparams(self, model_hparams): + """Returns problem_hparams.""" + if self._encoders is None: + self.internal_build_encoders(model_hparams.data_dir) + + hp = _default_hparams() + ret = self.hparams(hp, model_hparams) + if ret is not None: + raise ValueError("The Problem subclass hparams function should mutate " + "the defaults passed in and return None.") + + hp.add_hparam("vocabulary", self._encoders) + hp.add_hparam("was_reversed", self._was_reversed) + hp.add_hparam("was_copy", self._was_copy) + + if self._was_reversed: + _reverse_problem_hparams(hp) + # TODO(rsepassi): Move this into the cifar10 Problem + if "image_cifar10" in self.name: + hp.loss_multiplier = 1. + if self._was_copy: + _copy_problem_hparams(hp) + return hp + + +def _copy_problem_hparams(p_hparams): + """Use input modality, vocab, and space id for target.""" + p = p_hparams + # Duplicate input modality. + p.target_modality = p.input_modality["inputs"] + # Duplicate input vocabulary. + p.vocabulary["targets"] = p.vocabulary["inputs"] + # Duplicate input space ids. + p.target_space_id = p.input_space_id + # Mark that p was reversed. + p.was_copy = True + + +def _reverse_problem_hparams(p_hparams): + """Swap input/output modalities, vocab, and space ids.""" + p = p_hparams + + # Swap modalities. + input_modality = p.input_modality["inputs"] + target_modality = p.target_modality + p.input_modality["inputs"] = target_modality + p.target_modality = input_modality + + # Swap vocabularies. + input_vocabulary = p.vocabulary["inputs"] + target_vocabulary = p.vocabulary["targets"] + p.vocabulary["inputs"] = target_vocabulary + p.vocabulary["targets"] = input_vocabulary + + # Swap input/target space ids. + input_space_id = p.input_space_id + target_space_id = p.target_space_id + p.input_space_id = target_space_id + p.target_space_id = input_space_id + + # Mark that p was reversed. + p.was_reversed = True + + +def _default_hparams(): + """A set of basic model hyperparameters.""" + return tf.contrib.training.HParams( + # Use this parameter to get comparable perplexity numbers with different + # tokenizations. This value should be set to the ratio of the number of + # tokens in the test set according to the tokeization used to the number + # of tokens in the test set in the "official" tokenization. For + # example, if we are using a word-piece based model and we want to + # compute per-word perplexity, then we set loss_multiplier to the number + # of wordpieces per word in the test set. + loss_multiplier=1.0, + + # Use this parameter to allow for larger sequences in the batch. Without + # the use of this parameter, the size of the inner two dimensions will + # be used to judge the sequence length. + batch_size_multiplier=1, + + # To make queues of the right capacity, it's good to know the maximal + # expected batch size, as it can vary a lot. It only affects performance + # of input readers and memory use. The defaults should be safe and fast, + # but decrease if your reader uses a lot of memory and increase if slow. + max_expected_batch_size_per_shard=64, + + # Modalities used to map from input features to a space compatible with + # chosen model architecture. One modality spec (which is a 2-tuple, + # (modality_full_name, vocab_size)) per feature key. modality_full_name + # is a string type:name, e.g. class_label:class_label_2d. Leaving off + # the name uses the default modality for that type (e.g. class_label == + # class_label:default). + input_modality={}, + + # Modality used to map from hidden representation to the target space. + # Specified as a modality spec, a 2-tuple described above. + target_modality=None, + + # Identifiers used to tell the model which input/target space will be + # expected. For example, it can tell that we expect French as characters + # as output, or Spanish as sound. Spaces defined as constants in SpaceID + # class. + input_space_id=SpaceID.GENERIC, + target_space_id=SpaceID.GENERIC) diff --git a/tensor2tensor/data_generators/problem_hparams.py b/tensor2tensor/data_generators/problem_hparams.py old mode 100755 new mode 100644 index 91a685dfe..70b9dada8 --- a/tensor2tensor/data_generators/problem_hparams.py +++ b/tensor2tensor/data_generators/problem_hparams.py @@ -1,4 +1,4 @@ -# Copyright 2017 Google Inc. +# Copyright 2017 The Tensor2Tensor Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -69,7 +69,7 @@ def parse_problem_name(problem_name): if problem_name.endswith("_rev"): base, _, was_copy = parse_problem_name(problem_name[:-4]) return base, True, was_copy - if problem_name.endswith("_copy"): + elif problem_name.endswith("_copy"): base, was_reversed, _ = parse_problem_name(problem_name[:-5]) return base, was_reversed, True return problem_name, False, False @@ -201,7 +201,7 @@ def default_problem_hparams(): # `problem_rev_copy` will copy the targets. was_reversed=False, was_copy=False, - ) + ) def test_problem_hparams(unused_model_hparams, input_vocab_size, @@ -456,26 +456,6 @@ def wmt_ende_characters(unused_model_hparams): return p -def wmt_ende_tokens(model_hparams, wrong_vocab_size): - """English to German translation benchmark.""" - p = default_problem_hparams() - # This vocab file must be present within the data directory. - vocab_filename = os.path.join(model_hparams.data_dir, - "tokens.vocab.%d" % wrong_vocab_size) - subtokenizer = text_encoder.SubwordTextEncoder(vocab_filename) - p.input_modality = { - "inputs": (registry.Modalities.SYMBOL, subtokenizer.vocab_size) - } - p.target_modality = (registry.Modalities.SYMBOL, subtokenizer.vocab_size) - p.vocabulary = { - "inputs": subtokenizer, - "targets": subtokenizer, - } - p.input_space_id = 3 - p.target_space_id = 8 - return p - - def wmt_zhen_tokens(model_hparams, wrong_vocab_size): """Chinese to English translation benchmark.""" p = default_problem_hparams() @@ -502,41 +482,9 @@ def wmt_zhen_tokens(model_hparams, wrong_vocab_size): return p -def wmt_ende_v2(model_hparams, vocab_size): - """English to German translation benchmark with separate vocabularies.""" - p = default_problem_hparams() - # These vocab files must be present within the data directory. - source_vocab_filename = os.path.join(model_hparams.data_dir, - "wmt_ende_v2.en.vocab.%d" % vocab_size) - target_vocab_filename = os.path.join(model_hparams.data_dir, - "wmt_ende_v2.de.vocab.%d" % vocab_size) - p.input_modality = {"inputs": (registry.Modalities.SYMBOL, vocab_size)} - p.target_modality = (registry.Modalities.SYMBOL, vocab_size) - p.vocabulary = { - "inputs": text_encoder.SubwordTextEncoder(source_vocab_filename), - "targets": text_encoder.SubwordTextEncoder(target_vocab_filename), - } - p.input_space_id = 3 - p.target_space_id = 8 - return p - - -def wmt_concat(model_hparams, wrong_vocab_size): - """English to German translation benchmark.""" - p = default_problem_hparams() - # This vocab file must be present within the data directory. - vocab_filename = os.path.join(model_hparams.data_dir, - "tokens.vocab.%d" % wrong_vocab_size) - subtokenizer = text_encoder.SubwordTextEncoder(vocab_filename) - vocab_size = subtokenizer.vocab_size - p.input_modality = {} - p.target_modality = (registry.Modalities.SYMBOL, vocab_size) - p.vocabulary = {"targets": subtokenizer} - return p - - def wmt_parsing_characters(model_hparams): """English to parse tree translation benchmark.""" + del model_hparams # Unused. p = default_problem_hparams() p.input_modality = {"inputs": (registry.Modalities.SYMBOL, 256)} p.target_modality = (registry.Modalities.SYMBOL, 256) @@ -579,13 +527,15 @@ def wmt_parsing_tokens(model_hparams, wrong_vocab_size): return p -def wsj_parsing_tokens(model_hparams, prefix, +def wsj_parsing_tokens(model_hparams, + prefix, wrong_source_vocab_size, wrong_target_vocab_size): """English to parse tree translation benchmark. Args: model_hparams: a tf.contrib.training.HParams + prefix: name to use as prefix for vocabulary files. wrong_source_vocab_size: a number used in the filename indicating the approximate vocabulary size. This is not to be confused with the actual vocabulary size. @@ -624,8 +574,12 @@ def ice_parsing_tokens(model_hparams, wrong_source_vocab_size): Args: model_hparams: a tf.contrib.training.HParams + wrong_source_vocab_size: a number used in the filename indicating the + approximate vocabulary size. This is not to be confused with the actual + vocabulary size. + Returns: - a tf.contrib.training.HParams + A tf.contrib.training.HParams object. """ p = default_problem_hparams() # This vocab file must be present within the data directory. @@ -645,8 +599,8 @@ def ice_parsing_tokens(model_hparams, wrong_source_vocab_size): "inputs": source_subtokenizer, "targets": target_subtokenizer, } - p.input_space_id = 18 # Icelandic tokens - p.target_space_id = 19 # Icelandic parse tokens + p.input_space_id = 18 # Icelandic tokens + p.target_space_id = 19 # Icelandic parse tokens return p @@ -747,8 +701,6 @@ def img2img_imagenet(unused_model_hparams): PROBLEM_HPARAMS_MAP = { "algorithmic_addition_binary40": lambda p: algorithmic(4, p), "algorithmic_addition_decimal40": lambda p: algorithmic(12, p), - "algorithmic_identity_binary40": lambda p: algorithmic(4, p), - "algorithmic_identity_decimal40": lambda p: algorithmic(12, p), "algorithmic_multiplication_binary40": lambda p: algorithmic(4, p), "algorithmic_multiplication_decimal40": lambda p: algorithmic(12, p), "algorithmic_reverse_binary40": lambda p: algorithmic(4, p), @@ -767,33 +719,19 @@ def img2img_imagenet(unused_model_hparams): "lm1b_32k": lm1b_32k, "wiki_32k": wiki_32k, "lmptb_10k": lmptb_10k, - "wmt_parsing_characters": wmt_parsing_characters, "ice_parsing_characters": wmt_parsing_characters, "ice_parsing_tokens": lambda p: ice_parsing_tokens(p, 2**13), "wmt_parsing_tokens_8k": lambda p: wmt_parsing_tokens(p, 2**13), - "wsj_parsing_tokens_16k": lambda p: wsj_parsing_tokens(p, "wsj", 2**14, 2**9), - "wsj_parsing_tokens_32k": lambda p: wsj_parsing_tokens(p, "wsj", 2**15, 2**9), + "wsj_parsing_tokens_16k": lambda p: wsj_parsing_tokens( # pylint: disable=g-long-lambda + p, "wsj", 2**14, 2**9), "wmt_enfr_characters": wmt_enfr_characters, "wmt_enfr_tokens_8k": lambda p: wmt_enfr_tokens(p, 2**13), "wmt_enfr_tokens_32k": lambda p: wmt_enfr_tokens(p, 2**15), "wmt_enfr_tokens_32k_shuffled": lambda p: wmt_enfr_tokens(p, 2**15), "wmt_enfr_tokens_32k_combined": lambda p: wmt_enfr_tokens(p, 2**15), "wmt_enfr_tokens_128k": lambda p: wmt_enfr_tokens(p, 2**17), - # bytes per subtoken: 3.267350 - "wmt_ende_concat_8k": lambda p: wmt_concat(p, 2**13), - # bytes per subtoken: 4.236272 - "wmt_ende_concat_32k": lambda p: wmt_concat(p, 2**15), "wmt_ende_characters": wmt_ende_characters, - "wmt_ende_tokens_8k": lambda p: wmt_ende_tokens(p, 2**13), - "wmt_ende_tokens_32k": lambda p: wmt_ende_tokens(p, 2**15), - "wmt_ende_tokens_128k": lambda p: wmt_ende_tokens(p, 2**17), - # bytes per subtoken: 4.59291664162 "wmt_ende_bpe32k": wmt_ende_bpe32k, - "wmt_ende_bpe32k_shuffled": wmt_ende_bpe32k, - "wmt_ende_bpe32k_combined": wmt_ende_bpe32k, - "wmt_ende_bpe32k_160": wmt_ende_bpe32k, - "wmt_ende_v2_32k_combined": lambda p: wmt_ende_v2(p, 2**15), - "wmt_ende_v2_16k_combined": lambda p: wmt_ende_v2(p, 2**14), "wmt_zhen_tokens_32k": lambda p: wmt_zhen_tokens(p, 2**15), "image_cifar10_tune": image_cifar10, "image_cifar10_test": image_cifar10, @@ -801,12 +739,8 @@ def img2img_imagenet(unused_model_hparams): "image_mnist_test": image_mnist, "image_mscoco_characters_tune": image_mscoco_characters, "image_mscoco_characters_test": image_mscoco_characters, - "image_mscoco_tokens_8k_tune": lambda p: image_mscoco_tokens(p, 2**13), "image_mscoco_tokens_8k_test": lambda p: image_mscoco_tokens(p, 2**13), - "image_mscoco_tokens_32k_tune": lambda p: image_mscoco_tokens(p, 2**15), "image_mscoco_tokens_32k_test": lambda p: image_mscoco_tokens(p, 2**15), - "image_mscoco_tokens_128k_tune": lambda p: image_mscoco_tokens(p, 2**17), - "image_mscoco_tokens_128k_test": lambda p: image_mscoco_tokens(p, 2**17), "image_imagenet": image_imagenet, "img2img_imagenet": img2img_imagenet, } diff --git a/tensor2tensor/data_generators/problem_hparams_test.py b/tensor2tensor/data_generators/problem_hparams_test.py index d3803396f..ad1f0192d 100644 --- a/tensor2tensor/data_generators/problem_hparams_test.py +++ b/tensor2tensor/data_generators/problem_hparams_test.py @@ -1,4 +1,4 @@ -# Copyright 2017 Google Inc. +# Copyright 2017 The Tensor2Tensor Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tensor2tensor/data_generators/ptb.py b/tensor2tensor/data_generators/ptb.py index d4cf42c88..9a7db3a78 100644 --- a/tensor2tensor/data_generators/ptb.py +++ b/tensor2tensor/data_generators/ptb.py @@ -1,4 +1,4 @@ -# Copyright 2017 Google Inc. +# Copyright 2017 The Tensor2Tensor Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tensor2tensor/data_generators/snli.py b/tensor2tensor/data_generators/snli.py index 1d3acd356..7322c59ff 100644 --- a/tensor2tensor/data_generators/snli.py +++ b/tensor2tensor/data_generators/snli.py @@ -1,4 +1,4 @@ -# Copyright 2017 Google Inc. +# Copyright 2017 The Tensor2Tensor Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tensor2tensor/data_generators/text_encoder.py b/tensor2tensor/data_generators/text_encoder.py old mode 100755 new mode 100644 index 38b78256d..4a5a784c2 --- a/tensor2tensor/data_generators/text_encoder.py +++ b/tensor2tensor/data_generators/text_encoder.py @@ -1,4 +1,4 @@ -# Copyright 2017 Google Inc. +# Copyright 2017 The Tensor2Tensor Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -36,35 +36,47 @@ import tensorflow as tf -# Conversion between Unicode and UTF-8, if required (on Python2) -if PY2: - native_to_unicode = lambda s: s if isinstance(s, unicode) else s.decode("utf-8") - unicode_to_native = lambda s: s.encode("utf-8") -else: - # No conversion required on Python3 - native_to_unicode = lambda s: s - unicode_to_native = lambda s: s - - # Reserved tokens for things like padding and EOS symbols. PAD = "" EOS = "" RESERVED_TOKENS = [PAD, EOS] NUM_RESERVED_TOKENS = len(RESERVED_TOKENS) -PAD_TOKEN = RESERVED_TOKENS.index(PAD) # Normally 0 -EOS_TOKEN = RESERVED_TOKENS.index(EOS) # Normally 1 +PAD_TOKEN = RESERVED_TOKENS.index(PAD) # Normally 0 +EOS_TOKEN = RESERVED_TOKENS.index(EOS) # Normally 1 if PY2: RESERVED_TOKENS_BYTES = RESERVED_TOKENS else: RESERVED_TOKENS_BYTES = [bytes(PAD, "ascii"), bytes(EOS, "ascii")] + +def native_to_unicode_py2(s): + """Python 2: transform native string to Unicode.""" + if isinstance(s, unicode): + return s + return s.decode("utf-8") + + +# Conversion between Unicode and UTF-8, if required (on Python2) +if PY2: + native_to_unicode = native_to_unicode_py2 + unicode_to_native = lambda s: s.encode("utf-8") +else: + # No conversion required on Python3 + native_to_unicode = lambda s: s + unicode_to_native = lambda s: s + + class TextEncoder(object): """Base class for converting from ints to/from human readable strings.""" def __init__(self, num_reserved_ids=NUM_RESERVED_TOKENS): self._num_reserved_ids = num_reserved_ids + @property + def num_reserved_ids(self): + return self._num_reserved_ids + def encode(self, s): """Transform a human-readable string into a sequence of int ids. @@ -137,7 +149,8 @@ def vocab_size(self): class TokenTextEncoder(TextEncoder): """Encoder based on a user-supplied vocabulary.""" - def __init__(self, vocab_filename, reverse=False, num_reserved_ids=NUM_RESERVED_TOKENS): + def __init__(self, vocab_filename, reverse=False, + num_reserved_ids=NUM_RESERVED_TOKENS): """Initialize from a file, one token per line.""" super(TokenTextEncoder, self).__init__(num_reserved_ids=num_reserved_ids) self._reverse = reverse diff --git a/tensor2tensor/data_generators/text_encoder_build_subword.py b/tensor2tensor/data_generators/text_encoder_build_subword.py index df8aa73eb..093101c68 100644 --- a/tensor2tensor/data_generators/text_encoder_build_subword.py +++ b/tensor2tensor/data_generators/text_encoder_build_subword.py @@ -1,4 +1,4 @@ -# Copyright 2017 Google Inc. +# Copyright 2017 The Tensor2Tensor Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tensor2tensor/data_generators/tokenizer.py b/tensor2tensor/data_generators/tokenizer.py old mode 100755 new mode 100644 index 65fe19334..2b1cf572c --- a/tensor2tensor/data_generators/tokenizer.py +++ b/tensor2tensor/data_generators/tokenizer.py @@ -1,4 +1,4 @@ -# Copyright 2017 Google Inc. +# Copyright 2017 The Tensor2Tensor Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tensor2tensor/data_generators/tokenizer_test.py b/tensor2tensor/data_generators/tokenizer_test.py index 404a11396..c279290ed 100644 --- a/tensor2tensor/data_generators/tokenizer_test.py +++ b/tensor2tensor/data_generators/tokenizer_test.py @@ -1,4 +1,4 @@ -# Copyright 2017 Google Inc. +# Copyright 2017 The Tensor2Tensor Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tensor2tensor/data_generators/wiki.py b/tensor2tensor/data_generators/wiki.py index 99a9e64e6..8f905aa96 100644 --- a/tensor2tensor/data_generators/wiki.py +++ b/tensor2tensor/data_generators/wiki.py @@ -1,4 +1,4 @@ -# Copyright 2017 Google Inc. +# Copyright 2017 The Tensor2Tensor Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -25,7 +25,6 @@ # Dependency imports import six -from six import PY2 from tensor2tensor.data_generators import generator_utils from tensor2tensor.data_generators import text_encoder from tensor2tensor.data_generators import tokenizer @@ -61,7 +60,7 @@ def page_generator(tmp_dir, max_docs=None): count = 0 corpus_filepath = _maybe_download_corpus(tmp_dir) for line in bz2.BZ2File(corpus_filepath, "r"): - line = unicode(line, "utf-8") if PY2 else line.decode("utf-8") + line = unicode(line, "utf-8") if six.PY2 else line.decode("utf-8") if not doc and line != u" \n": continue doc += line diff --git a/tensor2tensor/data_generators/wmt.py b/tensor2tensor/data_generators/wmt.py old mode 100755 new mode 100644 index 2d43d1739..8edab8ba2 --- a/tensor2tensor/data_generators/wmt.py +++ b/tensor2tensor/data_generators/wmt.py @@ -1,4 +1,4 @@ -# Copyright 2017 Google Inc. +# Copyright 2017 The Tensor2Tensor Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -24,21 +24,65 @@ # Dependency imports from tensor2tensor.data_generators import generator_utils +from tensor2tensor.data_generators import problem from tensor2tensor.data_generators import text_encoder from tensor2tensor.data_generators import wsj_parsing +from tensor2tensor.utils import registry import tensorflow as tf - tf.flags.DEFINE_string("ende_bpe_path", "", "Path to BPE files in tmp_dir." "Download from https://drive.google.com/open?" "id=0B_bZck-ksdkpM25jRUN2X2UxMm8") - FLAGS = tf.flags.FLAGS -# End-of-sentence marker +@registry.register_problem("wmt_ende_tokens_8k") +class WMTEnDeTokens8k(problem.Problem): + """Problem spec for WMT En-De translation.""" + + @property + def target_vocab_size(self): + return 2**13 # 8192 + + def feature_encoders(self, data_dir): + return _default_wmt_feature_encoders(data_dir, self.target_vocab_size) + + def generate_data(self, data_dir, tmp_dir): + generator_utils.generate_dataset_and_shuffle( + ende_wordpiece_token_generator(tmp_dir, True, self.target_vocab_size), + self.training_filepaths(data_dir, 100, shuffled=False), + ende_wordpiece_token_generator(tmp_dir, False, self.target_vocab_size), + self.dev_filepaths(data_dir, 1, shuffled=False)) + + def hparams(self, defaults, unused_model_hparams): + p = defaults + vocab_size = self._encoders["inputs"].vocab_size + p.input_modality = {"inputs": (registry.Modalities.SYMBOL, vocab_size)} + p.target_modality = (registry.Modalities.SYMBOL, vocab_size) + p.input_space_id = problem.SpaceID.EN_TOK + p.target_space_id = problem.SpaceID.DE_TOK + + +@registry.register_problem("wmt_ende_tokens_32k") +class WMTEnDeTokens32k(WMTEnDeTokens8k): + + @property + def target_vocab_size(self): + return 2**15 # 32768 + + +def _default_wmt_feature_encoders(data_dir, target_vocab_size): + vocab_filename = os.path.join(data_dir, "tokens.vocab.%d" % target_vocab_size) + subtokenizer = text_encoder.SubwordTextEncoder(vocab_filename) + return { + "inputs": subtokenizer, + "targets": subtokenizer, + } + + +# End-of-sentence marker. EOS = text_encoder.EOS_TOKEN @@ -72,12 +116,13 @@ def character_generator(source_path, target_path, character_vocab, eos=None): def tabbed_generator(source_path, source_vocab, target_vocab, eos=None): - """Generator for sequence-to-sequence tasks using tokens derived from - text files where each line contains both a source and a target string. - The two strings are separated by a tab character ('\t'). It yields - dictionaries of "inputs" and "targets" where inputs are characters - from the source lines converted to integers, and targets are - characters from the target lines, also converted to integers. + r"""Generator for sequence-to-sequence tasks using tabbed files. + + Tokens are derived from text files where each line contains both + a source and a target string. The two strings are separated by a tab + character ('\t'). It yields dictionaries of "inputs" and "targets" where + inputs are characters from the source lines converted to integers, and + targets are characters from the target lines, also converted to integers. Args: source_path: path to the file with source and target sentences. @@ -92,11 +137,11 @@ def tabbed_generator(source_path, source_vocab, target_vocab, eos=None): eos_list = [] if eos is None else [eos] with tf.gfile.GFile(source_path, mode="r") as source_file: for line in source_file: - if line and '\t' in line: - parts = line.split('\t', maxsplit = 1) + if line and "\t" in line: + parts = line.split("\t", maxsplit=1) source, target = parts[0].strip(), parts[1].strip() source_ints = source_vocab.encode(source) + eos_list - target_ints = source_vocab.encode(target) + eos_list + target_ints = target_vocab.encode(target) + eos_list yield {"inputs": source_ints, "targets": target_ints} @@ -129,7 +174,8 @@ def token_generator(source_path, target_path, token_vocab, eos=None): source, target = source_file.readline(), target_file.readline() -def bi_vocabs_token_generator(source_path, target_path, +def bi_vocabs_token_generator(source_path, + target_path, source_token_vocab, target_token_vocab, eos=None): @@ -161,6 +207,7 @@ def bi_vocabs_token_generator(source_path, target_path, yield {"inputs": source_ints, "targets": target_ints} source, target = source_file.readline(), target_file.readline() + def _get_wmt_ende_dataset(directory, filename): """Extract the WMT en-de corpus `filename` to directory unless it's there.""" train_path = os.path.join(directory, filename) @@ -182,7 +229,8 @@ def ende_bpe_token_generator(tmp_dir, train): train_path = _get_wmt_ende_dataset(tmp_dir, dataset_path) token_path = os.path.join(tmp_dir, "vocab.bpe.32000") token_vocab = text_encoder.TokenTextEncoder(vocab_filename=token_path) - return token_generator(train_path + ".en", train_path + ".de", token_vocab, EOS) + return token_generator(train_path + ".en", train_path + ".de", token_vocab, + EOS) _ENDE_TRAIN_DATASETS = [ @@ -237,75 +285,54 @@ def ende_bpe_token_generator(tmp_dir, train): ], ] -_ZHEN_TRAIN_DATASETS = [ - [ - "http://data.statmt.org/wmt17/translation-task/training-parallel-nc-v12.tgz", - ("training/news-commentary-v12.zh-en.zh", - "training/news-commentary-v12.zh-en.en") - ] -] +_ZHEN_TRAIN_DATASETS = [[("http://data.statmt.org/wmt17/translation-task/" + "training-parallel-nc-v12.tgz"), + ("training/news-commentary-v12.zh-en.zh", + "training/news-commentary-v12.zh-en.en")]] + +_ZHEN_TEST_DATASETS = [[ + "http://data.statmt.org/wmt17/translation-task/dev.tgz", + ("dev/newsdev2017-zhen-src.zh", "dev/newsdev2017-zhen-ref.en") +]] -_ZHEN_TEST_DATASETS = [ - [ - "http://data.statmt.org/wmt17/translation-task/dev.tgz", - ("dev/newsdev2017-zhen-src.zh", - "dev/newsdev2017-zhen-ref.en") - ] -] def _compile_data(tmp_dir, datasets, filename): """Concatenate all `datasets` and save to `filename`.""" filename = os.path.join(tmp_dir, filename) - lang1_lines, lang2_lines = [], [] - for dataset in datasets: - url = dataset[0] - compressed_filename = os.path.basename(url) - compressed_filepath = os.path.join(tmp_dir, compressed_filename) - - lang1_filename, lang2_filename = dataset[1] - lang1_filepath = os.path.join(tmp_dir, lang1_filename) - lang2_filepath = os.path.join(tmp_dir, lang2_filename) - - if not os.path.exists(compressed_filepath): - generator_utils.maybe_download(tmp_dir, compressed_filename, url) - if not os.path.exists(lang1_filepath) or not os.path.exists(lang2_filepath): - mode = "r:gz" if "gz" in compressed_filepath else "r" - with tarfile.open(compressed_filepath, mode) as corpus_tar: - corpus_tar.extractall(tmp_dir) - if ".gz" in lang1_filepath: - new_filepath = lang1_filepath.strip(".gz") - generator_utils.gunzip_file(lang1_filepath, new_filepath) - lang1_filepath = new_filepath - if ".gz" in lang2_filepath: - new_filepath = lang2_filepath.strip(".gz") - generator_utils.gunzip_file(lang2_filepath, new_filepath) - lang2_filepath = new_filepath - with tf.gfile.GFile(lang1_filepath, mode="r") as lang1_file: - with tf.gfile.GFile(lang2_filepath, mode="r") as lang2_file: - lang1_file_lines = lang1_file.readlines() - lang2_file_lines = lang2_file.readlines() - assert len(lang1_file_lines) == len(lang2_file_lines), lang1_filepath - lang1_lines.extend(lang1_file_lines) - lang2_lines.extend(lang2_file_lines) - - write_chunk_size = 10000 - assert len(lang1_lines) == len(lang2_lines) - with tf.gfile.GFile(filename + ".lang1", mode="w") as lang1_file: - i = 0 - while i <= len(lang1_lines): - for line in lang1_lines[i * write_chunk_size:(i + 1) * write_chunk_size]: - lang1_file.write(line) - i += 1 - for line in lang1_lines[i * write_chunk_size:]: - lang1_file.write(line) - with tf.gfile.GFile(filename + ".lang2", mode="w") as lang2_file: - i = 0 - while i <= len(lang2_lines): - for line in lang2_lines[i * write_chunk_size:(i + 1) * write_chunk_size]: - lang2_file.write(line) - i += 1 - for line in lang2_lines[i * write_chunk_size:]: - lang2_file.write(line) + with tf.gfile.GFile(filename + ".lang1", mode="w") as lang1_resfile: + with tf.gfile.GFile(filename + ".lang2", mode="w") as lang2_resfile: + for dataset in datasets: + url = dataset[0] + compressed_filename = os.path.basename(url) + compressed_filepath = os.path.join(tmp_dir, compressed_filename) + + lang1_filename, lang2_filename = dataset[1] + lang1_filepath = os.path.join(tmp_dir, lang1_filename) + lang2_filepath = os.path.join(tmp_dir, lang2_filename) + + if not os.path.exists(compressed_filepath): + generator_utils.maybe_download(tmp_dir, compressed_filename, url) + if not (os.path.exists(lang1_filepath) and + os.path.exists(lang2_filepath)): + mode = "r:gz" if "gz" in compressed_filepath else "r" + with tarfile.open(compressed_filepath, mode) as corpus_tar: + corpus_tar.extractall(tmp_dir) + if ".gz" in lang1_filepath: + new_filepath = lang1_filepath.strip(".gz") + generator_utils.gunzip_file(lang1_filepath, new_filepath) + lang1_filepath = new_filepath + if ".gz" in lang2_filepath: + new_filepath = lang2_filepath.strip(".gz") + generator_utils.gunzip_file(lang2_filepath, new_filepath) + lang2_filepath = new_filepath + with tf.gfile.GFile(lang1_filepath, mode="r") as lang1_file: + with tf.gfile.GFile(lang2_filepath, mode="r") as lang2_file: + line1, line2 = lang1_file.readline(), lang2_file.readline() + while line1 or line2: + lang1_resfile.write(line1.strip() + "\n") + lang2_resfile.write(line2.strip() + "\n") + line1, line2 = lang1_file.readline(), lang2_file.readline() + return filename @@ -328,23 +355,22 @@ def ende_character_generator(tmp_dir, train): character_vocab, EOS) -def zhen_wordpiece_token_generator(tmp_dir, train, - source_vocab_size, +def zhen_wordpiece_token_generator(tmp_dir, train, source_vocab_size, target_vocab_size): + """Wordpiece generator for the WMT'17 zh-en dataset.""" datasets = _ZHEN_TRAIN_DATASETS if train else _ZHEN_TEST_DATASETS source_datasets = [[item[0], [item[1][0]]] for item in datasets] target_datasets = [[item[0], [item[1][1]]] for item in datasets] source_vocab = generator_utils.get_or_generate_vocab( - tmp_dir, "tokens.vocab.zh.%d" % source_vocab_size, - source_vocab_size, source_datasets) + tmp_dir, "tokens.vocab.zh.%d" % source_vocab_size, source_vocab_size, + source_datasets) target_vocab = generator_utils.get_or_generate_vocab( - tmp_dir, "tokens.vocab.en.%d" % target_vocab_size, - target_vocab_size, target_datasets) + tmp_dir, "tokens.vocab.en.%d" % target_vocab_size, target_vocab_size, + target_datasets) tag = "train" if train else "dev" data_path = _compile_data(tmp_dir, datasets, "wmt_zhen_tok_%s" % tag) - return bi_vocabs_token_generator(data_path + ".lang1", - data_path + ".lang2", - source_vocab, target_vocab, EOS) + return bi_vocabs_token_generator(data_path + ".lang1", data_path + ".lang2", + source_vocab, target_vocab, EOS) def enfr_wordpiece_token_generator(tmp_dir, train, vocab_size): @@ -367,6 +393,7 @@ def enfr_character_generator(tmp_dir, train): return character_generator(data_path + ".lang1", data_path + ".lang2", character_vocab, EOS) + def parsing_character_generator(tmp_dir, train): character_vocab = text_encoder.ByteTextEncoder() filename = "parsing_%s" % ("train" if train else "dev") @@ -375,25 +402,22 @@ def parsing_character_generator(tmp_dir, train): return character_generator(text_filepath, tags_filepath, character_vocab, EOS) -def tabbed_parsing_token_generator(tmp_dir, train, prefix, source_vocab_size, target_vocab_size): - """Generate source and target data from a single file with source/target pairs - separated by a tab character ('\t')""" +def tabbed_parsing_token_generator(tmp_dir, train, prefix, source_vocab_size, + target_vocab_size): + """Generate source and target data from a single file.""" source_vocab = generator_utils.get_or_generate_tabbed_vocab( tmp_dir, "parsing_train.pairs", 0, - prefix + "_source.tokens.vocab.%d" % source_vocab_size, - source_vocab_size) + prefix + "_source.tokens.vocab.%d" % source_vocab_size, source_vocab_size) target_vocab = generator_utils.get_or_generate_tabbed_vocab( tmp_dir, "parsing_train.pairs", 1, - prefix + "_target.tokens.vocab.%d" % target_vocab_size, - target_vocab_size) + prefix + "_target.tokens.vocab.%d" % target_vocab_size, target_vocab_size) filename = "parsing_%s" % ("train" if train else "dev") pair_filepath = os.path.join(tmp_dir, filename + ".pairs") return tabbed_generator(pair_filepath, source_vocab, target_vocab, EOS) def tabbed_parsing_character_generator(tmp_dir, train): - """Generate source and target data from a single file with source/target pairs - separated by a tab character ('\t')""" + """Generate source and target data from a single file.""" character_vocab = text_encoder.ByteTextEncoder() filename = "parsing_%s" % ("train" if train else "dev") pair_filepath = os.path.join(tmp_dir, filename + ".pairs") @@ -405,5 +429,5 @@ def parsing_token_generator(tmp_dir, train, vocab_size): tmp_dir, "tokens.vocab.%d" % vocab_size, vocab_size) filename = "%s_%s.trees" % (FLAGS.parsing_path, "train" if train else "dev") tree_filepath = os.path.join(tmp_dir, filename) - return wsj_parsing.token_generator(tree_filepath, - symbolizer_vocab, symbolizer_vocab, EOS) + return wsj_parsing.token_generator(tree_filepath, symbolizer_vocab, + symbolizer_vocab, EOS) diff --git a/tensor2tensor/data_generators/wmt_test.py b/tensor2tensor/data_generators/wmt_test.py index b6af3cf93..86b88e5b1 100644 --- a/tensor2tensor/data_generators/wmt_test.py +++ b/tensor2tensor/data_generators/wmt_test.py @@ -1,4 +1,4 @@ -# Copyright 2017 Google Inc. +# Copyright 2017 The Tensor2Tensor Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tensor2tensor/data_generators/wsj_parsing.py b/tensor2tensor/data_generators/wsj_parsing.py index 756a44954..7734db646 100644 --- a/tensor2tensor/data_generators/wsj_parsing.py +++ b/tensor2tensor/data_generators/wsj_parsing.py @@ -1,4 +1,4 @@ -# Copyright 2017 Google Inc. +# Copyright 2017 The Tensor2Tensor Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tensor2tensor/docs/distributed_training.md b/tensor2tensor/docs/distributed_training.md index f41197fc4..9ed9778da 100644 --- a/tensor2tensor/docs/distributed_training.md +++ b/tensor2tensor/docs/distributed_training.md @@ -10,11 +10,11 @@ along with a set of flags. ## `TF_CONFIG` -Both workers and parameter servers must have the `TF_CONFIG` environment +Both masters and parameter servers must have the `TF_CONFIG` environment variable set. The `TF_CONFIG` environment variable is a json-encoded string with the addresses -of the workers and parameter servers (in the `'cluster'` key) and the +of the masters and parameter servers (in the `'cluster'` key) and the identification of the current task (in the `'task'` key). For example: @@ -22,40 +22,42 @@ For example: ``` cluster = { 'ps': ['host1:2222', 'host2:2222'], - 'worker': ['host3:2222', 'host4:2222', 'host5:2222'] + 'master': ['host3:2222', 'host4:2222', 'host5:2222'] } os.environ['TF_CONFIG'] = json.dumps({ 'cluster': cluster, - 'task': {'type': 'worker', 'index': 1} + 'task': {'type': 'master', 'index': 1}, + 'environment': 'cloud', }) ``` ## Command-line flags -The following T2T command-line flags must also be set on the workers for +The following T2T command-line flags must also be set on the masters for distributed training: - `--master=grpc://$ADDRESS` -- `--worker_replicas=$NUM_WORKERS` -- `--worker_gpu=$NUM_GPUS_PER_WORKER` -- `--worker_id=$WORKER_ID` +- `--worker_replicas=$NUM_MASTERS` +- `--worker_gpu=$NUM_GPUS_PER_MASTER` +- `--worker_id=$MASTER_ID` +- `--worker_job='/job:master'` - `--ps_replicas=$NUM_PS` - `--ps_gpu=$NUM_GPUS_PER_PS` - `--schedule=train` - `--sync`, if you want synchronous training, i.e. for there to be a single - master worker coordinating the work across "ps" jobs (yes, the naming is - unfortunate). If not set, then each worker operates independently while - variables are shared on the parameter servers. + master coordinating the work across "ps" jobs. If not set, then each master + operates independently while variables are shared on the parameter servers. -Parameter servers only need `--schedule=run_std_server`. +Parameter servers only need `--master=grpc://$ADDRESS` and +`--schedule=run_std_server`. ## Utility to produce `TF_CONFIG` and flags [`t2t-make-tf-configs`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/bin/t2t-make-tf-configs)) generates the `TF_CONFIG` json strings and the above-mentioned command-line -flags for the workers and parameter servers. +flags for the masters and parameter servers. -Given a set of worker and parameter server addresses, the script outputs, for +Given a set of master and parameter server addresses, the script outputs, for each job, a line with the `TF_CONFIG` environment variable and the command-line flags necessary for distributed training. For each job, you should invoke the `t2t-trainer` with the `TF_CONFIG` value and flags that are output. @@ -66,6 +68,9 @@ For example: TF_CONFIG=$JOB_TF_CONFIG t2t-trainer $JOB_FLAGS --model=transformer ... ``` +Modify the `--worker_gpu` and `--ps_gpu` flags, which specify how many gpus are +on each master and ps, respectively, as needed for your machine/cluster setup. + ## Command-line flags for eval jobs Eval jobs should set the following flags and do not need the `TF_CONFIG` diff --git a/tensor2tensor/models/__init__.py b/tensor2tensor/models/__init__.py index 27d533abc..eff6a2b14 100644 --- a/tensor2tensor/models/__init__.py +++ b/tensor2tensor/models/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2017 Google Inc. +# Copyright 2017 The Tensor2Tensor Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tensor2tensor/models/attention_lm.py b/tensor2tensor/models/attention_lm.py index 99fbd8232..947dc9306 100644 --- a/tensor2tensor/models/attention_lm.py +++ b/tensor2tensor/models/attention_lm.py @@ -1,4 +1,4 @@ -# Copyright 2017 Google Inc. +# Copyright 2017 The Tensor2Tensor Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -140,7 +140,7 @@ def attention_lm_base(): hparams.optimizer_adam_epsilon = 1e-9 hparams.learning_rate_decay_scheme = "noam" hparams.learning_rate = 0.1 - hparams.learning_rate_warmup_steps = 1000 + hparams.learning_rate_warmup_steps = 2000 hparams.initializer_gain = 1.0 hparams.num_hidden_layers = 6 hparams.initializer = "uniform_unit_scaling" @@ -163,3 +163,22 @@ def attention_lm_base(): hparams.add_hparam("residual_dropout", 0.1) hparams.add_hparam("pos", "timing") # timing, none return hparams + + +@registry.register_hparams +def attention_lm_small(): + """Cheap model. + + on lm1b_32k: + 45M params + 2 steps/sec on [GeForce GTX TITAN X] + + Returns: + an hparams object. + """ + hparams = attention_lm_base() + hparams.num_hidden_layers = 4 + hparams.hidden_size = 512 + hparams.filter_size = 2048 + hparams.residual_dropout = 0.5 + return hparams diff --git a/tensor2tensor/models/attention_lm_moe.py b/tensor2tensor/models/attention_lm_moe.py index b4d27d400..952ff1a71 100644 --- a/tensor2tensor/models/attention_lm_moe.py +++ b/tensor2tensor/models/attention_lm_moe.py @@ -1,4 +1,4 @@ -# Copyright 2017 Google Inc. +# Copyright 2017 The Tensor2Tensor Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -117,9 +117,9 @@ def attention_lm_moe_base(): """Set of hyperparameters. suitable for 1 gpu. - on lm1b_16k: - ~337M params - 1.1 steps/sec on [GeForce GTX TITAN X] + on lm1b_32k: + ~229M params + 0.9 steps/sec on [GeForce GTX TITAN X] Returns: a hparams object @@ -133,7 +133,7 @@ def attention_lm_moe_base(): hparams.optimizer_adam_epsilon = 1e-9 hparams.learning_rate_decay_scheme = "noam" hparams.learning_rate = 0.1 - hparams.learning_rate_warmup_steps = 1000 + hparams.learning_rate_warmup_steps = 2000 hparams.initializer_gain = 1.0 hparams.num_hidden_layers = 4 hparams.initializer = "uniform_unit_scaling" @@ -143,14 +143,14 @@ def attention_lm_moe_base(): hparams.num_sampled_classes = 0 hparams.label_smoothing = 0.0 hparams.shared_embedding_and_softmax_weights = int(False) - hparams.add_hparam("filter_size", 2948) # Add new ones like this. + hparams.add_hparam("filter_size", 2048) # Add new ones like this. # comma-separated list of layer numbers. # At each of these layers, we replace the ffn with a mixture of experts. hparams.add_hparam("moe_layers", "2") # If moe_n2 is None, then use a flat MoE with moe_n1 experts. # If moe_n2 is an integer, then use a hierarchical MoE # consisting of moe_n1 groups of moe_n2 experts each. - hparams.add_hparam("moe_n1", 64) + hparams.add_hparam("moe_n1", 32) hparams.add_hparam("moe_n2", 0) hparams.add_hparam("moe_hidden_size", 2048) hparams.add_hparam("moe_loss_coef", 1e-2) @@ -171,9 +171,11 @@ def attention_lm_moe_base(): def attention_lm_moe_small(): """Cheap model for single-gpu training. - on lm1b_16k: - ~295M params - 2 steps/sec on [GeForce GTX TITAN X] + on lm1b_32k: + ~312M params + 1.6 steps/sec on [GeForce GTX TITAN X] + After 50K steps on 8 GPUs (synchronous): + eval_log_ppl_per_token = 3.31 Returns: an hparams object. @@ -188,6 +190,24 @@ def attention_lm_moe_small(): return hparams +@registry.register_hparams +def attention_lm_no_moe_small(): + """Without the mixture of experts (for comparison). + + on lm1b_32k: + ~45M params + 2 steps/sec on [GeForce GTX TITAN X] + After 50K steps on 8 GPUs (synchronous): + eval_log_ppl_per_token = 3.51 + + Returns: + an hparams object. + """ + hparams = attention_lm_moe_small() + hparams.moe_layers = "" + return hparams + + @registry.register_hparams def attention_lm_moe_large(): """Large model for distributed training. @@ -195,6 +215,11 @@ def attention_lm_moe_large(): Over 1B parameters, so requires multi-gpu training due to memory requirements. + on lm1b_32k: + After 45K steps on 8 GPUs (synchronous): + eval_log_ppl_per_token = 3.18 + eval_ppl_per_word = exp(1.107893 * eval_log_ppl_per_token) = 33.9 + Returns: an hparams object. """ diff --git a/tensor2tensor/models/bluenet.py b/tensor2tensor/models/bluenet.py index c0533ee42..95216f43d 100644 --- a/tensor2tensor/models/bluenet.py +++ b/tensor2tensor/models/bluenet.py @@ -1,4 +1,4 @@ -# Copyright 2017 Google Inc. +# Copyright 2017 The Tensor2Tensor Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tensor2tensor/models/bluenet_test.py b/tensor2tensor/models/bluenet_test.py index 080c96a3f..b3f18249d 100644 --- a/tensor2tensor/models/bluenet_test.py +++ b/tensor2tensor/models/bluenet_test.py @@ -1,4 +1,4 @@ -# Copyright 2017 Google Inc. +# Copyright 2017 The Tensor2Tensor Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tensor2tensor/models/bytenet.py b/tensor2tensor/models/bytenet.py index 1a82144d6..301626dc2 100644 --- a/tensor2tensor/models/bytenet.py +++ b/tensor2tensor/models/bytenet.py @@ -1,4 +1,4 @@ -# Copyright 2017 Google Inc. +# Copyright 2017 The Tensor2Tensor Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tensor2tensor/models/bytenet_test.py b/tensor2tensor/models/bytenet_test.py index 8202d5b74..f1e42669e 100644 --- a/tensor2tensor/models/bytenet_test.py +++ b/tensor2tensor/models/bytenet_test.py @@ -1,4 +1,4 @@ -# Copyright 2017 Google Inc. +# Copyright 2017 The Tensor2Tensor Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tensor2tensor/models/common_attention.py b/tensor2tensor/models/common_attention.py index b6a5e09d6..49cd40285 100644 --- a/tensor2tensor/models/common_attention.py +++ b/tensor2tensor/models/common_attention.py @@ -1,4 +1,4 @@ -# Copyright 2017 Google Inc. +# Copyright 2017 The Tensor2Tensor Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -374,7 +374,18 @@ def multihead_attention(query_antecedent, Returns: A Tensor. + + Raises: + ValueError: if the key depth or value depth are not divisible by the + number of attention heads. """ + if total_key_depth % num_heads != 0: + raise ValueError("Key depth (%d) must be divisible by the number of " + "attention heads (%d)." % (total_key_depth, num_heads)) + if total_value_depth % num_heads != 0: + raise ValueError("Value depth (%d) must be divisible by the number of " + "attention heads (%d)." % (total_value_depth, num_heads)) + with tf.variable_scope( name, default_name="multihead_attention", diff --git a/tensor2tensor/models/common_hparams.py b/tensor2tensor/models/common_hparams.py index f48a67c15..f067b724e 100644 --- a/tensor2tensor/models/common_hparams.py +++ b/tensor2tensor/models/common_hparams.py @@ -1,4 +1,4 @@ -# Copyright 2017 Google Inc. +# Copyright 2017 The Tensor2Tensor Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tensor2tensor/models/common_layers.py b/tensor2tensor/models/common_layers.py index d38f97fb0..1e7050570 100644 --- a/tensor2tensor/models/common_layers.py +++ b/tensor2tensor/models/common_layers.py @@ -1,4 +1,4 @@ -# Copyright 2017 Google Inc. +# Copyright 2017 The Tensor2Tensor Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tensor2tensor/models/common_layers_test.py b/tensor2tensor/models/common_layers_test.py index a87776bfb..3a2fafd8b 100644 --- a/tensor2tensor/models/common_layers_test.py +++ b/tensor2tensor/models/common_layers_test.py @@ -1,4 +1,4 @@ -# Copyright 2017 Google Inc. +# Copyright 2017 The Tensor2Tensor Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tensor2tensor/models/lstm.py b/tensor2tensor/models/lstm.py index 998e6756b..c3ae0a01e 100644 --- a/tensor2tensor/models/lstm.py +++ b/tensor2tensor/models/lstm.py @@ -1,4 +1,4 @@ -# Copyright 2017 Google Inc. +# Copyright 2017 The Tensor2Tensor Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tensor2tensor/models/lstm_test.py b/tensor2tensor/models/lstm_test.py index 4c4c42909..4ddaf6b64 100644 --- a/tensor2tensor/models/lstm_test.py +++ b/tensor2tensor/models/lstm_test.py @@ -1,4 +1,4 @@ -# Copyright 2017 Google Inc. +# Copyright 2017 The Tensor2Tensor Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tensor2tensor/models/modalities.py b/tensor2tensor/models/modalities.py index 4e7a7e924..60df80a1c 100644 --- a/tensor2tensor/models/modalities.py +++ b/tensor2tensor/models/modalities.py @@ -1,4 +1,4 @@ -# Copyright 2017 Google Inc. +# Copyright 2017 The Tensor2Tensor Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -465,3 +465,31 @@ def bottom(self, x): def top(self, body_output, _): return body_output + + +@registry.register_image_modality("identity_no_pad") +class IdentityModalityNoPad(modality.Modality): + """Does nothing except making sure that there is no padding in cross-ent.""" + + @property + def targets_dimensionality(self): + return self._vocab_size + + def bottom(self, x): + return tf.to_float(x) + + def top(self, body_output, _): + return body_output + + def top_sharded(self, + sharded_body_output, + sharded_targets, + data_parallelism, + weights_fn=common_layers.weights_all): + # Call the default implementation, but weight 1.0 on 0s by default. + # (Since we're processing images and so have no padding and some pixel 0s.) + return super(IdentityModalityNoPad, self).top_sharded( + sharded_body_output, + sharded_targets, + data_parallelism, + weights_fn=weights_fn) diff --git a/tensor2tensor/models/modalities_test.py b/tensor2tensor/models/modalities_test.py index 090af3aef..118db3847 100644 --- a/tensor2tensor/models/modalities_test.py +++ b/tensor2tensor/models/modalities_test.py @@ -1,4 +1,4 @@ -# Copyright 2017 Google Inc. +# Copyright 2017 The Tensor2Tensor Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tensor2tensor/models/models.py b/tensor2tensor/models/models.py index 214aec245..0ca11996e 100644 --- a/tensor2tensor/models/models.py +++ b/tensor2tensor/models/models.py @@ -1,4 +1,4 @@ -# Copyright 2017 Google Inc. +# Copyright 2017 The Tensor2Tensor Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tensor2tensor/models/multimodel.py b/tensor2tensor/models/multimodel.py index 26e7469c2..6f12db86d 100644 --- a/tensor2tensor/models/multimodel.py +++ b/tensor2tensor/models/multimodel.py @@ -1,4 +1,4 @@ -# Copyright 2017 Google Inc. +# Copyright 2017 The Tensor2Tensor Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tensor2tensor/models/multimodel_test.py b/tensor2tensor/models/multimodel_test.py index dbbd3fa8e..958fac5d7 100644 --- a/tensor2tensor/models/multimodel_test.py +++ b/tensor2tensor/models/multimodel_test.py @@ -1,4 +1,4 @@ -# Copyright 2017 Google Inc. +# Copyright 2017 The Tensor2Tensor Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tensor2tensor/models/neural_gpu.py b/tensor2tensor/models/neural_gpu.py index dce0dbc30..30d535098 100644 --- a/tensor2tensor/models/neural_gpu.py +++ b/tensor2tensor/models/neural_gpu.py @@ -1,4 +1,4 @@ -# Copyright 2017 Google Inc. +# Copyright 2017 The Tensor2Tensor Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tensor2tensor/models/neural_gpu_test.py b/tensor2tensor/models/neural_gpu_test.py index 3065bb1c4..1dddc1056 100644 --- a/tensor2tensor/models/neural_gpu_test.py +++ b/tensor2tensor/models/neural_gpu_test.py @@ -1,4 +1,4 @@ -# Copyright 2017 Google Inc. +# Copyright 2017 The Tensor2Tensor Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tensor2tensor/models/shake_shake.py b/tensor2tensor/models/shake_shake.py index f87eaa335..26d43afb3 100644 --- a/tensor2tensor/models/shake_shake.py +++ b/tensor2tensor/models/shake_shake.py @@ -1,7 +1,25 @@ +# Copyright 2017 The Tensor2Tensor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shake-shake model for CIFAR.""" + from __future__ import absolute_import from __future__ import division from __future__ import print_function +# Dependency imports + from six.moves import xrange # pylint: disable=redefined-builtin from tensor2tensor.models import common_hparams @@ -15,31 +33,29 @@ def shake_shake_block_branch(x, conv_filters, stride): x = tf.nn.relu(x) x = tf.layers.conv2d( - x, conv_filters, (3, 3), strides=(stride, stride), padding='SAME') + x, conv_filters, (3, 3), strides=(stride, stride), padding="SAME") x = tf.layers.batch_normalization(x) x = tf.nn.relu(x) - x = tf.layers.conv2d(x, conv_filters, (3, 3), strides=(1, 1), padding='SAME') + x = tf.layers.conv2d(x, conv_filters, (3, 3), strides=(1, 1), padding="SAME") x = tf.layers.batch_normalization(x) return x def downsampling_residual_branch(x, conv_filters): x = tf.nn.relu(x) - x1 = tf.layers.average_pooling2d(x, pool_size=(1, 1), strides=(2, 2)) - x1 = tf.layers.conv2d(x1, conv_filters / 2, (1, 1), padding='SAME') - + x1 = tf.layers.conv2d(x1, conv_filters / 2, (1, 1), padding="SAME") x2 = tf.pad(x[:, 1:, 1:], [[0, 0], [0, 1], [0, 1], [0, 0]]) x2 = tf.layers.average_pooling2d(x2, pool_size=(1, 1), strides=(2, 2)) - x2 = tf.layers.conv2d(x2, conv_filters / 2, (1, 1), padding='SAME') - + x2 = tf.layers.conv2d(x2, conv_filters / 2, (1, 1), padding="SAME") return tf.concat([x1, x2], axis=3) def shake_shake_block(x, conv_filters, stride, hparams): - with tf.variable_scope('branch_1'): + """A shake-shake block.""" + with tf.variable_scope("branch_1"): branch1 = shake_shake_block_branch(x, conv_filters, stride) - with tf.variable_scope('branch_2'): + with tf.variable_scope("branch_2"): branch2 = shake_shake_block_branch(x, conv_filters, stride) if x.shape[-1] == conv_filters: skip = tf.identity(x) @@ -48,14 +64,14 @@ def shake_shake_block(x, conv_filters, stride, hparams): # TODO(rshin): Use different alpha for each image in batch. if hparams.mode == tf.contrib.learn.ModeKeys.TRAIN: - if hparams.shakeshake_type == 'batch': + if hparams.shakeshake_type == "batch": shaken = common_layers.shakeshake2(branch1, branch2) - elif hparams.shakeshake_type == 'image': + elif hparams.shakeshake_type == "image": shaken = common_layers.shakeshake2_indiv(branch1, branch2) - elif hparams.shakeshake_type == 'equal': + elif hparams.shakeshake_type == "equal": shaken = common_layers.shakeshake2_py(branch1, branch2, equal=True) else: - raise ValueError('Invalid shakeshake_type: {!r}'.format(shaken)) + raise ValueError("Invalid shakeshake_type: {!r}".format(shaken)) else: shaken = common_layers.shakeshake2_py(branch1, branch2, equal=True) shaken.set_shape(branch1.get_shape()) @@ -64,22 +80,22 @@ def shake_shake_block(x, conv_filters, stride, hparams): def shake_shake_stage(x, num_blocks, conv_filters, initial_stride, hparams): - with tf.variable_scope('block_0'): + with tf.variable_scope("block_0"): x = shake_shake_block(x, conv_filters, initial_stride, hparams) for i in xrange(1, num_blocks): - with tf.variable_scope('block_{}'.format(i)): + with tf.variable_scope("block_{}".format(i)): x = shake_shake_block(x, conv_filters, 1, hparams) return x @registry.register_model class ShakeShake(t2t_model.T2TModel): - '''Implements the Shake-Shake architecture. + """Implements the Shake-Shake architecture. From This is intended to match the CIFAR-10 version, and correspond to "Shake-Shake-Batch" in Table 1. - ''' + """ def model_fn_body(self, features): hparams = self._hparams @@ -93,14 +109,13 @@ def model_fn_body(self, features): # filters then a batch norm. Instead we will rely on the one in # SmallImageModality, which seems to instead use a layer norm. x = inputs - mode = hparams.mode - with tf.variable_scope('shake_shake_stage_1'): + with tf.variable_scope("shake_shake_stage_1"): x = shake_shake_stage(x, blocks_per_stage, hparams.base_filters, 1, hparams) - with tf.variable_scope('shake_shake_stage_2'): + with tf.variable_scope("shake_shake_stage_2"): x = shake_shake_stage(x, blocks_per_stage, hparams.base_filters * 2, 2, hparams) - with tf.variable_scope('shake_shake_stage_3'): + with tf.variable_scope("shake_shake_stage_3"): x = shake_shake_stage(x, blocks_per_stage, hparams.base_filters * 4, 2, hparams) @@ -117,6 +132,7 @@ def model_fn_body(self, features): @registry.register_hparams def shakeshake_cifar10(): + """Parameters for CIFAR-10.""" hparams = common_hparams.basic_params1() # This leads to effective batch size 128 when number of GPUs is 1 hparams.batch_size = 4096 * 8 @@ -138,6 +154,6 @@ def shakeshake_cifar10(): hparams.weight_decay = 3.0 hparams.optimizer = "Momentum" hparams.optimizer_momentum_momentum = 0.9 - hparams.add_hparam('base_filters', 16) - hparams.add_hparam('shakeshake_type', 'batch') + hparams.add_hparam("base_filters", 16) + hparams.add_hparam("shakeshake_type", "batch") return hparams diff --git a/tensor2tensor/models/slicenet.py b/tensor2tensor/models/slicenet.py index 77659e8ef..43913eab1 100644 --- a/tensor2tensor/models/slicenet.py +++ b/tensor2tensor/models/slicenet.py @@ -1,4 +1,4 @@ -# Copyright 2017 Google Inc. +# Copyright 2017 The Tensor2Tensor Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tensor2tensor/models/slicenet_test.py b/tensor2tensor/models/slicenet_test.py index db563b481..911953445 100644 --- a/tensor2tensor/models/slicenet_test.py +++ b/tensor2tensor/models/slicenet_test.py @@ -1,4 +1,4 @@ -# Copyright 2017 Google Inc. +# Copyright 2017 The Tensor2Tensor Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tensor2tensor/models/transformer.py b/tensor2tensor/models/transformer.py old mode 100755 new mode 100644 index 73542bd5a..b341d6fe0 --- a/tensor2tensor/models/transformer.py +++ b/tensor2tensor/models/transformer.py @@ -1,4 +1,4 @@ -# Copyright 2017 Google Inc. +# Copyright 2017 The Tensor2Tensor Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -353,15 +353,6 @@ def transformer_parsing_base(): return hparams -@registry.register_hparams -def transformer_parsing_ice(): - """Hparams for parsing Icelandic text.""" - hparams = transformer_base_single_gpu() - hparams.batch_size = 4096 - hparams.shared_embedding_and_softmax_weights = int(False) - return hparams - - @registry.register_hparams def transformer_parsing_big(): """HParams for parsing on wsj semi-supervised.""" @@ -375,6 +366,15 @@ def transformer_parsing_big(): return hparams +@registry.register_hparams +def transformer_parsing_ice(): + """Hparams for parsing Icelandic text.""" + hparams = transformer_base_single_gpu() + hparams.batch_size = 4096 + hparams.shared_embedding_and_softmax_weights = int(False) + return hparams + + @registry.register_hparams def transformer_tiny(): hparams = transformer_base() diff --git a/tensor2tensor/models/transformer_alternative.py b/tensor2tensor/models/transformer_alternative.py index e50cba86f..aed074d56 100644 --- a/tensor2tensor/models/transformer_alternative.py +++ b/tensor2tensor/models/transformer_alternative.py @@ -1,4 +1,4 @@ -# Copyright 2017 Google Inc. +# Copyright 2017 The Tensor2Tensor Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tensor2tensor/models/transformer_test.py b/tensor2tensor/models/transformer_test.py index 9535558a4..ca099c653 100644 --- a/tensor2tensor/models/transformer_test.py +++ b/tensor2tensor/models/transformer_test.py @@ -1,4 +1,4 @@ -# Copyright 2017 Google Inc. +# Copyright 2017 The Tensor2Tensor Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tensor2tensor/models/xception.py b/tensor2tensor/models/xception.py index d28a1628e..d3c5a2690 100644 --- a/tensor2tensor/models/xception.py +++ b/tensor2tensor/models/xception.py @@ -1,4 +1,4 @@ -# Copyright 2017 Google Inc. +# Copyright 2017 The Tensor2Tensor Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tensor2tensor/models/xception_test.py b/tensor2tensor/models/xception_test.py index cd158b852..aa5c1c034 100644 --- a/tensor2tensor/models/xception_test.py +++ b/tensor2tensor/models/xception_test.py @@ -1,4 +1,4 @@ -# Copyright 2017 Google Inc. +# Copyright 2017 The Tensor2Tensor Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tensor2tensor/utils/__init__.py b/tensor2tensor/utils/__init__.py index 27d533abc..eff6a2b14 100644 --- a/tensor2tensor/utils/__init__.py +++ b/tensor2tensor/utils/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2017 Google Inc. +# Copyright 2017 The Tensor2Tensor Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tensor2tensor/utils/avg_checkpoints.py b/tensor2tensor/utils/avg_checkpoints.py index 01850aeae..a84750310 100644 --- a/tensor2tensor/utils/avg_checkpoints.py +++ b/tensor2tensor/utils/avg_checkpoints.py @@ -1,4 +1,4 @@ -# Copyright 2017 Google Inc. +# Copyright 2017 The Tensor2Tensor Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tensor2tensor/utils/beam_search.py b/tensor2tensor/utils/beam_search.py index eacbf467f..3a511907d 100644 --- a/tensor2tensor/utils/beam_search.py +++ b/tensor2tensor/utils/beam_search.py @@ -1,4 +1,4 @@ -# Copyright 2017 Google Inc. +# Copyright 2017 The Tensor2Tensor Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tensor2tensor/utils/beam_search_test.py b/tensor2tensor/utils/beam_search_test.py index 33439b41f..e084f1f0e 100644 --- a/tensor2tensor/utils/beam_search_test.py +++ b/tensor2tensor/utils/beam_search_test.py @@ -1,4 +1,4 @@ -# Copyright 2017 Google Inc. +# Copyright 2017 The Tensor2Tensor Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tensor2tensor/utils/bleu_hook.py b/tensor2tensor/utils/bleu_hook.py index 012215cff..155b10c72 100644 --- a/tensor2tensor/utils/bleu_hook.py +++ b/tensor2tensor/utils/bleu_hook.py @@ -1,4 +1,4 @@ -# Copyright 2017 Google Inc. +# Copyright 2017 The Tensor2Tensor Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tensor2tensor/utils/bleu_hook_test.py b/tensor2tensor/utils/bleu_hook_test.py index 1838affd6..8092ab979 100644 --- a/tensor2tensor/utils/bleu_hook_test.py +++ b/tensor2tensor/utils/bleu_hook_test.py @@ -1,4 +1,4 @@ -# Copyright 2017 Google Inc. +# Copyright 2017 The Tensor2Tensor Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tensor2tensor/utils/data_reader.py b/tensor2tensor/utils/data_reader.py index 7b0663cf8..a3e9835ac 100644 --- a/tensor2tensor/utils/data_reader.py +++ b/tensor2tensor/utils/data_reader.py @@ -1,4 +1,4 @@ -# Copyright 2017 Google Inc. +# Copyright 2017 The Tensor2Tensor Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -27,6 +27,7 @@ from tensor2tensor.data_generators import problem_hparams from tensor2tensor.models import common_layers +from tensor2tensor.utils import registry import tensorflow as tf @@ -352,7 +353,10 @@ def get_datasets(problems, data_dir, mode): """Return the location of a dataset for a given mode.""" datasets = [] for problem in problems.split("-"): - problem, _, _ = problem_hparams.parse_problem_name(problem) + try: + problem = registry.problem(problem).dataset_filename() + except ValueError: + problem, _, _ = problem_hparams.parse_problem_name(problem) path = os.path.join(data_dir, problem) if mode == tf.contrib.learn.ModeKeys.TRAIN: datasets.append("%s-train*" % path) diff --git a/tensor2tensor/utils/data_reader_test.py b/tensor2tensor/utils/data_reader_test.py index 7386d3ea0..18507ed06 100644 --- a/tensor2tensor/utils/data_reader_test.py +++ b/tensor2tensor/utils/data_reader_test.py @@ -1,4 +1,4 @@ -# Copyright 2017 Google Inc. +# Copyright 2017 The Tensor2Tensor Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tensor2tensor/utils/expert_utils.py b/tensor2tensor/utils/expert_utils.py index 0bd69599d..c3becbfb4 100644 --- a/tensor2tensor/utils/expert_utils.py +++ b/tensor2tensor/utils/expert_utils.py @@ -1,4 +1,4 @@ -# Copyright 2017 Google Inc. +# Copyright 2017 The Tensor2Tensor Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tensor2tensor/utils/metrics.py b/tensor2tensor/utils/metrics.py index 97da4cd35..cf66f6af8 100644 --- a/tensor2tensor/utils/metrics.py +++ b/tensor2tensor/utils/metrics.py @@ -1,4 +1,4 @@ -# Copyright 2017 Google Inc. +# Copyright 2017 The Tensor2Tensor Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tensor2tensor/utils/metrics_test.py b/tensor2tensor/utils/metrics_test.py index 0472d4f21..de72d797f 100644 --- a/tensor2tensor/utils/metrics_test.py +++ b/tensor2tensor/utils/metrics_test.py @@ -1,4 +1,4 @@ -# Copyright 2017 Google Inc. +# Copyright 2017 The Tensor2Tensor Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tensor2tensor/utils/modality.py b/tensor2tensor/utils/modality.py index 856c1a97f..3ac6153b7 100644 --- a/tensor2tensor/utils/modality.py +++ b/tensor2tensor/utils/modality.py @@ -1,4 +1,4 @@ -# Copyright 2017 Google Inc. +# Copyright 2017 The Tensor2Tensor Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tensor2tensor/utils/registry.py b/tensor2tensor/utils/registry.py index 6c04cf22d..5a8823510 100644 --- a/tensor2tensor/utils/registry.py +++ b/tensor2tensor/utils/registry.py @@ -1,4 +1,4 @@ -# Copyright 2017 Google Inc. +# Copyright 2017 The Tensor2Tensor Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -54,6 +54,7 @@ class MyModel(T2TModel): _MODELS = {} _HPARAMS = {} _RANGED_HPARAMS = {} +_PROBLEMS = {} class Modalities(object): @@ -184,6 +185,63 @@ def list_ranged_hparams(): return list(_RANGED_HPARAMS) +def register_problem(name=None): + """Register a Problem. name defaults to cls name snake-cased.""" + + def decorator(p_cls, registration_name=None): + """Registers & returns p_cls with registration_name or default name.""" + p_name = registration_name or _default_name(p_cls) + if p_name in _PROBLEMS: + raise ValueError("Problem %s already registered." % p_name) + + _PROBLEMS[p_name] = p_cls + p_cls.name = p_name + return p_cls + + # Handle if decorator was used without parens + if callable(name): + p_cls = name + return decorator(p_cls, registration_name=_default_name(p_cls)) + + return lambda p_cls: decorator(p_cls, name) + + +def problem(name): + """Retrieve a problem by name.""" + + def parse_problem_name(problem_name): + """Determines if problem_name specifies a copy and/or reversal. + + Args: + problem_name: A string containing a single problem name from + FLAGS.problems. + + Returns: + base_name: A string with the base problem name. + was_reversed: A boolean. + was_copy: A boolean. + """ + # Recursively strip tags until we reach a base name. + if len(problem_name) > 4 and problem_name[-4:] == "_rev": + base, _, was_copy = parse_problem_name(problem_name[:-4]) + return base, True, was_copy + elif len(problem_name) > 5 and problem_name[-5:] == "_copy": + base, was_reversed, _ = parse_problem_name(problem_name[:-5]) + return base, was_reversed, True + else: + return problem_name, False, False + + base_name, was_reversed, was_copy = parse_problem_name(name) + + if base_name not in _PROBLEMS: + raise ValueError("Problem %s never registered." % name) + return _PROBLEMS[base_name](was_reversed, was_copy) + + +def list_problems(): + return list(_PROBLEMS) + + def _internal_get_modality(name, mod_collection, collection_str): if name is None: name = "default" @@ -345,11 +403,16 @@ def help_string(): RangedHParams: %s Modalities: %s + + Problems: %s """ - m, rhp, mod = [ + m, rhp, mod, probs = [ sorted(entries) - for entries in [list_models(), - list_ranged_hparams(), - list_modalities()] + for entries in [ + list_models(), + list_ranged_hparams(), + list_modalities(), + list_problems() + ] ] - return help_str % (m, _hparams_help_string(), rhp, mod) + return help_str % (m, _hparams_help_string(), rhp, mod, probs) diff --git a/tensor2tensor/utils/registry_test.py b/tensor2tensor/utils/registry_test.py index 84903b141..1f4436b0c 100644 --- a/tensor2tensor/utils/registry_test.py +++ b/tensor2tensor/utils/registry_test.py @@ -1,4 +1,4 @@ -# Copyright 2017 Google Inc. +# Copyright 2017 The Tensor2Tensor Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tensor2tensor/utils/t2t_model.py b/tensor2tensor/utils/t2t_model.py index a991d3614..2a271afbf 100644 --- a/tensor2tensor/utils/t2t_model.py +++ b/tensor2tensor/utils/t2t_model.py @@ -1,4 +1,4 @@ -# Copyright 2017 Google Inc. +# Copyright 2017 The Tensor2Tensor Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -115,6 +115,9 @@ def _create_modalities(self, problem_hparams, hparams): input_modality = {} for f, modality_spec in six.iteritems(problem_hparams.input_modality): if isinstance(modality_spec, modality.Modality): + # This function has been previously run (e.g. for training and now is + # being called for eval) and the modalities have already been + # constructed. Return. return if f in input_modality_overrides: _warn_changed_modality_type(input_modality_overrides[f], diff --git a/tensor2tensor/utils/trainer_utils.py b/tensor2tensor/utils/trainer_utils.py old mode 100755 new mode 100644 index 6b9f66c92..b5894904d --- a/tensor2tensor/utils/trainer_utils.py +++ b/tensor2tensor/utils/trainer_utils.py @@ -1,4 +1,4 @@ -# Copyright 2017 Google Inc. +# Copyright 2017 The Tensor2Tensor Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -32,17 +32,19 @@ from six.moves import xrange # pylint: enable=redefined-builtin +from tensor2tensor.data_generators import all_problems # pylint: disable=unused-import from tensor2tensor.data_generators import problem_hparams -from tensor2tensor.data_generators.text_encoder import EOS_TOKEN +from tensor2tensor.data_generators import text_encoder from tensor2tensor.models import models # pylint: disable=unused-import from tensor2tensor.utils import data_reader from tensor2tensor.utils import expert_utils as eu from tensor2tensor.utils import metrics from tensor2tensor.utils import registry +from tensor2tensor.utils import yellowfin + import tensorflow as tf from tensorflow.contrib.learn.python.learn import learn_runner from tensorflow.python.ops import init_ops -from tensor2tensor.utils.yellowfin import YellowFinOptimizer # Number of samples to draw for an image input (in such cases as captioning) IMAGE_DECODE_LENGTH = 100 @@ -91,6 +93,8 @@ flags.DEFINE_integer("worker_gpu", 1, "How many GPUs to use.") flags.DEFINE_integer("worker_replicas", 1, "How many workers to use.") flags.DEFINE_integer("worker_id", 0, "Which worker task are we.") +flags.DEFINE_float("worker_gpu_memory_fraction", 1., + "Fraction of GPU memory to allocate.") flags.DEFINE_integer("ps_gpu", 0, "How many GPUs to use per ps.") flags.DEFINE_string("gpu_order", "", "Optional order for daisy-chaining gpus." " e.g. \"1 3 2 4\"") @@ -102,7 +106,6 @@ "In inference, use last position only for speedup.") flags.DEFINE_bool("decode_interactive", False, "Interactive local inference mode.") -flags.DEFINE_bool("decode_endless", False, "Run decoding endlessly. Temporary.") flags.DEFINE_bool("decode_save_images", False, "Save inference input images.") flags.DEFINE_string("decode_from_file", None, "Path to decode file") flags.DEFINE_string("decode_to_file", None, "Path to inference output file") @@ -121,12 +124,12 @@ def _save_until_eos(hyp): - """ Strips everything after the first token, which is normally 1 """ + """Strips everything after the first token, which is normally 1.""" try: - index = list(hyp).index(EOS_TOKEN) + index = list(hyp).index(text_encoder.EOS_TOKEN) return hyp[0:index] except ValueError: - # No EOS_TOKEN: return the array as-is + # No EOS_TOKEN: return the array as-is. return hyp @@ -187,6 +190,7 @@ def create_experiment_components(hparams, output_dir, data_dir, model_name): config=tf.contrib.learn.RunConfig( master=FLAGS.master, model_dir=output_dir, + gpu_memory_fraction=FLAGS.worker_gpu_memory_fraction, session_config=session_config(), keep_checkpoint_max=FLAGS.keep_checkpoint_max)) # Store the hparams in the estimator as well @@ -221,10 +225,15 @@ def create_hparams(params_id, data_dir): hparams = hparams.parse(FLAGS.hparams) # Add hparams for the problems - hparams.problems = [ - problem_hparams.problem_hparams(problem, hparams) - for problem in FLAGS.problems.split("-") - ] + hparams.problems = [] + for problem_name in FLAGS.problems.split("-"): + try: + problem = registry.problem(problem_name) + p_hparams = problem.internal_hparams(hparams) + except ValueError: + p_hparams = problem_hparams.problem_hparams(problem_name, hparams) + + hparams.problems.append(p_hparams) return hparams @@ -280,6 +289,7 @@ def session_config(): """The TensorFlow Session config to use.""" graph_options = tf.GraphOptions(optimizer_options=tf.OptimizerOptions( opt_level=tf.OptimizerOptions.L1, do_function_inlining=False)) + if FLAGS.experimental_optimize_placement: rewrite_options = tf.RewriterConfig(optimize_tensor_layout=True) rewrite_options.optimizers.append("pruning") @@ -287,8 +297,13 @@ def session_config(): rewrite_options.optimizers.append("layout") graph_options = tf.GraphOptions( rewrite_options=rewrite_options, infer_shapes=True) - config = tf.ConfigProto( - allow_soft_placement=True, graph_options=graph_options) + + gpu_options = tf.GPUOptions( + per_process_gpu_memory_fraction=FLAGS.worker_gpu_memory_fraction) + + config = tf.ConfigProto(allow_soft_placement=True, + graph_options=graph_options, + gpu_options=gpu_options) return config @@ -357,7 +372,6 @@ def learning_rate_decay(): lambda: decay, name="learning_rate_decay_warump_cond") - def model_fn(features, targets, mode): """Creates the prediction, loss, and train ops. @@ -578,7 +592,7 @@ def decode_from_dataset(estimator): num_datashards=data_parallelism().n, fixed_problem=i) result_iter = estimator.predict( - input_fn=infer_input_fn, as_iterable=FLAGS.decode_endless) + input_fn=infer_input_fn, as_iterable=False) def log_fn(inputs, targets, @@ -610,11 +624,7 @@ def log_fn(inputs, target_file.write(decoded_targets + "\n") # The function predict() returns an iterable over the network's - # predictions from the test input. if FLAGS.decode_endless is set, it will - # decode over the dev set endlessly, looping over it. We use the returned - # iterator to log inputs and decodes. - if FLAGS.decode_endless: - tf.logging.info("Warning: Decoding endlessly") + # predictions from the test input. We use it to log inputs and decodes. for j, result in enumerate(result_iter): inputs, targets, outputs = (result["inputs"], result["targets"], result["outputs"]) @@ -650,7 +660,8 @@ def log_fn(inputs, outputs): decoded_inputs = inputs_vocab.decode(_save_until_eos(inputs.flatten())) tf.logging.info("Inference results INPUT: %s" % decoded_inputs) - decoded_outputs = targets_vocab.decode(_save_until_eos(outputs.flatten())) + decoded_outputs = targets_vocab.decode( + _save_until_eos(outputs.flatten())) tf.logging.info("Inference results OUTPUT: %s" % decoded_outputs) return decoded_outputs @@ -701,13 +712,14 @@ def decode_interactively(estimator): scores = np.split(result["scores"], FLAGS.decode_beam_size, axis=0) for k, beam in enumerate(beams): tf.logging.info("BEAM %d:" % k) + beam_string = targets_vocab.decode(_save_until_eos(beam.flatten())) if scores is not None: - tf.logging.info("%s\tScore:%f" % - (targets_vocab.decode(_save_until_eos(beam.flatten())), scores[k])) + tf.logging.info("%s\tScore:%f" % (beam_string, scores[k])) else: - tf.logging.info(targets_vocab.decode(_save_until_eos(beam.flatten()))) + tf.logging.info(beam_string) else: - tf.logging.info(targets_vocab.decode(_save_until_eos(result["outputs"].flatten()))) + tf.logging.info(targets_vocab.decode(_save_until_eos( + result["outputs"].flatten()))) def _decode_batch_input_fn(problem_id, num_decode_batches, sorted_inputs, @@ -717,13 +729,13 @@ def _decode_batch_input_fn(problem_id, num_decode_batches, sorted_inputs, # you'll see it in the first batch sorted_inputs.reverse() for b in range(num_decode_batches): - tf.logging.info("Deocding batch %d" % b) + tf.logging.info("Decoding batch %d" % b) batch_length = 0 batch_inputs = [] for inputs in sorted_inputs[b * FLAGS.decode_batch_size: - (b + 1) * FLAGS.decode_batch_size]: + (b + 1) * FLAGS.decode_batch_size]: input_ids = vocabulary.encode(inputs) - input_ids.append(EOS_TOKEN) + input_ids.append(text_encoder.EOS_TOKEN) batch_inputs.append(input_ids) if len(input_ids) > batch_length: batch_length = len(input_ids) @@ -816,7 +828,7 @@ def _interactive_input_fn(hparams): if input_type == "text": input_ids = vocabulary.encode(input_string) if has_input: - input_ids.append(EOS_TOKEN) + input_ids.append(text_encoder.EOS_TOKEN) x = [num_samples, decode_length, len(input_ids)] + input_ids assert len(x) < const_array_size x += [0] * (const_array_size - len(x)) @@ -1137,7 +1149,7 @@ def __init__(self, optimizer_name, lr, hparams): lr, momentum=hparams.optimizer_momentum_momentum) elif optimizer_name == "YellowFin": tf.logging.info("Init YellowFin Optimizer.") - self._opt = YellowFinOptimizer( + self._opt = yellowfin.YellowFinOptimizer( learning_rate=lr, momentum=hparams.optimizer_momentum_momentum) else: self._opt = tf.contrib.layers.OPTIMIZER_CLS_NAMES[optimizer_name](lr) diff --git a/tensor2tensor/utils/trainer_utils_test.py b/tensor2tensor/utils/trainer_utils_test.py index d621b6fbc..3ed86952b 100644 --- a/tensor2tensor/utils/trainer_utils_test.py +++ b/tensor2tensor/utils/trainer_utils_test.py @@ -1,4 +1,4 @@ -# Copyright 2017 Google Inc. +# Copyright 2017 The Tensor2Tensor Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -22,6 +22,7 @@ from tensor2tensor.data_generators import algorithmic from tensor2tensor.data_generators import generator_utils +from tensor2tensor.models import transformer from tensor2tensor.utils import registry from tensor2tensor.utils import trainer_utils as utils # pylint: disable=unused-import @@ -30,20 +31,38 @@ FLAGS = tf.flags.FLAGS +@registry.register_problem +class TinyAlgo(algorithmic.AlgorithmicIdentityBinary40): + + def generate_data(self, data_dir, _): + generator_utils.generate_files( + algorithmic.identity_generator(self.num_symbols, 40, 100000), + self.training_filepaths(data_dir, 1, shuffled=True), 100) + generator_utils.generate_files( + algorithmic.identity_generator(self.num_symbols, 400, 10000), + self.dev_filepaths(data_dir, 1, shuffled=True), 100) + + +@registry.register_hparams +def transformer_test(): + hparams = transformer.transformer_base() + hparams.batch_size = 10 + hparams.hidden_size = 10 + hparams.num_hidden_layers = 1 + hparams.num_heads = 2 + hparams.max_length = 16 + return hparams + + class TrainerUtilsTest(tf.test.TestCase): @classmethod def setUpClass(cls): # Generate a small test dataset - FLAGS.problems = "algorithmic_addition_binary40" + FLAGS.problems = "tiny_algo" TrainerUtilsTest.data_dir = tf.test.get_temp_dir() - gen = algorithmic.identity_generator(2, 10, 300) - train_filenames = generator_utils.train_data_filenames( - FLAGS.problems, TrainerUtilsTest.data_dir, 1) - dev_filenames = generator_utils.dev_data_filenames( - FLAGS.problems, TrainerUtilsTest.data_dir, 1) - generator_utils.generate_files(gen, train_filenames, 100) - generator_utils.generate_files(gen, dev_filenames, 100) + registry.problem(FLAGS.problems).generate_data(TrainerUtilsTest.data_dir, + None) def testModelsImported(self): models = registry.list_models() @@ -55,10 +74,7 @@ def testHParamsImported(self): def testSingleStep(self): model_name = "transformer" - FLAGS.hparams_set = "transformer_base" - # Shrink the test model down - FLAGS.hparams = ("batch_size=10,hidden_size=10,num_heads=2,max_length=16," - "num_hidden_layers=1") + FLAGS.hparams_set = "transformer_test" exp = utils.create_experiment( output_dir=tf.test.get_temp_dir(), data_dir=TrainerUtilsTest.data_dir, diff --git a/tensor2tensor/utils/yellowfin.py b/tensor2tensor/utils/yellowfin.py index b5cedf21b..6bbe31bf6 100644 --- a/tensor2tensor/utils/yellowfin.py +++ b/tensor2tensor/utils/yellowfin.py @@ -1,35 +1,27 @@ -# MIT License +# Copyright 2017 The Tensor2Tensor Authors. # -# Copyright (c) 2017 JianGoForIt +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: +# http://www.apache.org/licenses/LICENSE-2.0 # -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - -"""YellowFin for TensorFlow.""" +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""YellowFin for TensorFlow. Thanks Jian Zhang: zjian [@] stanford [.] edu.""" + +from __future__ import absolute_import +from __future__ import division from __future__ import print_function +# Dependency imports + import numpy as np -from math import ceil, floor import tensorflow as tf -from tensorflow.python.training import momentum -from tensorflow.python.ops import variable_scope -from tensorflow.python.ops import variables -from tensorflow.python.ops import state_ops from tensorflow.python.framework import ops @@ -41,8 +33,8 @@ class YellowFinOptimizer(tf.train.Optimizer): """Optimizer that implements the YellowFin algorithm. - See [Zhang et. al., 2017](https://arxiv.org/abs/1706.03471) - ([pdf](https://arxiv.org/pdf/1706.03471.pdf)). + + See [Zhang et. al., 2017](https://arxiv.org/abs/1706.03471) for details. """ def __init__(self, @@ -87,7 +79,7 @@ def __init__(self, self._lr = learning_rate self._mu = momentum - # Set lr and mu tensor + # Set lr and mu tensor. self._lr_var = tf.Variable(learning_rate, dtype=tf.float32, name="YF_lr", @@ -97,13 +89,13 @@ def __init__(self, name="YF_mu", trainable=False) - # Tuning factor for learning rates step or decaying scheme + # Tuning factor for learning rates step or decaying scheme. self.lr_factor = tf.Variable(1.0, dtype=tf.float32, name="YF_lr_factor", trainable=False) - # Gradient Clipping Threshold + # Gradient Clipping Threshold. if clip_thresh is not None: self._clip_thresh_var = tf.Variable(clip_thresh, dtype=tf.float32, @@ -112,63 +104,63 @@ def __init__(self, else: self._clip_thresh_var = None - # Set initial lr and mu for momentum + # Set initial lr and mu for momentum. self._lr_m = self._lr_var * self.lr_factor self._mu_m = self._mu_var + delta_mu - # Init momentum optimizer - self._momentum_optimizer = \ - tf.train.MomentumOptimizer(self._lr_m, self._mu_m) + # Init momentum optimizer. + self._momentum_optimizer = tf.train.MomentumOptimizer( + self._lr_m, self._mu_m) - # Moving average for statistics + # Moving average for statistics. self._beta = beta self._moving_averager = None - # Step counting + # Step counting. self._step = tf.Variable(0, dtype=tf.int32, name="YF_step", trainable=False) - # YF_step + 1 op + # YF_step + 1 op. self._increment_step_op = None - # For conditional tuning + # For conditional tuning. self._do_tune = tf.greater(self._step, tf.constant(0)) - # Moving-averages + # Moving-averages. self._zero_debias = zero_debias - # For curvature range + # For curvature range. self.curvature_window_width = curvature_window_width self._curv_win = None - # Gradients and Variables + # Gradients and Variables. self._grad = None self._vars = None - # Get per var g**2, norm**2 and mean(norm**2) + # Get per var g**2, norm**2 and mean(norm**2). self._grad_squared = None self._grad_norm_squared = None self._grad_norm_squared_avg = None - # Mean(grad) and Mean(grad**2) to compute Variance + # Mean(grad) and Mean(grad**2) to compute Variance. self._grad_avg = None self._grad_avg_squared = None - # Max and Min curvature variations + # Max and Min curvature variations. self._h_max_t = None self._h_min_t = None self._h_min = None self._h_max = None - # Gradient Expected Variance + # Gradient Expected Variance. self._grad_var = None - # Gradient Norm and Mean(Gradient Norm) + # Gradient Norm and Mean(Gradient Norm). self._grad_norm = None self._grad_norm_avg = None - # Distance to optimum and Mean(Distance to optimum) + # Distance to optimum and Mean(Distance to optimum). self._d_t = None self._dist_to_opt_avg = None @@ -177,31 +169,27 @@ def __init__(self, # and (zero_devias) moving-averages. self._moving_averager = None - def _curvature_range(self): - """Curvature range + """Curvature range. Returns: h_max_t, h_min_t ops """ - self._curv_win = \ - tf.Variable(np.zeros([self.curvature_window_width, ]), - dtype=tf.float32, - name="curv_win", - trainable=False) - - self._curv_win = \ - tf.scatter_update(self._curv_win, - self._step % self.curvature_window_width, - self._grad_norm_squared) + self._curv_win = tf.Variable(np.zeros([self.curvature_window_width,]), + dtype=tf.float32, + name="curv_win", + trainable=False) + + self._curv_win = tf.scatter_update(self._curv_win, + self._step % self.curvature_window_width, + self._grad_norm_squared) # Note here the iterations start from iteration 0 valid_window = tf.slice(self._curv_win, - tf.constant([0, ]), + tf.constant([0,]), tf.expand_dims( - tf.minimum( - tf.constant(self.curvature_window_width), - self._step + 1), - dim=0)) + tf.minimum( + tf.constant(self.curvature_window_width), + self._step + 1), axis=0)) self._h_min_t = tf.reduce_min(valid_window) self._h_max_t = tf.reduce_max(valid_window) @@ -212,24 +200,23 @@ def _curvature_range(self): self._h_min = tf.identity(self._moving_averager.average(self._h_min_t)) self._h_max = tf.identity(self._moving_averager.average(self._h_max_t)) curv_range_ops.append(avg_op) - return curv_range_ops # h_max_t, h_min_t - + return curv_range_ops # h_max_t, h_min_t def _grad_variance(self): - """Estimate of gradient Variance + """Estimate of gradient Variance. Returns: - C_t ops + C_t ops. """ grad_var_ops = [] tensor_to_avg = [] for t, g in zip(self._vars, self._grad): - if isinstance(g, ops.IndexedSlices): - tensor_to_avg.append( \ - tf.reshape(tf.unsorted_segment_sum(g.values, - g.indices, - g.dense_shape[0]), - shape=t.get_shape())) + if isinstance(g, tf.IndexedSlices): + tensor_to_avg.append( + tf.reshape(tf.unsorted_segment_sum(g.values, + g.indices, + g.dense_shape[0]), + shape=t.get_shape())) else: tensor_to_avg.append(g) avg_op = self._moving_averager.apply(tensor_to_avg) @@ -244,9 +231,8 @@ def _grad_variance(self): self._grad_var = self._grad_norm_squared_avg - self._grad_avg_squared return grad_var_ops # C_t - def _dist_to_opt(self): - """Distance to optimum + """Distance to optimum. Returns: D_t ops @@ -254,7 +240,7 @@ def _dist_to_opt(self): dist_to_opt_ops = [] # Running average of the norm of gradeint self._grad_norm = tf.sqrt(self._grad_norm_squared) - avg_op = self._moving_averager.apply([self._grad_norm, ]) + avg_op = self._moving_averager.apply([self._grad_norm,]) dist_to_opt_ops.append(avg_op) with tf.control_dependencies([avg_op]): self._grad_norm_avg = self._moving_averager.average(self._grad_norm) @@ -265,21 +251,19 @@ def _dist_to_opt(self): avg_op = self._moving_averager.apply([self._d_t]) dist_to_opt_ops.append(avg_op) with tf.control_dependencies([avg_op]): - self._dist_to_opt_avg = \ - tf.identity(self._moving_averager.average(self._d_t)) + self._dist_to_opt_avg = tf.identity( + self._moving_averager.average(self._d_t)) return dist_to_opt_ops # D_t - def _prepare_variables(self): - """Prepare Variables for YellowFin + """Prepare Variables for YellowFin. Returns: Grad**2, Norm, Norm**2, Mean(Norm**2) ops """ - self._moving_averager = \ - tf.train.ExponentialMovingAverage(decay=self._beta, - zero_debias=self._zero_debias) - assert self._grad != None and len(self._grad) > 0 + self._moving_averager = tf.train.ExponentialMovingAverage( + decay=self._beta, zero_debias=self._zero_debias) + assert self._grad # List for the returned Operations prepare_variables_op = [] @@ -293,39 +277,37 @@ def _prepare_variables(self): with ops.colocate_with(v): self._grad_squared.append(tf.square(g)) - # Norm squared - self._grad_norm_squared = [tf.reduce_sum(g_sq) \ - for g_sq in self._grad_squared] + # Norm squared. + self._grad_norm_squared = [tf.reduce_sum(g_sq) + for g_sq in self._grad_squared] # The following running average on squared norm of gradient # is shared by grad_var and dist_to_opt avg_op = self._moving_averager.apply(self._grad_norm_squared) with tf.control_dependencies([avg_op]): - self._grad_norm_squared_avg = \ - [self._moving_averager.average(val) for val in self._grad_norm_squared] + self._grad_norm_squared_avg = [self._moving_averager.average(val) + for val in self._grad_norm_squared] self._grad_norm_squared = tf.add_n(self._grad_norm_squared) self._grad_norm_squared_avg = tf.add_n(self._grad_norm_squared_avg) prepare_variables_op.append(avg_op) return tf.group(*prepare_variables_op) - def _get_lr_tensor(self): - """Get lr minimzing the surrogate + """Get lr minimzing the surrogate. Returns: - lr_t + The lr_t. """ - lr = (1.0 - tf.sqrt(self._mu) )**2 / self._h_min + lr = (1.0 - tf.sqrt(self._mu))**2 / self._h_min return lr - def _get_mu_tensor(self): - """Get the min mu which minimize the surrogate + """Get the min mu which minimize the surrogate. Returns: - mu_t + The mu_t. """ const_fact = self._dist_to_opt_avg**2 * self._h_min**2 / 2 / self._grad_var coef = tf.Variable([-1.0, 3.0, 0.0, 1.0], @@ -340,28 +322,23 @@ def _get_mu_tensor(self): stateful=False) # Filter out the correct root - root_idx = \ - tf.logical_and( + root_idx = tf.logical_and( tf.logical_and( - tf.greater(tf.real(roots), tf.constant(0.0)), - tf.less(tf.real(roots), tf.constant(1.0))), + tf.greater(tf.real(roots), tf.constant(0.0)), + tf.less(tf.real(roots), tf.constant(1.0))), tf.less(tf.abs(tf.imag(roots)), 1e-5)) # In case there are two duplicated roots satisfying the above condition root = tf.reshape(tf.gather(tf.gather(roots, tf.where(root_idx)), - tf.constant(0)), + tf.constant(0)), shape=[]) - # Never Evaluated - #tf.assert_equal(tf.size(root), tf.constant(1)) - dr = self._h_max / self._h_min mu = tf.maximum(tf.real(root)**2, ((tf.sqrt(dr) - 1)/(tf.sqrt(dr) + 1))**2) return mu - def _yellowfin(self): - """YellowFin auto-tuning optimizer based on momentum SGD + """YellowFin auto-tuning optimizer based on momentum SGD. Returns: YF ops @@ -371,16 +348,16 @@ def _yellowfin(self): Single-Step, Auto-Tuning) """ - # List for the returned Operations + # List for the returned Operations. yellowfin_ops = [] - # Curvature range ops + # Curvature range ops. curv_range_ops = self._curvature_range() yellowfin_ops += curv_range_ops - # Estimate of gradient Variance ops + # Estimate of gradient Variance ops. grad_var_ops = self._grad_variance() yellowfin_ops += grad_var_ops - # Distance to optimum ops + # Distance to optimum ops. dist_to_opt_ops = self._dist_to_opt() yellowfin_ops += dist_to_opt_ops @@ -388,15 +365,14 @@ def _yellowfin(self): # squared distance from the optimum of a local quadratic # approximation after a single step while keeping all directions in the # robust region. - self._mu = \ - tf.identity(tf.cond(self._do_tune, lambda: self._get_mu_tensor(), - lambda: self._mu_var)) + self._mu = tf.identity(tf.cond(self._do_tune, self._get_mu_tensor, + lambda: self._mu_var)) with tf.control_dependencies([self._mu]): - self._lr = \ - tf.identity(tf.cond(self._do_tune, lambda: self._get_lr_tensor(), - lambda: self._lr_var)) + self._lr = tf.identity(tf.cond(self._do_tune, + self._get_lr_tensor, + lambda: self._lr_var)) - # Tune learning rate and momentum + # Tune learning rate and momentum. with tf.control_dependencies([self._mu, self._lr]): self._mu = self._beta * self._mu_var + (1 - self._beta) * self._mu self._lr = self._beta * self._lr_var + (1 - self._beta) * self._lr @@ -406,9 +382,8 @@ def _yellowfin(self): yellowfin_ops = tf.group(*yellowfin_ops) return yellowfin_ops - def apply_gradients(self, grads_and_vars, global_step=None, name=None): - """Applying gradients aand tune hyperparams with YellowFin + """Applying gradients aand tune hyperparams with YellowFin. Args: grads_and_vars: List of (gradient, variable) pairs as returned by @@ -429,22 +404,20 @@ def apply_gradients(self, grads_and_vars, global_step=None, name=None): self._grad, self._vars = zip(*[(g, t) for g, t in grads_and_vars if g is not None]) - # Var Update with Momentum + # Var update with Momentum. with tf.variable_scope("apply_updates"): # Gradient Clipping? if self._clip_thresh_var is not None: - self._grads_clip, self._grads_norm = \ - tf.clip_by_global_norm(self._grad, self._clip_thresh_var) + self._grads_clip, self._grads_norm = tf.clip_by_global_norm( + self._grad, self._clip_thresh_var) - apply_grad_op = \ - self._momentum_optimizer.apply_gradients( \ + apply_grad_op = self._momentum_optimizer.apply_gradients( zip(self._grads_clip, self._vars), global_step=global_step) else: - apply_grad_op = \ - self._momentum_optimizer.apply_gradients( \ + apply_grad_op = self._momentum_optimizer.apply_gradients( zip(self._grad, self._vars), global_step=global_step) - # Begin lr and mu tuning + # Begin lr and mu tuning. with tf.variable_scope("prepare_yellowFin_variables"): prepare_variables_op = self._prepare_variables() @@ -452,22 +425,14 @@ def apply_gradients(self, grads_and_vars, global_step=None, name=None): with tf.control_dependencies([prepare_variables_op]): yellowfin_op = self._yellowfin() - # Update YellowFin step variable + # Update YellowFin step variable. with tf.control_dependencies([yellowfin_op]): - self._increment_step_op = state_ops.assign_add(self._step, 1).op - - # # Global_step variable Update. Commented because the update is made by self._momentum_optimizer - # if global_step is not None: - # with tf.control_dependencies([yellowfin_op]): - # with ops.colocate_with(global_step): - # global_step_op = state_ops.assign_add(global_step, 1).op + self._increment_step_op = tf.assign_add(self._step, 1).op return tf.group(apply_grad_op, prepare_variables_op, yellowfin_op, self._increment_step_op) - # global_step_op) - def compute_gradients(self, loss, @@ -478,7 +443,7 @@ def compute_gradients(self, colocate_gradients_with_ops=False, name=None, grad_loss=None): - """Compute gradients through momentum optimizer + """Compute gradients through momentum optimizer. Args: loss: A Tensor containing the value to minimize. @@ -501,14 +466,13 @@ def compute_gradients(self, A list of (gradient, variable) pairs. Variable is always present, but gradient can be None. """ - return self._momentum_optimizer.compute_gradients( \ - loss, - var_list=var_list, - gate_gradients=gate_gradients, - aggregation_method=aggregation_method, - colocate_gradients_with_ops=colocate_gradients_with_ops, - grad_loss=grad_loss) - + return self._momentum_optimizer.compute_gradients( + loss, + var_list=var_list, + gate_gradients=gate_gradients, + aggregation_method=aggregation_method, + colocate_gradients_with_ops=colocate_gradients_with_ops, + grad_loss=grad_loss) def minimize(self, loss, @@ -519,7 +483,8 @@ def minimize(self, colocate_gradients_with_ops=False, name=None, grad_loss=None): - """Adapted from Tensorflow Optimizer base class member function: + """Adapted from Tensorflow Optimizer base class member function. + Add operations to minimize `loss` by updating `var_list`. This method simply combines calls `compute_gradients()` and `apply_gradients()`. If you want to process the gradient before applying @@ -545,9 +510,11 @@ def minimize(self, Returns: An Operation that updates the variables in var_list. If global_step was not None, that operation also increments global_step. + + Raises: + ValueError: if no gradients are provided for any variable. """ - grads_and_vars = \ - self._optimizer.compute_gradients( \ + grads_and_vars = self._optimizer.compute_gradients( loss, var_list=var_list, gate_gradients=gate_gradients, diff --git a/tensor2tensor/utils/yellowfin_test.py b/tensor2tensor/utils/yellowfin_test.py index c4a318990..c4727175b 100644 --- a/tensor2tensor/utils/yellowfin_test.py +++ b/tensor2tensor/utils/yellowfin_test.py @@ -1,44 +1,41 @@ -# MIT License +# Copyright 2017 The Tensor2Tensor Authors. # -# Copyright (c) 2017 JianGoForIt +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: +# http://www.apache.org/licenses/LICENSE-2.0 # -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """YellowFin Test Module for TensorFlow.""" -#import os -# os.environ['TF_CPP_MIN_LOG_LEVEL']='2' -import tensorflow as tf +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# Dependency imports + import numpy as np + from tensor2tensor.utils.yellowfin import YellowFinOptimizer -from tensorflow.python.ops import variables -import time + +import tensorflow as tf n_dim = 1000000 -n_iter = 50 +n_iter = 0 -class TrainerUtilsTest(tf.test.TestCase): - def tuneEverything(self, x0squared, C, T, gmin, gmax): +class YellowFinTest(tf.test.TestCase): + + def tuneEverything(self, x0squared, c, t, gmin, gmax): # First tune based on dynamic range - if C == 0: + if c == 0: dr = gmax / gmin mustar = ((np.sqrt(dr) - 1) / (np.sqrt(dr) + 1))**2 alpha_star = (1 + np.sqrt(mustar))**2/gmax @@ -46,7 +43,7 @@ def tuneEverything(self, x0squared, C, T, gmin, gmax): return alpha_star, mustar dist_to_opt = x0squared - grad_var = C + grad_var = c max_curv = gmax min_curv = gmin const_fact = dist_to_opt * min_curv**2 / 2 / grad_var @@ -63,31 +60,29 @@ def tuneEverything(self, x0squared, C, T, gmin, gmax): mu = max(((np.sqrt(dr) - 1) / (np.sqrt(dr) + 1))**2, root**2) lr_min = (1 - np.sqrt(mu))**2 / min_curv - lr_max = (1 + np.sqrt(mu))**2 / max_curv alpha_star = lr_min mustar = mu return alpha_star, mustar - def testMeasurement(self): opt = YellowFinOptimizer(zero_debias=False) - w = tf.Variable(np.ones([n_dim, ]), + w = tf.Variable(np.ones([n_dim,]), dtype=tf.float32, name="w", trainable=True) - b = tf.Variable(np.ones([1, ], dtype=np.float32), + b = tf.Variable(np.ones([1,], dtype=np.float32), dtype=tf.float32, name="b", trainable=True) x = tf.constant(np.ones([n_dim,], dtype=np.float32), dtype=tf.float32) - loss = tf.multiply(w, x) + b + _ = tf.multiply(w, x) + b # loss tvars = tf.trainable_variables() - w_grad_val = tf.placeholder(tf.float32, shape=(n_dim, )) - b_grad_val = tf.placeholder(tf.float32, shape=(1, )) + w_grad_val = tf.placeholder(tf.float32, shape=(n_dim,)) + b_grad_val = tf.placeholder(tf.float32, shape=(1,)) apply_op = opt.apply_gradients(zip([w_grad_val, b_grad_val], tvars)) init_op = tf.global_variables_initializer() @@ -100,8 +95,8 @@ def testMeasurement(self): g_avg = 0.0 target_dist = 0.0 for i in range(n_iter): - feed_dict = {w_grad_val: (i + 1) * np.ones([n_dim, ], dtype=np.float32), - b_grad_val: (i + 1) * np.ones([1, ], dtype=np.float32)} + feed_dict = {w_grad_val: (i + 1) * np.ones([n_dim,], dtype=np.float32), + b_grad_val: (i + 1) * np.ones([1,], dtype=np.float32)} res = sess.run([opt._curv_win, opt._h_max, opt._h_min, @@ -109,49 +104,44 @@ def testMeasurement(self): opt._dist_to_opt_avg, apply_op], feed_dict=feed_dict) - g_norm_squared_avg = 0.999 * g_norm_squared_avg \ - + 0.001 * np.sum(((i + 1) * np.ones([n_dim + 1, ]))**2) - g_norm_avg = 0.999 * g_norm_avg \ - + 0.001 * np.linalg.norm((i + 1)*np.ones([n_dim + 1, ])) + g_norm_squared_avg = ( + 0.999 * g_norm_squared_avg + + 0.001 * np.sum(((i + 1) * np.ones([n_dim + 1,]))**2)) + g_norm_avg = (0.999 * g_norm_avg + + 0.001 * np.linalg.norm((i + 1)*np.ones([n_dim + 1,]))) g_avg = 0.999 * g_avg + 0.001 * (i + 1) target_h_max = 0.999 * target_h_max + 0.001 * (i + 1)**2*(n_dim + 1) - target_h_min = 0.999 * target_h_min + \ - 0.001 * max(1, i + 2 - 20)**2 * (n_dim + 1) + target_h_min = (0.999 * target_h_min + + 0.001 * max(1, i + 2 - 20)**2 * (n_dim + 1)) target_var = g_norm_squared_avg - g_avg**2 * (n_dim + 1) - target_dist = 0.999 * target_dist + \ - 0.001 * g_norm_avg / g_norm_squared_avg + target_dist = (0.999 * target_dist + + 0.001 * g_norm_avg / g_norm_squared_avg) - # print "iter ", i, " h max ", res[1], target_h_max, " h min ", res[2], target_h_min, \ - # " var ", res[3], target_var, " dist ", res[4], target_dist assert np.abs(target_h_max - res[1]) < np.abs(target_h_max) * 1e-3 assert np.abs(target_h_min - res[2]) < np.abs(target_h_min) * 1e-3 assert np.abs(target_var - res[3]) < np.abs(res[3]) * 1e-3 assert np.abs(target_dist - res[4]) < np.abs(res[4]) * 1e-3 - print "[Test-INFO] Sync measurement test passed!" - def testLrMu(self): opt = YellowFinOptimizer(learning_rate=0.5, momentum=0.5, zero_debias=False) - w = tf.Variable(np.ones([n_dim, ]), + w = tf.Variable(np.ones([n_dim,]), dtype=tf.float32, name="w", trainable=True) - b = tf.Variable(np.ones([1, ], - dtype=np.float32), + b = tf.Variable(np.ones([1,], + dtype=np.float32), dtype=tf.float32, name="b", trainable=True) - x = tf.constant(np.ones([n_dim, ], - dtype=np.float32), - dtype=tf.float32) - loss = tf.multiply(w, x) + b + x = tf.constant(np.ones([n_dim,], dtype=np.float32), dtype=tf.float32) + _ = tf.multiply(w, x) + b # loss tvars = tf.trainable_variables() - w_grad_val = tf.Variable(np.zeros([n_dim, ]), - dtype=tf.float32, - trainable=False) - b_grad_val = tf.Variable(np.zeros([1, ]), + w_grad_val = tf.Variable(np.zeros([n_dim,]), + dtype=tf.float32, + trainable=False) + b_grad_val = tf.Variable(np.zeros([1,]), dtype=tf.float32, trainable=False) apply_op = opt.apply_gradients(zip([w_grad_val, b_grad_val], tvars)) @@ -169,9 +159,9 @@ def testLrMu(self): target_mu = 0.5 for i in range(n_iter): - sess.run(tf.assign(w_grad_val, (i + 1) * np.ones([n_dim, ], + sess.run(tf.assign(w_grad_val, (i + 1) * np.ones([n_dim,], dtype=np.float32))) - sess.run(tf.assign(b_grad_val, (i + 1) * np.ones([1, ], + sess.run(tf.assign(b_grad_val, (i + 1) * np.ones([1,], dtype=np.float32))) res = sess.run([opt._curv_win, @@ -186,40 +176,37 @@ def testLrMu(self): res[5] = opt._lr_var.eval() res[6] = opt._mu_var.eval() - g_norm_squared_avg = 0.999 * g_norm_squared_avg \ - + 0.001 * np.sum(((i + 1) * np.ones([n_dim + 1, ]))**2) - g_norm_avg = 0.999 * g_norm_avg \ - + 0.001 * np.linalg.norm((i + 1)*np.ones([n_dim + 1, ])) + g_norm_squared_avg = ( + 0.999 * g_norm_squared_avg + + 0.001 * np.sum(((i + 1) * np.ones([n_dim + 1,]))**2)) + g_norm_avg = (0.999 * g_norm_avg + + 0.001 * np.linalg.norm((i + 1)*np.ones([n_dim + 1,]))) g_avg = 0.999 * g_avg + 0.001 * (i + 1) target_h_max = 0.999 * target_h_max + 0.001 * (i + 1)**2 * (n_dim + 1) - target_h_min = 0.999 * target_h_min + \ - 0.001 * max(1, i + 2 - 20)**2 * (n_dim + 1) + target_h_min = (0.999 * target_h_min + + 0.001 * max(1, i + 2 - 20)**2 * (n_dim + 1)) target_var = g_norm_squared_avg - g_avg**2 * (n_dim + 1) - target_dist = 0.999 * target_dist + \ - 0.001 * g_norm_avg / g_norm_squared_avg + target_dist = (0.999 * target_dist + + 0.001 * g_norm_avg / g_norm_squared_avg) if i > 0: lr, mu = self.tuneEverything(target_dist**2, - target_var, - 1, - target_h_min, - target_h_max) + target_var, + 1, + target_h_min, + target_h_max) target_lr = 0.999 * target_lr + 0.001 * lr target_mu = 0.999 * target_mu + 0.001 * mu - # print "iter ", i, " h max ", res[1], target_h_max, \ - # " h min ", res[2], target_h_min, " var ", res[3], target_var, \ - # " dist ", res[4], target_dist - # print "iter ", i, " lr ", res[5], target_lr, " mu ", res[6], target_mu - assert np.abs(target_h_max - res[1]) < np.abs(target_h_max) * 1e-3 assert np.abs(target_h_min - res[2]) < np.abs(target_h_min) * 1e-3 assert np.abs(target_var - res[3]) < np.abs(res[3]) * 1e-3 assert np.abs(target_dist - res[4]) < np.abs(res[4]) * 1e-3 - assert target_lr == 0.0 or np.abs(target_lr - res[5]) < np.abs(res[5]) * 1e-3 - assert target_mu == 0.0 or np.abs(target_mu - res[6]) < np.abs(res[6]) * 5e-3 - print "[Test-INFO] lr and mu computing test passed!" + assert (target_lr == 0.0 or + (np.abs(target_lr - res[5]) < np.abs(res[5]) * 1e-3)) + assert (target_mu == 0.0 or + (np.abs(target_mu - res[6]) < np.abs(res[6]) * 5e-3)) if __name__ == "__main__":