diff --git a/README.md b/README.md index 27bb47947..059fbe429 100644 --- a/README.md +++ b/README.md @@ -153,7 +153,7 @@ python -c "from tensor2tensor.models.transformer import Transformer" specification. * Support for multi-GPU machines and synchronous (1 master, many workers) and asynchrounous (independent workers synchronizing through a parameter server) - distributed training. + [distributed training](https://github.com/tensorflow/tensor2tensor/tree/master/docs/distributed_training.md). * Easily swap amongst datasets and models by command-line flag with the data generation script `t2t-datagen` and the training script `t2t-trainer`. @@ -173,8 +173,10 @@ and many common sequence datasets are already available for generation and use. **Problems** define training-time hyperparameters for the dataset and task, mainly by setting input and output **modalities** (e.g. symbol, image, audio, -label) and vocabularies, if applicable. All problems are defined in -[`problem_hparams.py`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/problem_hparams.py). +label) and vocabularies, if applicable. All problems are defined either in +[`problem_hparams.py`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/problem_hparams.py) +or are registered with `@registry.register_problem` (run `t2t-datagen` to see +the list of all available problems). **Modalities**, defined in [`modality.py`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/utils/modality.py), abstract away the input and output data types so that **models** may deal with @@ -211,7 +213,7 @@ inference. Users can easily switch between problems, models, and hyperparameter sets by using the `--model`, `--problems`, and `--hparams_set` flags. Specific hyperparameters can be overridden with the `--hparams` flag. `--schedule` and related flags control local and distributed training/evaluation -([distributed training documentation](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/docs/distributed_training.md)). +([distributed training documentation](https://github.com/tensorflow/tensor2tensor/tree/master/docs/distributed_training.md)). --- @@ -222,7 +224,7 @@ enables easily adding new ones and easily swapping amongst them by command-line flag. You can add your own components without editing the T2T codebase by specifying the `--t2t_usr_dir` flag in `t2t-trainer`. -You can currently do so for models, hyperparameter sets, and modalities. Please +You can do so for models, hyperparameter sets, modalities, and problems. Please do submit a pull request if your component might be useful to others. Here's an example with a new hyperparameter set: @@ -253,9 +255,18 @@ You'll see under the registered HParams your `transformer_my_very_own_hparams_set`, which you can directly use on the command line with the `--hparams_set` flag. +`t2t-datagen` also supports the `--t2t_usr_dir` flag for `Problem` +registrations. + ## Adding a dataset -See the [data generators +To add a new dataset, subclass +[`Problem`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/problem.py) +and register it with `@registry.register_problem`. See +[`WMTEnDeTokens8k`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/wmt.py) +for an example. + +Also see the [data generators README](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/README.md). --- diff --git a/tensor2tensor/docs/distributed_training.md b/docs/distributed_training.md similarity index 100% rename from tensor2tensor/docs/distributed_training.md rename to docs/distributed_training.md diff --git a/docs/index.md b/docs/index.md new file mode 100644 index 000000000..a5eeba137 --- /dev/null +++ b/docs/index.md @@ -0,0 +1,23 @@ +# T2T: Tensor2Tensor Transformers + +Check us out on + +GitHub + + +. + +[![PyPI +version](https://badge.fury.io/py/tensor2tensor.svg)](https://badge.fury.io/py/tensor2tensor) +[![GitHub +Issues](https://img.shields.io/github/issues/tensorflow/tensor2tensor.svg)](https://github.com/tensorflow/tensor2tensor/issues) +[![Contributions +welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg)](CONTRIBUTING.md) +[![Gitter](https://img.shields.io/gitter/room/nwjs/nw.js.svg)](https://gitter.im/tensor2tensor/Lobby) +[![License](https://img.shields.io/badge/License-Apache%202.0-brightgreen.svg)](https://opensource.org/licenses/Apache-2.0) + +See our +[README](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/README.md) +for documentation. + +More documentation and tutorials coming soon... diff --git a/setup.py b/setup.py index 00325cff2..d8fd19cf4 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setup( name='tensor2tensor', - version='1.0.14', + version='1.1.0', description='Tensor2Tensor', author='Google Inc.', author_email='no-reply@google.com', diff --git a/tensor2tensor/bin/t2t-datagen b/tensor2tensor/bin/t2t-datagen old mode 100755 new mode 100644 index cbf0a6164..b0fd816a2 --- a/tensor2tensor/bin/t2t-datagen +++ b/tensor2tensor/bin/t2t-datagen @@ -35,7 +35,6 @@ import tempfile import numpy as np -from tensor2tensor.data_generators import algorithmic from tensor2tensor.data_generators import algorithmic_math from tensor2tensor.data_generators import all_problems # pylint: disable=unused-import from tensor2tensor.data_generators import audio @@ -60,52 +59,22 @@ flags.DEFINE_string("tmp_dir", "/tmp/t2t_datagen", "Temporary storage directory.") flags.DEFINE_string("problem", "", "The name of the problem to generate data for.") +flags.DEFINE_string("exclude_problems", "", + "Comma-separates list of problems to exclude.") flags.DEFINE_integer("num_shards", 10, "How many shards to use.") flags.DEFINE_integer("max_cases", 0, "Maximum number of cases to generate (unbounded if 0).") flags.DEFINE_integer("random_seed", 429459, "Random seed to use.") - flags.DEFINE_string("t2t_usr_dir", "", "Path to a Python module that will be imported. The " "__init__.py file should include the necessary imports. " "The imported files should contain registrations, " - "e.g. @registry.register_model calls, that will then be " - "available to the t2t-datagen.") + "e.g. @registry.register_problem calls, that will then be " + "available to t2t-datagen.") # Mapping from problems that we can generate data for to their generators. # pylint: disable=g-long-lambda _SUPPORTED_PROBLEM_GENERATORS = { - "algorithmic_shift_decimal40": ( - lambda: algorithmic.shift_generator(20, 10, 40, 100000), - lambda: algorithmic.shift_generator(20, 10, 80, 10000)), - "algorithmic_reverse_binary40": ( - lambda: algorithmic.reverse_generator(2, 40, 100000), - lambda: algorithmic.reverse_generator(2, 400, 10000)), - "algorithmic_reverse_decimal40": ( - lambda: algorithmic.reverse_generator(10, 40, 100000), - lambda: algorithmic.reverse_generator(10, 400, 10000)), - "algorithmic_addition_binary40": ( - lambda: algorithmic.addition_generator(2, 40, 100000), - lambda: algorithmic.addition_generator(2, 400, 10000)), - "algorithmic_addition_decimal40": ( - lambda: algorithmic.addition_generator(10, 40, 100000), - lambda: algorithmic.addition_generator(10, 400, 10000)), - "algorithmic_multiplication_binary40": ( - lambda: algorithmic.multiplication_generator(2, 40, 100000), - lambda: algorithmic.multiplication_generator(2, 400, 10000)), - "algorithmic_multiplication_decimal40": ( - lambda: algorithmic.multiplication_generator(10, 40, 100000), - lambda: algorithmic.multiplication_generator(10, 400, 10000)), - "algorithmic_reverse_nlplike_decimal8K": ( - lambda: algorithmic.reverse_generator_nlplike(8000, 70, 100000, - 10, 1.300), - lambda: algorithmic.reverse_generator_nlplike(8000, 70, 10000, - 10, 1.300)), - "algorithmic_reverse_nlplike_decimal32K": ( - lambda: algorithmic.reverse_generator_nlplike(32000, 70, 100000, - 10, 1.050), - lambda: algorithmic.reverse_generator_nlplike(32000, 70, 10000, - 10, 1.050)), "algorithmic_algebra_inverse": ( lambda: algorithmic_math.algebra_inverse(26, 0, 2, 100000), lambda: algorithmic_math.algebra_inverse(26, 3, 3, 10000)), @@ -125,29 +94,9 @@ _SUPPORTED_PROBLEM_GENERATORS = { 2**14, 2**9), lambda: wsj_parsing.parsing_token_generator(FLAGS.tmp_dir, False, 2**14, 2**9)), - "wmt_enfr_characters": ( - lambda: wmt.enfr_character_generator(FLAGS.tmp_dir, True), - lambda: wmt.enfr_character_generator(FLAGS.tmp_dir, False)), - "wmt_enfr_tokens_8k": ( - lambda: wmt.enfr_wordpiece_token_generator(FLAGS.tmp_dir, True, 2**13), - lambda: wmt.enfr_wordpiece_token_generator(FLAGS.tmp_dir, False, 2**13) - ), - "wmt_enfr_tokens_32k": ( - lambda: wmt.enfr_wordpiece_token_generator(FLAGS.tmp_dir, True, 2**15), - lambda: wmt.enfr_wordpiece_token_generator(FLAGS.tmp_dir, False, 2**15) - ), - "wmt_ende_characters": ( - lambda: wmt.ende_character_generator(FLAGS.tmp_dir, True), - lambda: wmt.ende_character_generator(FLAGS.tmp_dir, False)), "wmt_ende_bpe32k": ( lambda: wmt.ende_bpe_token_generator(FLAGS.tmp_dir, True), lambda: wmt.ende_bpe_token_generator(FLAGS.tmp_dir, False)), - "wmt_zhen_tokens_32k": ( - lambda: wmt.zhen_wordpiece_token_generator(FLAGS.tmp_dir, True, - 2**15, 2**15), - lambda: wmt.zhen_wordpiece_token_generator(FLAGS.tmp_dir, False, - 2**15, 2**15) - ), "lm1b_32k": ( lambda: lm1b.generator(FLAGS.tmp_dir, True), lambda: lm1b.generator(FLAGS.tmp_dir, False) @@ -286,6 +235,9 @@ def main(_): # Calculate the list of problems to generate. problems = sorted( list(_SUPPORTED_PROBLEM_GENERATORS) + registry.list_problems()) + for exclude in FLAGS.exclude_problems.split(","): + if exclude: + problems = [p for p in problems if exclude not in p] if FLAGS.problem and FLAGS.problem[-1] == "*": problems = [p for p in problems if p.startswith(FLAGS.problem[:-1])] elif FLAGS.problem: diff --git a/tensor2tensor/bin/t2t-trainer b/tensor2tensor/bin/t2t-trainer old mode 100755 new mode 100644 index 6b3f4de71..8a801e70e --- a/tensor2tensor/bin/t2t-trainer +++ b/tensor2tensor/bin/t2t-trainer @@ -29,14 +29,11 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import importlib -import os -import sys - # Dependency imports from tensor2tensor.utils import trainer_utils as utils from tensor2tensor.utils import usr_dir + import tensorflow as tf flags = tf.flags @@ -49,6 +46,7 @@ flags.DEFINE_string("t2t_usr_dir", "", "e.g. @registry.register_model calls, that will then be " "available to the t2t-trainer.") + def main(_): tf.logging.set_verbosity(tf.logging.INFO) usr_dir.import_usr_dir(FLAGS.t2t_usr_dir) diff --git a/tensor2tensor/data_generators/README.md b/tensor2tensor/data_generators/README.md index f8495c38f..310bc39df 100644 --- a/tensor2tensor/data_generators/README.md +++ b/tensor2tensor/data_generators/README.md @@ -1,7 +1,7 @@ -# Data generators for T2T models. +# T2T Problems. -This directory contains data generators for a number of problems. We use a -naming scheme for the problems, they have names of the form +This directory contains `Problem` specifications for a number of problems. We +use a naming scheme for the problems, they have names of the form `[task-family]_[task]_[specifics]`. Data for all currently supported problems can be generated by calling the main generator binary (`t2t-datagen`). For example: @@ -20,53 +20,51 @@ All tasks produce TFRecord files of `tensorflow.Example` protocol buffers. ## Adding a new problem -1. Implement and register a Python generator for the dataset -1. Add a problem specification to `problem_hparams.py` specifying input and - output modalities +To add a new problem, subclass +[`Problem`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/problem.py) +and register it with `@registry.register_problem`. See +[`WMTEnDeTokens8k`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/wmt.py) +for an example. -To add a new problem, you first need to create python generators for training -and development data for the problem. The python generators should yield -dictionaries with string keys and values being lists of {int, float, str}. -Here is a very simple generator for a data-set where inputs are lists of 1s with -length upto 100 and targets are lists of length 1 with an integer denoting the -length of the input list. +`Problem`s support data generation, training, and decoding. + +Data generation is handles by `Problem.generate_data` which should produce 2 +datasets, training and dev, which should be named according to +`Problem.training_filepaths` and `Problem.dev_filepaths`. +`Problem.generate_data` should also produce any other files that may be required +for training/decoding, e.g. a vocabulary file. + +A particularly easy way to implement `Problem.generate_data` for your dataset is +to create 2 Python generators, one for the training data and another for the +dev data, and pass them to `generator_utils.generate_dataset_and_shuffle`. See +[`WMTEnDeTokens8k.generate_data`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/wmt.py) +for an example of usage. + +The generators should yield dictionaries with string keys and values being lists +of {int, float, str}. Here is a very simple generator for a data-set where +inputs are lists of 2s with length upto 100 and targets are lists of length 1 +with an integer denoting the length of the input list. ``` def length_generator(nbr_cases): for _ in xrange(nbr_cases): length = np.random.randint(100) + 1 - yield {"inputs": [1] * length, "targets": [length]} + yield {"inputs": [2] * length, "targets": [length]} ``` -Note that our data reader uses 0 for padding, so it is a good idea to never -generate 0s, except if all your examples have the same size (in which case -they'll never be padded anyway) or if you're doing padding on your own (in which -case please use 0s for padding). When adding the python generator function, -please also add unit tests to check if the code runs. +Note that our data reader uses 0 for padding and other parts of the code assume +end-of-string (EOS) is 1, so it is a good idea to never generate 0s or 1s, +except if all your examples have the same size (in which case they'll never be +padded anyway) or if you're doing padding on your own (in which case please use +0s for padding). When adding the python generator function, please also add unit +tests to check if the code runs. The generator can do arbitrary setup before beginning to yield examples - for example, downloading data, generating vocabulary files, etc. Some examples: -* [Algorithmic generators](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/algorithmic.py) +* [Algorithmic problems](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/algorithmic.py) and their [unit tests](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/algorithmic_test.py) -* [WMT generators](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/wmt.py) +* [WMT problems](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/wmt.py) and their [unit tests](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/wmt_test.py) - -When your python generator is ready and tested, add it to the -`_SUPPORTED_PROBLEM_GENERATORS` dictionary in the -[data -generator](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/bin/t2t-datagen). -The keys are problem names, and the values are pairs of (training-set-generator -function, dev-set-generator function). For the generator above, one could add -the following lines: - -``` - "algorithmic_length_upto100": - (lambda: algorithmic.length_generator(10000), - lambda: algorithmic.length_generator(1000)), -``` - -Note the lambdas above: we don't want to call the generators too early. - diff --git a/tensor2tensor/data_generators/algorithmic.py b/tensor2tensor/data_generators/algorithmic.py index 7e522bfa0..2169e1910 100644 --- a/tensor2tensor/data_generators/algorithmic.py +++ b/tensor2tensor/data_generators/algorithmic.py @@ -25,48 +25,86 @@ from tensor2tensor.data_generators import generator_utils as utils from tensor2tensor.data_generators import problem +from tensor2tensor.data_generators import text_encoder from tensor2tensor.utils import registry -@registry.register_problem -class AlgorithmicIdentityBinary40(problem.Problem): - """Problem spec for algorithmic binary identity task.""" +class AlgorithmicProblem(problem.Problem): + """Base class for algorithmic problems.""" @property def num_symbols(self): - return 2 + raise NotImplementedError() + + @property + def train_generator(self): + """Generator; takes 3 args: nbr_symbols, max_length, nbr_cases.""" + raise NotImplementedError() + + @property + def dev_generator(self): + return self.train_generator + + @property + def train_length(self): + return 40 + + @property + def dev_length(self): + return 400 + + @property + def train_size(self): + return 100000 + + @property + def dev_size(self): + return 10000 + + @property + def num_shards(self): + return 10 + + def generate_data(self, data_dir, _, num_shards=None): + if num_shards is None: + num_shards = self.num_shards + + def generator_eos(generator): + """Shift by NUM_RESERVED_IDS and append EOS token.""" + for case in generator: + new_case = {} + for feature in case: + new_case[feature] = [i + text_encoder.NUM_RESERVED_TOKENS + for i in case[feature]] + [text_encoder.EOS_ID] + yield new_case + + train_generator_eos = lambda: generator_eos( # pylint: disable=g-long-lambda + self.train_generator(self.num_symbols, + self.train_length, self.train_size)) + dev_generator_eos = lambda: generator_eos( # pylint: disable=g-long-lambda + self.dev_generator(self.num_symbols, self.dev_length, self.dev_size)) - def generate_data(self, data_dir, _, num_shards=100): utils.generate_dataset_and_shuffle( - identity_generator(self.num_symbols, 40, 100000), + train_generator_eos(), self.training_filepaths(data_dir, num_shards, shuffled=True), - identity_generator(self.num_symbols, 400, 10000), + dev_generator_eos(), self.dev_filepaths(data_dir, 1, shuffled=True), shuffle=False) def hparams(self, defaults, unused_model_hparams): p = defaults - vocab_size = self.num_symbols + self._encoders["inputs"].num_reserved_ids + vocab_size = self.num_symbols + text_encoder.NUM_RESERVED_TOKENS p.input_modality = {"inputs": (registry.Modalities.SYMBOL, vocab_size)} p.target_modality = (registry.Modalities.SYMBOL, vocab_size) p.input_space_id = problem.SpaceID.DIGIT_0 p.target_space_id = problem.SpaceID.DIGIT_1 -@registry.register_problem -class AlgorithmicIdentityDecimal40(AlgorithmicIdentityBinary40): - """Problem spec for algorithmic decimal identity task.""" - - @property - def num_symbols(self): - return 10 - - def identity_generator(nbr_symbols, max_length, nbr_cases): """Generator for the identity (copy) task on sequences of symbols. The length of the sequence is drawn uniformly at random from [1, max_length] - and then symbols are drawn uniformly at random from [2, nbr_symbols + 2) until + and then symbols are drawn uniformly at random from [0, nbr_symbols) until nbr_cases sequences have been produced. Args: @@ -80,15 +118,37 @@ def identity_generator(nbr_symbols, max_length, nbr_cases): """ for _ in xrange(nbr_cases): l = np.random.randint(max_length) + 1 - inputs = [np.random.randint(nbr_symbols) + 2 for _ in xrange(l)] - yield {"inputs": inputs, "targets": inputs + [1]} # [1] for EOS + inputs = [np.random.randint(nbr_symbols) for _ in xrange(l)] + yield {"inputs": inputs, "targets": inputs} + + +@registry.register_problem +class AlgorithmicIdentityBinary40(AlgorithmicProblem): + """Problem spec for algorithmic binary identity task.""" + + @property + def num_symbols(self): + return 2 + + @property + def train_generator(self): + return identity_generator + + +@registry.register_problem +class AlgorithmicIdentityDecimal40(AlgorithmicIdentityBinary40): + """Problem spec for algorithmic decimal identity task.""" + + @property + def num_symbols(self): + return 10 def shift_generator(nbr_symbols, shift, max_length, nbr_cases): """Generator for the shift task on sequences of symbols. The length of the sequence is drawn uniformly at random from [1, max_length] - and then symbols are drawn uniformly at random from [2, nbr_symbols - shift] + and then symbols are drawn uniformly at random from [0, nbr_symbols - shift] until nbr_cases sequences have been produced (output[i] = input[i] + shift). Args: @@ -103,18 +163,35 @@ def shift_generator(nbr_symbols, shift, max_length, nbr_cases): """ for _ in xrange(nbr_cases): l = np.random.randint(max_length) + 1 - inputs = [np.random.randint(nbr_symbols - shift) + 2 for _ in xrange(l)] + inputs = [np.random.randint(nbr_symbols - shift) for _ in xrange(l)] yield { "inputs": inputs, - "targets": [i + shift for i in inputs] + [1] - } # [1] for EOS + "targets": [i + shift for i in inputs] + } + + +@registry.register_problem +class AlgorithmicShiftDecimal40(AlgorithmicProblem): + """Problem spec for algorithmic decimal shift task.""" + + @property + def num_symbols(self): + return 20 + + @property + def train_generator(self): + return lambda nbr_sym, l, size: shift_generator(nbr_sym, 10, l, size) + + @property + def dev_length(self): + return 80 def reverse_generator(nbr_symbols, max_length, nbr_cases): """Generator for the reversing task on sequences of symbols. The length of the sequence is drawn uniformly at random from [1, max_length] - and then symbols are drawn uniformly at random from [2, nbr_symbols] until + and then symbols are drawn uniformly at random from [0, nbr_symbols) until nbr_cases sequences have been produced. Args: @@ -128,11 +205,33 @@ def reverse_generator(nbr_symbols, max_length, nbr_cases): """ for _ in xrange(nbr_cases): l = np.random.randint(max_length) + 1 - inputs = [np.random.randint(nbr_symbols) + 2 for _ in xrange(l)] + inputs = [np.random.randint(nbr_symbols) for _ in xrange(l)] yield { "inputs": inputs, - "targets": list(reversed(inputs)) + [1] - } # [1] for EOS + "targets": list(reversed(inputs)) + } + + +@registry.register_problem +class AlgorithmicReverseBinary40(AlgorithmicProblem): + """Problem spec for algorithmic binary reversing task.""" + + @property + def num_symbols(self): + return 2 + + @property + def train_generator(self): + return reverse_generator + + +@registry.register_problem +class AlgorithmicReverseDecimal40(AlgorithmicReverseBinary40): + """Problem spec for algorithmic decimal reversing task.""" + + @property + def num_symbols(self): + return 10 def zipf_distribution(nbr_symbols, alpha): @@ -166,11 +265,8 @@ def zipf_random_sample(distr_map, sample_len): """ u = np.random.random(sample_len) # Random produces values in range [0.0,1.0); even if it is almost - # improbable(but possible) that it can generate a clear 0.000..0, - # we have made a sanity check to overcome this issue. On the other hand, - # t+1 is enough from saving us to generate PAD(0) and EOS(1) which are - # reservated symbols. - return [t + 1 if t > 0 else t + 2 for t in np.searchsorted(distr_map, u)] + # improbable(but possible) that it can generate a clear 0.000..0. + return list(np.searchsorted(distr_map, u)) def reverse_generator_nlplike(nbr_symbols, @@ -182,7 +278,7 @@ def reverse_generator_nlplike(nbr_symbols, The length of the sequence is drawn from a Gaussian(Normal) distribution at random from [1, max_length] and with std deviation of 1%, - then symbols are drawn from Zipf's law at random from [2, nbr_symbols] until + then symbols are drawn from Zipf's law at random from [0, nbr_symbols) until nbr_cases sequences have been produced. Args: @@ -206,8 +302,44 @@ def reverse_generator_nlplike(nbr_symbols, inputs = zipf_random_sample(distr_map, l) yield { "inputs": inputs, - "targets": list(reversed(inputs)) + [1] - } # [1] for EOS + "targets": list(reversed(inputs)) + } + + +@registry.register_problem +class AlgorithmicReverseNlplike8K(AlgorithmicProblem): + """Problem spec for algorithmic nlp-like reversing task.""" + + @property + def num_symbols(self): + return 8000 + + @property + def train_generator(self): + return lambda nbr_sym, length, size: reverse_generator_nlplike( # pylint: disable=g-long-lambda + nbr_sym, length, size, 10, 1.300) + + @property + def train_length(self): + return 70 + + @property + def dev_length(self): + return 70 + + +@registry.register_problem +class AlgorithmicReverseNlplike32K(AlgorithmicReverseNlplike8K): + """Problem spec for algorithmic nlp-like reversing task, 32K vocab.""" + + @property + def num_symbols(self): + return 32000 + + @property + def train_generator(self): + return lambda nbr_sym, length, size: reverse_generator_nlplike( # pylint: disable=g-long-lambda + nbr_sym, length, size, 10, 1.050) def lower_endian_to_number(l, base): @@ -235,7 +367,7 @@ def addition_generator(base, max_length, nbr_cases): The length of each number is drawn uniformly at random from [1, max_length/2] and then digits are drawn uniformly at random. The numbers are added and - separated by [base+1] in the input. Stops at nbr_cases. + separated by [base] in the input. Stops at nbr_cases. Args: base: in which base are the numbers. @@ -257,10 +389,31 @@ def addition_generator(base, max_length, nbr_cases): n1 = random_number_lower_endian(l1, base) n2 = random_number_lower_endian(l2, base) result = lower_endian_to_number(n1, base) + lower_endian_to_number(n2, base) - # We shift digits by 1 on input and output to leave 0 for padding. - inputs = [i + 2 for i in n1] + [base + 2] + [i + 2 for i in n2] - targets = [i + 2 for i in number_to_lower_endian(result, base)] - yield {"inputs": inputs, "targets": targets + [1]} # [1] for EOS + inputs = n1 + [base] + n2 + targets = number_to_lower_endian(result, base) + yield {"inputs": inputs, "targets": targets} + + +@registry.register_problem +class AlgorithmicAdditionBinary40(AlgorithmicProblem): + """Problem spec for algorithmic binary addition task.""" + + @property + def num_symbols(self): + return 2 + + @property + def train_generator(self): + return addition_generator + + +@registry.register_problem +class AlgorithmicAdditionDecimal40(AlgorithmicAdditionBinary40): + """Problem spec for algorithmic decimal addition task.""" + + @property + def num_symbols(self): + return 10 def multiplication_generator(base, max_length, nbr_cases): @@ -268,7 +421,7 @@ def multiplication_generator(base, max_length, nbr_cases): The length of each number is drawn uniformly at random from [1, max_length/2] and then digits are drawn uniformly at random. The numbers are multiplied - and separated by [base+1] in the input. Stops at nbr_cases. + and separated by [base] in the input. Stops at nbr_cases. Args: base: in which base are the numbers. @@ -291,7 +444,28 @@ def multiplication_generator(base, max_length, nbr_cases): n1 = random_number_lower_endian(l1, base) n2 = random_number_lower_endian(l2, base) result = lower_endian_to_number(n1, base) * lower_endian_to_number(n2, base) - # We shift digits by 1 on input and output to leave 0 for padding. - inputs = [i + 2 for i in n1] + [base + 2] + [i + 2 for i in n2] - targets = [i + 2 for i in number_to_lower_endian(result, base)] - yield {"inputs": inputs, "targets": targets + [1]} # [1] for EOS + inputs = n1 + [base] + n2 + targets = number_to_lower_endian(result, base) + yield {"inputs": inputs, "targets": targets} + + +@registry.register_problem +class AlgorithmicMultiplicationBinary40(AlgorithmicProblem): + """Problem spec for algorithmic binary multiplication task.""" + + @property + def num_symbols(self): + return 2 + + @property + def train_generator(self): + return multiplication_generator + + +@registry.register_problem +class AlgorithmicMultiplicationDecimal40(AlgorithmicMultiplicationBinary40): + """Problem spec for algorithmic decimal multiplication task.""" + + @property + def num_symbols(self): + return 10 diff --git a/tensor2tensor/data_generators/algorithmic_test.py b/tensor2tensor/data_generators/algorithmic_test.py index 9961e6173..fb8ff6719 100644 --- a/tensor2tensor/data_generators/algorithmic_test.py +++ b/tensor2tensor/data_generators/algorithmic_test.py @@ -31,14 +31,14 @@ def testIdentityGenerator(self): counter = 0 for d in algorithmic.identity_generator(3, 8, 10): counter += 1 - self.assertEqual(d["inputs"] + [1], d["targets"]) + self.assertEqual(d["inputs"], d["targets"]) self.assertEqual(counter, 10) def testReverseGenerator(self): counter = 0 for d in algorithmic.reverse_generator(3, 8, 10): counter += 1 - self.assertEqual(list(reversed(d["inputs"])) + [1], d["targets"]) + self.assertEqual(list(reversed(d["inputs"])), d["targets"]) self.assertEqual(counter, 10) def testZipfDistribution(self): @@ -53,7 +53,7 @@ def testReverseGeneratorNlpLike(self): counter = 0 for d in algorithmic.reverse_generator_nlplike(3, 8, 10): counter += 1 - self.assertEqual(list(reversed(d["inputs"])) + [1], d["targets"]) + self.assertEqual(list(reversed(d["inputs"])), d["targets"]) self.assertEqual(counter, 10) def testLowerEndianToNumber(self): @@ -78,20 +78,20 @@ def testAdditionGenerator(self): counter = 0 for d in algorithmic.addition_generator(4, 8, 10): counter += 1 - self.assertEqual(d["inputs"].count(6), 1) - self.assertEqual(d["inputs"].count(0), 0) - self.assertEqual(d["targets"].count(6), 0) - self.assertEqual(d["targets"].count(0), 0) + self.assertEqual(d["inputs"].count(4), 1) + self.assertEqual(d["inputs"].count(5), 0) + self.assertEqual(d["targets"].count(4), 0) + self.assertEqual(d["targets"].count(5), 0) self.assertEqual(counter, 10) def testMultiplicationGenerator(self): counter = 0 for d in algorithmic.multiplication_generator(4, 8, 10): counter += 1 - self.assertEqual(d["inputs"].count(6), 1) - self.assertEqual(d["inputs"].count(0), 0) - self.assertEqual(d["targets"].count(6), 0) - self.assertEqual(d["targets"].count(0), 0) + self.assertEqual(d["inputs"].count(4), 1) + self.assertEqual(d["inputs"].count(5), 0) + self.assertEqual(d["targets"].count(4), 0) + self.assertEqual(d["targets"].count(5), 0) self.assertEqual(counter, 10) diff --git a/tensor2tensor/data_generators/generator_utils.py b/tensor2tensor/data_generators/generator_utils.py old mode 100755 new mode 100644 index 20f3959d8..b34a87138 --- a/tensor2tensor/data_generators/generator_utils.py +++ b/tensor2tensor/data_generators/generator_utils.py @@ -244,7 +244,6 @@ def gunzip_file(gz_path, new_path): "http://www.statmt.org/wmt13/training-parallel-un.tgz", ["un/undoc.2000.fr-en.en", "un/undoc.2000.fr-en.fr"] ], - # Macedonian-English [ "https://github.com/stefan-it/nmt-mk-en/raw/master/data/setimes.mk-en.train.tgz", # pylint: disable=line-too-long ["train.mk", "train.en"] diff --git a/tensor2tensor/data_generators/problem.py b/tensor2tensor/data_generators/problem.py index 0eb1987fa..1182ed7d1 100644 --- a/tensor2tensor/data_generators/problem.py +++ b/tensor2tensor/data_generators/problem.py @@ -92,6 +92,7 @@ class Problem(object): get sharded filenames. If shuffled=False, the filenames will contain an "unshuffled" suffix; you should then shuffle the data shard-by-shard with generator_utils.shuffle_dataset. + - Allows to specify the number of shards, optionally (can be omitted). - Subclasses must override * dataset_filename() - Base filename for problem. @@ -113,7 +114,7 @@ class Problem(object): # BEGIN SUBCLASS INTERFACE # ============================================================================ - def generate_data(self, data_dir, tmp_dir, num_shards=100): + def generate_data(self, data_dir, tmp_dir, num_shards=None): raise NotImplementedError() def hparams(self, defaults, model_hparams): diff --git a/tensor2tensor/data_generators/problem_hparams.py b/tensor2tensor/data_generators/problem_hparams.py index 70b9dada8..5922ab59a 100644 --- a/tensor2tensor/data_generators/problem_hparams.py +++ b/tensor2tensor/data_generators/problem_hparams.py @@ -217,20 +217,6 @@ def test_problem_hparams(unused_model_hparams, input_vocab_size, return p -def algorithmic(vocab_size, unused_model_hparams): - """Default parameters for algorithmic tasks.""" - p = default_problem_hparams() - p.input_modality = {"inputs": (registry.Modalities.SYMBOL, vocab_size)} - p.target_modality = (registry.Modalities.SYMBOL, vocab_size) - p.vocabulary = { - "inputs": text_encoder.TextEncoder(), - "targets": text_encoder.TextEncoder(), - } - p.input_space_id = 10 - p.target_space_id = 11 - return p - - def audio_timit_characters(unused_model_hparams): """English audio transcription benchmark.""" p = default_problem_hparams() @@ -351,10 +337,9 @@ def wiki_32k(model_hparams): p = default_problem_hparams() encoder = text_encoder.SubwordTextEncoder( os.path.join(model_hparams.data_dir, "wiki_32k.subword_text_encoder")) - p.input_modality = { - "inputs": (registry.Modalities.SYMBOL, encoder.vocab_size) - } - p.target_modality = (registry.Modalities.SYMBOL, encoder.vocab_size) + modality_spec = (registry.Modalities.SYMBOL, encoder.vocab_size) + p.input_modality = {"inputs": modality_spec} + p.target_modality = modality_spec p.vocabulary = { "inputs": encoder, "targets": encoder @@ -378,50 +363,6 @@ def lmptb_10k(model_hparams): return p -def wmt_enfr_characters(unused_model_hparams): - """English to French translation benchmark.""" - p = default_problem_hparams() - p.input_modality = {"inputs": (registry.Modalities.SYMBOL, 256)} - p.target_modality = (registry.Modalities.SYMBOL, 256) - p.vocabulary = { - "inputs": text_encoder.ByteTextEncoder(), - "targets": text_encoder.ByteTextEncoder(), - } - p.loss_multiplier = 2.0 - p.input_space_id = 2 - p.target_space_id = 5 - return p - - -def wmt_enfr_tokens(model_hparams, wrong_vocab_size): - """English to French translation benchmark. - - Args: - model_hparams: a tf.contrib.training.HParams - wrong_vocab_size: a number used in the filename indicating the approximate - vocabulary size. This is not to be confused with the actual vocabulary - size. - Returns: - a tf.contrib.training.HParams - """ - p = default_problem_hparams() - # This vocab file must be present within the data directory. - vocab_filename = os.path.join(model_hparams.data_dir, - "tokens.vocab.%d" % wrong_vocab_size) - subtokenizer = text_encoder.SubwordTextEncoder(vocab_filename) - p.input_modality = { - "inputs": (registry.Modalities.SYMBOL, subtokenizer.vocab_size) - } - p.target_modality = (registry.Modalities.SYMBOL, subtokenizer.vocab_size) - p.vocabulary = { - "inputs": subtokenizer, - "targets": subtokenizer, - } - p.input_space_id = 3 - p.target_space_id = 6 - return p - - def wmt_ende_bpe32k(model_hparams): """English to German translation benchmark.""" p = default_problem_hparams() @@ -441,47 +382,6 @@ def wmt_ende_bpe32k(model_hparams): return p -def wmt_ende_characters(unused_model_hparams): - """English to German translation benchmark.""" - p = default_problem_hparams() - p.input_modality = {"inputs": (registry.Modalities.SYMBOL, 256)} - p.target_modality = (registry.Modalities.SYMBOL, 256) - p.vocabulary = { - "inputs": text_encoder.ByteTextEncoder(), - "targets": text_encoder.ByteTextEncoder(), - } - p.loss_multiplier = 2.0 - p.input_space_id = 2 - p.target_space_id = 7 - return p - - -def wmt_zhen_tokens(model_hparams, wrong_vocab_size): - """Chinese to English translation benchmark.""" - p = default_problem_hparams() - # This vocab file must be present within the data directory. - if model_hparams.shared_embedding_and_softmax_weights == 1: - model_hparams.shared_embedding_and_softmax_weights = 0 - source_vocab_filename = os.path.join(model_hparams.data_dir, - "tokens.vocab.zh.%d" % wrong_vocab_size) - target_vocab_filename = os.path.join(model_hparams.data_dir, - "tokens.vocab.en.%d" % wrong_vocab_size) - source_token = text_encoder.SubwordTextEncoder(source_vocab_filename) - target_token = text_encoder.SubwordTextEncoder(target_vocab_filename) - p.input_modality = { - "inputs": (registry.Modalities.SYMBOL, source_token.vocab_size) - } - p.target_modality = (registry.Modalities.SYMBOL, target_token.vocab_size) - p.vocabulary = { - "inputs": source_token, - "targets": target_token, - } - p.loss_multiplier = 1.4 - p.input_space_id = 16 - p.target_space_id = 4 - return p - - def wmt_parsing_characters(model_hparams): """English to parse tree translation benchmark.""" del model_hparams # Unused. @@ -699,15 +599,6 @@ def img2img_imagenet(unused_model_hparams): # Dictionary of named hyperparameter settings for various problems. # This is only accessed through the problem_hparams function below. PROBLEM_HPARAMS_MAP = { - "algorithmic_addition_binary40": lambda p: algorithmic(4, p), - "algorithmic_addition_decimal40": lambda p: algorithmic(12, p), - "algorithmic_multiplication_binary40": lambda p: algorithmic(4, p), - "algorithmic_multiplication_decimal40": lambda p: algorithmic(12, p), - "algorithmic_reverse_binary40": lambda p: algorithmic(4, p), - "algorithmic_reverse_decimal40": lambda p: algorithmic(12, p), - "algorithmic_reverse_nlplike_decimal8K": lambda p: algorithmic(8002, p), - "algorithmic_reverse_nlplike_decimal32K": lambda p: algorithmic(32002, p), - "algorithmic_shift_decimal40": lambda p: algorithmic(22, p), "audio_timit_characters_tune": audio_timit_characters, "audio_timit_characters_test": audio_timit_characters, "audio_timit_tokens_8k_tune": lambda p: audio_timit_tokens(p, 2**13), @@ -724,15 +615,7 @@ def img2img_imagenet(unused_model_hparams): "wmt_parsing_tokens_8k": lambda p: wmt_parsing_tokens(p, 2**13), "wsj_parsing_tokens_16k": lambda p: wsj_parsing_tokens( # pylint: disable=g-long-lambda p, "wsj", 2**14, 2**9), - "wmt_enfr_characters": wmt_enfr_characters, - "wmt_enfr_tokens_8k": lambda p: wmt_enfr_tokens(p, 2**13), - "wmt_enfr_tokens_32k": lambda p: wmt_enfr_tokens(p, 2**15), - "wmt_enfr_tokens_32k_shuffled": lambda p: wmt_enfr_tokens(p, 2**15), - "wmt_enfr_tokens_32k_combined": lambda p: wmt_enfr_tokens(p, 2**15), - "wmt_enfr_tokens_128k": lambda p: wmt_enfr_tokens(p, 2**17), - "wmt_ende_characters": wmt_ende_characters, "wmt_ende_bpe32k": wmt_ende_bpe32k, - "wmt_zhen_tokens_32k": lambda p: wmt_zhen_tokens(p, 2**15), "image_cifar10_tune": image_cifar10, "image_cifar10_test": image_cifar10, "image_mnist_tune": image_mnist, diff --git a/tensor2tensor/data_generators/text_encoder.py b/tensor2tensor/data_generators/text_encoder.py index c812ced4f..e0ac1901e 100644 --- a/tensor2tensor/data_generators/text_encoder.py +++ b/tensor2tensor/data_generators/text_encoder.py @@ -42,8 +42,8 @@ EOS = "" RESERVED_TOKENS = [PAD, EOS] NUM_RESERVED_TOKENS = len(RESERVED_TOKENS) -PAD_TOKEN = RESERVED_TOKENS.index(PAD) # Normally 0 -EOS_TOKEN = RESERVED_TOKENS.index(EOS) # Normally 1 +PAD_ID = RESERVED_TOKENS.index(PAD) # Normally 0 +EOS_ID = RESERVED_TOKENS.index(EOS) # Normally 1 if PY2: RESERVED_TOKENS_BYTES = RESERVED_TOKENS @@ -51,6 +51,13 @@ RESERVED_TOKENS_BYTES = [bytes(PAD, "ascii"), bytes(EOS, "ascii")] +# Regular expression for unescaping token strings. +# '\u' is converted to '_' +# '\\' is converted to '\' +# '\213;' is converted to unichr(213) +_UNESCAPE_REGEX = re.compile(u"|".join([r"\\u", r"\\\\", r"\\([0-9]+);"])) + + def native_to_unicode_py2(s): """Python 2: transform native string to Unicode.""" if isinstance(s, unicode): @@ -505,12 +512,6 @@ def _escape_token(self, token): ret += u"\\%d;" % ord(c) return ret - # Regular expression for unescaping token strings - # '\u' is converted to '_' - # '\\' is converted to '\' - # '\213;' is converted to unichr(213) - _UNESCAPE_REGEX = re.compile(u'|'.join([r"\\u", r"\\\\", r"\\([0-9]+);"])) - def _unescape_token(self, escaped_token): """Inverse of _escape_token(). diff --git a/tensor2tensor/data_generators/tokenizer_test.py b/tensor2tensor/data_generators/tokenizer_test.py old mode 100755 new mode 100644 index 45a1f7e41..c279290ed --- a/tensor2tensor/data_generators/tokenizer_test.py +++ b/tensor2tensor/data_generators/tokenizer_test.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Copyright 2017 The Tensor2Tensor Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +# coding=utf-8 """Tests for tensor2tensor.data_generators.tokenizer.""" from __future__ import absolute_import diff --git a/tensor2tensor/data_generators/wmt.py b/tensor2tensor/data_generators/wmt.py index de5a25e13..2e1f1e8af 100644 --- a/tensor2tensor/data_generators/wmt.py +++ b/tensor2tensor/data_generators/wmt.py @@ -38,77 +38,97 @@ FLAGS = tf.flags.FLAGS -@registry.register_problem("wmt_ende_tokens_8k") -class WMTEnDeTokens8k(problem.Problem): - """Problem spec for WMT En-De translation.""" +# End-of-sentence marker. +EOS = text_encoder.EOS_ID - @property - def target_vocab_size(self): - return 2**13 # 8192 - def feature_encoders(self, data_dir): - return _default_wmt_feature_encoders(data_dir, self.target_vocab_size) +def _default_token_feature_encoders(data_dir, target_vocab_size): + vocab_filename = os.path.join(data_dir, "tokens.vocab.%d" % target_vocab_size) + subtokenizer = text_encoder.SubwordTextEncoder(vocab_filename) + return { + "inputs": subtokenizer, + "targets": subtokenizer, + } - def generate_data(self, data_dir, tmp_dir, num_shards=100): - generator_utils.generate_dataset_and_shuffle( - ende_wordpiece_token_generator(tmp_dir, True, self.target_vocab_size), - self.training_filepaths(data_dir, num_shards, shuffled=False), - ende_wordpiece_token_generator(tmp_dir, False, self.target_vocab_size), - self.dev_filepaths(data_dir, 1, shuffled=False)) - def hparams(self, defaults, unused_model_hparams): - p = defaults - vocab_size = self._encoders["inputs"].vocab_size - p.input_modality = {"inputs": (registry.Modalities.SYMBOL, vocab_size)} - p.target_modality = (registry.Modalities.SYMBOL, vocab_size) - p.input_space_id = problem.SpaceID.EN_TOK - p.target_space_id = problem.SpaceID.DE_TOK +def _default_character_feature_encoders(): + return { + "inputs": text_encoder.ByteTextEncoder(), + "targets": text_encoder.ByteTextEncoder(), + } -@registry.register_problem("wmt_ende_tokens_32k") -class WMTEnDeTokens32k(WMTEnDeTokens8k): +class WMTProblem(problem.Problem): + """Base class for WMT problems.""" @property - def target_vocab_size(self): - return 2**15 # 32768 + def is_character_level(self): + return False + @property + def targeted_vocab_size(self): + raise NotImplementedError() # Not needed if self.is_character_level. -def _default_wmt_feature_encoders(data_dir, target_vocab_size): - vocab_filename = os.path.join(data_dir, "tokens.vocab.%d" % target_vocab_size) - subtokenizer = text_encoder.SubwordTextEncoder(vocab_filename) - return { - "inputs": subtokenizer, - "targets": subtokenizer, - } + @property + def train_generator(self): + """Generator; takes tmp_dir, is_training, possibly targeted_vocab_size.""" + raise NotImplementedError() -@registry.register_problem("setimes_mken_tokens_32k") -class SETimesMkEnTokens32k(problem.Problem): - """Problem spec for SETimes Mk-En translation.""" + @property + def dev_generator(self): + return self.train_generator @property - def target_vocab_size(self): - return 2**15 # 32768 + def input_space_id(self): + raise NotImplementedError() - def feature_encoders(self, data_dir): - return _default_wmt_feature_encoders(data_dir, self.target_vocab_size) + @property + def target_space_id(self): + raise NotImplementedError() + + @property + def num_shards(self): + return 100 + + def generate_data(self, data_dir, tmp_dir, num_shards=None): + if num_shards is None: + num_shards = self.num_shards + if self.is_character_level: + generator_utils.generate_dataset_and_shuffle( + self.train_generator(tmp_dir, True), + self.training_filepaths(data_dir, num_shards, shuffled=False), + self.dev_generator(tmp_dir, False), + self.dev_filepaths(data_dir, 1, shuffled=False)) + else: + generator_utils.generate_dataset_and_shuffle( + self.train_generator(tmp_dir, True, self.targeted_vocab_size), + self.training_filepaths(data_dir, num_shards, shuffled=False), + self.dev_generator(tmp_dir, False, self.targeted_vocab_size), + self.dev_filepaths(data_dir, 1, shuffled=False)) - def generate_data(self, data_dir, tmp_dir, num_shards=100): - generator_utils.generate_dataset_and_shuffle( - mken_wordpiece_token_generator(tmp_dir, True, self.target_vocab_size), - self.training_filepaths(data_dir, num_shards, shuffled=False), - mken_wordpiece_token_generator(tmp_dir, False, self.target_vocab_size), - self.dev_filepaths(data_dir, 1, shuffled=False)) + def feature_encoders(self, data_dir): + if self.is_character_level: + return _default_character_feature_encoders() + return _default_token_feature_encoders(data_dir, self.targeted_vocab_size) def hparams(self, defaults, unused_model_hparams): p = defaults - vocab_size = self._encoders["inputs"].vocab_size - p.input_modality = {"inputs": (registry.Modalities.SYMBOL, vocab_size)} - p.target_modality = (registry.Modalities.SYMBOL, vocab_size) - p.input_space_id = problem.SpaceID.MK_TOK - p.target_space_id = problem.SpaceID.EN_TOK + if self.is_character_level: + source_vocab_size = 256 + target_vocab_size = 256 + else: + source_vocab_size = self._encoders["inputs"].vocab_size + target_vocab_size = self._encoders["targets"].vocab_size + p.input_modality = {"inputs": (registry.Modalities.SYMBOL, + source_vocab_size)} + p.target_modality = (registry.Modalities.SYMBOL, target_vocab_size) + p.input_space_id = self.input_space_id + p.target_space_id = self.target_space_id + if self.is_character_level: + p.loss_multiplier = 2.0 -# End-of-sentence marker. -EOS = text_encoder.EOS_TOKEN + +# Generic generators used later for multiple problems. def character_generator(source_path, target_path, character_vocab, eos=None): @@ -233,29 +253,7 @@ def bi_vocabs_token_generator(source_path, source, target = source_file.readline(), target_file.readline() -def _get_wmt_ende_dataset(directory, filename): - """Extract the WMT en-de corpus `filename` to directory unless it's there.""" - train_path = os.path.join(directory, filename) - if not (tf.gfile.Exists(train_path + ".de") and - tf.gfile.Exists(train_path + ".en")): - # We expect that this file has been downloaded from: - # https://drive.google.com/open?id=0B_bZck-ksdkpM25jRUN2X2UxMm8 and placed - # in `directory`. - corpus_file = os.path.join(directory, FLAGS.ende_bpe_path) - with tarfile.open(corpus_file, "r:gz") as corpus_tar: - corpus_tar.extractall(directory) - return train_path - - -def ende_bpe_token_generator(tmp_dir, train): - """Instance of token generator for the WMT en->de task, training set.""" - dataset_path = ("train.tok.clean.bpe.32000" - if train else "newstest2013.tok.bpe.32000") - train_path = _get_wmt_ende_dataset(tmp_dir, dataset_path) - token_path = os.path.join(tmp_dir, "vocab.bpe.32000") - token_vocab = text_encoder.TokenTextEncoder(vocab_filename=token_path) - return token_generator(train_path + ".en", train_path + ".de", token_vocab, - EOS) +# Data-set URLs. _ENDE_TRAIN_DATASETS = [ @@ -336,6 +334,34 @@ def ende_bpe_token_generator(tmp_dir, train): ]] +# Generators. + + +def _get_wmt_ende_dataset(directory, filename): + """Extract the WMT en-de corpus `filename` to directory unless it's there.""" + train_path = os.path.join(directory, filename) + if not (tf.gfile.Exists(train_path + ".de") and + tf.gfile.Exists(train_path + ".en")): + # We expect that this file has been downloaded from: + # https://drive.google.com/open?id=0B_bZck-ksdkpM25jRUN2X2UxMm8 and placed + # in `directory`. + corpus_file = os.path.join(directory, FLAGS.ende_bpe_path) + with tarfile.open(corpus_file, "r:gz") as corpus_tar: + corpus_tar.extractall(directory) + return train_path + + +def ende_bpe_token_generator(tmp_dir, train): + """Instance of token generator for the WMT en->de task, training set.""" + dataset_path = ("train.tok.clean.bpe.32000" + if train else "newstest2013.tok.bpe.32000") + train_path = _get_wmt_ende_dataset(tmp_dir, dataset_path) + token_path = os.path.join(tmp_dir, "vocab.bpe.32000") + token_vocab = text_encoder.TokenTextEncoder(vocab_filename=token_path) + return token_generator(train_path + ".en", train_path + ".de", token_vocab, + EOS) + + def _compile_data(tmp_dir, datasets, filename): """Concatenate all `datasets` and save to `filename`.""" filename = os.path.join(tmp_dir, filename) @@ -386,6 +412,35 @@ def ende_wordpiece_token_generator(tmp_dir, train, vocab_size): symbolizer_vocab, EOS) +@registry.register_problem("wmt_ende_tokens_8k") +class WMTEnDeTokens8k(WMTProblem): + """Problem spec for WMT En-De translation.""" + + @property + def targeted_vocab_size(self): + return 2**13 # 8192 + + @property + def train_generator(self): + return ende_wordpiece_token_generator + + @property + def input_space_id(self): + return problem.SpaceID.EN_TOK + + @property + def target_space_id(self): + return problem.SpaceID.DE_TOK + + +@registry.register_problem("wmt_ende_tokens_32k") +class WMTEnDeTokens32k(WMTEnDeTokens8k): + + @property + def targeted_vocab_size(self): + return 2**15 # 32768 + + def ende_character_generator(tmp_dir, train): character_vocab = text_encoder.ByteTextEncoder() datasets = _ENDE_TRAIN_DATASETS if train else _ENDE_TEST_DATASETS @@ -395,8 +450,29 @@ def ende_character_generator(tmp_dir, train): character_vocab, EOS) -def zhen_wordpiece_token_generator(tmp_dir, train, source_vocab_size, - target_vocab_size): +@registry.register_problem("wmt_ende_characters") +class WMTEnDeCharacters(WMTProblem): + """Problem spec for WMT En-De translation.""" + + @property + def is_character_level(self): + return True + + @property + def train_generator(self): + return ende_character_generator + + @property + def input_space_id(self): + return problem.SpaceID.EN_CHR + + @property + def target_space_id(self): + return problem.SpaceID.DE_CHR + + +def zhen_wordpiece_token_bigenerator(tmp_dir, train, source_vocab_size, + target_vocab_size): """Wordpiece generator for the WMT'17 zh-en dataset.""" datasets = _ZHEN_TRAIN_DATASETS if train else _ZHEN_TEST_DATASETS source_datasets = [[item[0], [item[1][0]]] for item in datasets] @@ -413,6 +489,53 @@ def zhen_wordpiece_token_generator(tmp_dir, train, source_vocab_size, source_vocab, target_vocab, EOS) +def zhen_wordpiece_token_generator(tmp_dir, train, vocab_size): + return zhen_wordpiece_token_bigenerator(tmp_dir, train, + vocab_size, vocab_size) + + +@registry.register_problem("wmt_zhen_tokens_8k") +class WMTZhEnTokens8k(WMTProblem): + """Problem spec for WMT Zh-En translation.""" + + @property + def targeted_vocab_size(self): + return 2**13 # 8192 + + @property + def train_generator(self): + return zhen_wordpiece_token_generator + + @property + def input_space_id(self): + return problem.SpaceID.ZH_TOK + + @property + def target_space_id(self): + return problem.SpaceID.EN_TOK + + def feature_encoders(self, data_dir): + vocab_size = self.targeted_vocab_size + source_vocab_filename = os.path.join(data_dir, + "tokens.vocab.zh.%d" % vocab_size) + target_vocab_filename = os.path.join(data_dir, + "tokens.vocab.en.%d" % vocab_size) + source_token = text_encoder.SubwordTextEncoder(source_vocab_filename) + target_token = text_encoder.SubwordTextEncoder(target_vocab_filename) + return { + "inputs": source_token, + "targets": target_token, + } + + +@registry.register_problem("wmt_zhen_tokens_32k") +class WMTZhEnTokens32k(WMTZhEnTokens8k): + + @property + def targeted_vocab_size(self): + return 2**15 # 32768 + + def enfr_wordpiece_token_generator(tmp_dir, train, vocab_size): """Instance of token generator for the WMT en->fr task.""" symbolizer_vocab = generator_utils.get_or_generate_vocab( @@ -424,6 +547,35 @@ def enfr_wordpiece_token_generator(tmp_dir, train, vocab_size): symbolizer_vocab, EOS) +@registry.register_problem("wmt_enfr_tokens_8k") +class WMTEnFrTokens8k(WMTProblem): + """Problem spec for WMT En-Fr translation.""" + + @property + def targeted_vocab_size(self): + return 2**13 # 8192 + + @property + def train_generator(self): + return enfr_wordpiece_token_generator + + @property + def input_space_id(self): + return problem.SpaceID.EN_TOK + + @property + def target_space_id(self): + return problem.SpaceID.FR_TOK + + +@registry.register_problem("wmt_enfr_tokens_32k") +class WMTEnFrTokens32k(WMTEnFrTokens8k): + + @property + def targeted_vocab_size(self): + return 2**15 # 32768 + + def enfr_character_generator(tmp_dir, train): """Instance of character generator for the WMT en->fr task.""" character_vocab = text_encoder.ByteTextEncoder() @@ -433,6 +585,28 @@ def enfr_character_generator(tmp_dir, train): return character_generator(data_path + ".lang1", data_path + ".lang2", character_vocab, EOS) + +@registry.register_problem("wmt_enfr_characters") +class WMTEnFrCharacters(WMTProblem): + """Problem spec for WMT En-Fr translation.""" + + @property + def is_character_level(self): + return True + + @property + def train_generator(self): + return enfr_character_generator + + @property + def input_space_id(self): + return problem.SpaceID.EN_CHR + + @property + def target_space_id(self): + return problem.SpaceID.FR_CHR + + def mken_wordpiece_token_generator(tmp_dir, train, vocab_size): """Wordpiece generator for the SETimes Mk-En dataset.""" datasets = _MKEN_TRAIN_DATASETS if train else _MKEN_TEST_DATASETS @@ -447,6 +621,27 @@ def mken_wordpiece_token_generator(tmp_dir, train, vocab_size): symbolizer_vocab, EOS) +@registry.register_problem("setimes_mken_tokens_32k") +class SETimesMkEnTokens32k(WMTProblem): + """Problem spec for SETimes Mk-En translation.""" + + @property + def targeted_vocab_size(self): + return 2**15 # 32768 + + @property + def train_generator(self): + return mken_wordpiece_token_generator + + @property + def input_space_id(self): + return problem.SpaceID.MK_TOK + + @property + def target_space_id(self): + return problem.SpaceID.EN_TOK + + def parsing_character_generator(tmp_dir, train): character_vocab = text_encoder.ByteTextEncoder() filename = "parsing_%s" % ("train" if train else "dev") diff --git a/tensor2tensor/models/attention_lm.py b/tensor2tensor/models/attention_lm.py index 947dc9306..752de038e 100644 --- a/tensor2tensor/models/attention_lm.py +++ b/tensor2tensor/models/attention_lm.py @@ -101,8 +101,6 @@ def attention_lm_decoder(decoder_input, y: a Tensors """ x = decoder_input - # Summaries don't work in multi-problem setting yet. - summaries = "problems" not in hparams.values() or len(hparams.problems) == 1 with tf.variable_scope(name): for layer in xrange(hparams.num_hidden_layers): with tf.variable_scope("layer_%d" % layer): @@ -117,7 +115,6 @@ def attention_lm_decoder(decoder_input, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, - summaries=summaries, name="decoder_self_attention")) x = residual_fn(x, common_layers.conv_hidden_relu( diff --git a/tensor2tensor/models/attention_lm_moe.py b/tensor2tensor/models/attention_lm_moe.py index 952ff1a71..2754e8366 100644 --- a/tensor2tensor/models/attention_lm_moe.py +++ b/tensor2tensor/models/attention_lm_moe.py @@ -69,7 +69,6 @@ def residual_fn(x, y): hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, - summaries=True, name="decoder_self_attention") x = dp(residual_fn, x, y) with tf.variable_scope("ffn"): diff --git a/tensor2tensor/models/common_attention.py b/tensor2tensor/models/common_attention.py index 49cd40285..6aa8a2a07 100644 --- a/tensor2tensor/models/common_attention.py +++ b/tensor2tensor/models/common_attention.py @@ -280,13 +280,13 @@ def attention_image_summary(attn, image_shapes=None): (query_rows, query_cols, query_channels, memory_rows, memory_cols, memory_channels). """ - num_heads = attn.get_shape().as_list()[1] + num_heads = tf.shape(attn)[1] # [batch, query_length, memory_length, num_heads] image = tf.transpose(attn, [0, 2, 3, 1]) image = tf.pow(image, 0.2) # for high-dynamic-range # Each head will correspond to one of RGB. # pad the heads to be a multiple of 3 - image = tf.pad(image, [[0, 0], [0, 0], [0, 0], [0, -num_heads % 3]]) + image = tf.pad(image, [[0, 0], [0, 0], [0, 0], [0, tf.mod(-num_heads, 3)]]) image = split_last_dimension(image, 3) image = tf.reduce_max(image, 4) if image_shapes is not None: @@ -312,7 +312,6 @@ def dot_product_attention(q, v, bias, dropout_rate=0.0, - summaries=False, image_shapes=None, name=None): """dot-product attention. @@ -323,7 +322,6 @@ def dot_product_attention(q, v: a Tensor with shape [batch, heads, length_kv, depth_v] bias: bias Tensor (see attention_bias()) dropout_rate: a floating point number - summaries: a boolean image_shapes: optional tuple of integer scalars. see comments for attention_image_summary() name: an optional string @@ -340,11 +338,99 @@ def dot_product_attention(q, weights = tf.nn.softmax(logits, name="attention_weights") # dropping out the attention links for each of the heads weights = tf.nn.dropout(weights, 1.0 - dropout_rate) - if summaries and not tf.get_variable_scope().reuse: + if not tf.get_variable_scope().reuse: attention_image_summary(weights, image_shapes) return tf.matmul(weights, v) +def masked_local_attention_1d( + q, k, v, block_length=128, name=None): + """Attention to the source position and a neigborhood to the left of it. + + The sequence is divided into blocks of length block_size. + Attention for a given query position can only see memory positions + less than or equal to the query position, in the corresponding block + and the previous block. + + If mask_right is True, then a target position cannot see greater source + positions. + + Args: + q: a Tensor with shape [batch, heads, length, depth_k] + k: a Tensor with shape [batch, heads, length, depth_k] + v: a Tensor with shape [batch, heads, length, depth_v] + block_length: an integer + name: an optional string + + Returns: + a Tensor of shape [batch, heads, length, depth_v] + """ + with tf.variable_scope(name, default_name="local_attention_1d", + values=[q, k, v]): + v_shape = v.get_shape() + batch = tf.shape(q)[0] + heads = tf.shape(q)[1] + length = tf.shape(q)[2] + # If (length < 2 * block_length), then we use only one block. + block_length = tf.where(tf.less(length, block_length * 2), + length, block_length) + depth_k = tf.shape(q)[3] + depth_v = tf.shape(v)[3] + original_length = length + padding_size = tf.mod(-length, block_length) + length += padding_size + padding = [[0, 0], [0, 0], [0, padding_size], [0, 0]] + q = tf.pad(q, padding) + k = tf.pad(k, padding) + v = tf.pad(v, padding) + num_blocks = tf.div(length, block_length) + + # compute attention for the first query block. + first_q = tf.slice(q, [0, 0, 0, 0], [-1, -1, block_length, -1]) + first_k = tf.slice(k, [0, 0, 0, 0], [-1, -1, block_length, -1]) + first_v = tf.slice(v, [0, 0, 0, 0], [-1, -1, block_length, -1]) + first_output = dot_product_attention( + first_q, first_k, first_v, attention_bias_lower_triangle(block_length), + name="fist_block") + + # compute attention for all subsequent query blocks. + q = tf.reshape(q, [batch, heads, num_blocks, block_length, depth_k]) + k = tf.reshape(k, [batch, heads, num_blocks, block_length, depth_k]) + v = tf.reshape(v, [batch, heads, num_blocks, block_length, depth_v]) + + def local(x): + """Create a local version of the keys or values.""" + prev_block = tf.slice( + x, [0, 0, 0, 0, 0], [-1, -1, num_blocks - 1, -1, -1]) + cur_block = tf.slice( + x, [0, 0, 1, 0, 0], [-1, -1, -1, -1, -1]) + return tf.concat([prev_block, cur_block], 3) + local_k = local(k) + local_v = local(v) + tail_q = tf.slice(q, [0, 0, 1, 0, 0], [-1, -1, -1, -1, -1]) + + local_length = tf.shape(local_k)[3] + + # [batch, heads, num_blocks - 1, block_length, local_length] + attention = tf.matmul(tail_q, local_k, transpose_b=True) + + # make sure source_pos <= target_pos + good_part = tf.matrix_band_part( + tf.ones([block_length, local_length]), -1, tf.to_int64(block_length)) + mask = (1.0 - good_part) * -1e9 + attention += tf.reshape(mask, [1, 1, 1, block_length, local_length]) + attention = tf.nn.softmax(attention) + # TODO(noam): figure out how to show a summary for the remaining blocks. + # The naive way currently causes errors due to empty tensors. + # output: [batch, heads, num_blocks-1, block_length, depth_v] + output = tf.matmul(attention, local_v) + output = tf.reshape(output, [batch, heads, -1, depth_v]) + output = tf.concat([first_output, output], axis=2) + output = tf.slice(output, [0, 0, 0, 0], [-1, -1, original_length, -1]) + output.set_shape(v_shape) + return output + + def multihead_attention(query_antecedent, memory_antecedent, bias, @@ -353,8 +439,9 @@ def multihead_attention(query_antecedent, output_depth, num_heads, dropout_rate, - summaries=False, image_shapes=None, + attention_type="dot_product", + block_length=128, name=None): """Multihead scaled-dot-product attention with input/output transformations. @@ -367,9 +454,10 @@ def multihead_attention(query_antecedent, output_depth: an integer num_heads: an integer dividing total_key_depth and total_value_depth dropout_rate: a floating point number - summaries: a boolean image_shapes: optional tuple of integer scalars. see comments for attention_image_summary() + attention_type: a string, either "dot_product" or "local_mask_right" + block_length: an integer - relevent for "local_mask_right" name: an optional string Returns: @@ -414,8 +502,12 @@ def multihead_attention(query_antecedent, v = split_heads(v, num_heads) key_depth_per_head = total_key_depth // num_heads q *= key_depth_per_head**-0.5 - x = dot_product_attention( - q, k, v, bias, dropout_rate, summaries, image_shapes) + if attention_type == "dot_product": + x = dot_product_attention( + q, k, v, bias, dropout_rate, image_shapes) + else: + assert attention_type == "local_mask_right" + x = masked_local_attention_1d(q, k, v, block_length=block_length) x = combine_heads(x) x = common_layers.conv1d(x, output_depth, 1, name="output_transform") return x diff --git a/tensor2tensor/models/common_hparams.py b/tensor2tensor/models/common_hparams.py index f067b724e..ff856968b 100644 --- a/tensor2tensor/models/common_hparams.py +++ b/tensor2tensor/models/common_hparams.py @@ -72,6 +72,9 @@ def basic_params1(): # setting the max length in a minibatch. 0 means default behavior, # max_length = hparams.batch_size * length_multiplier max_length=0, + # If set to True, drop sequences longer than max_length during eval. + # This affects the validity of the evaluation metrics. + eval_drop_long_sequences=int(False), # in SymbolModality, share the output embeddings and the softmax # variables. # You can also share the input embeddings with the output embeddings diff --git a/tensor2tensor/models/common_layers.py b/tensor2tensor/models/common_layers.py index 1e7050570..638535aa2 100644 --- a/tensor2tensor/models/common_layers.py +++ b/tensor2tensor/models/common_layers.py @@ -777,7 +777,7 @@ def moe_layer(data_parallelism, xs_2d = dp(tf.reshape, xs, [[-1, model_hidden_size]] * dp.n) # Call the MoE moe_out_2d, importance, load, _, _ = moe.Eval( - dp.devices, xs_2d, train, identifiers=None, summaries=True) + dp.devices, xs_2d, train, identifiers=None) # Reshape the output to the original shape. moe_out = dp(tf.reshape, moe_out_2d, dp(tf.shape, xs)) # These losses encourage equal load on the different experts. @@ -785,7 +785,7 @@ def moe_layer(data_parallelism, return moe_out, loss -def simple_attention(target, source, bias=None, summaries=True): +def simple_attention(target, source, bias=None): """A simple attention function. Args: @@ -795,7 +795,6 @@ def simple_attention(target, source, bias=None, summaries=True): `[batch, source_timesteps_1, source_timesteps_2, depth]` bias: an optional `Tensor` with shape `[batch, timesteps, 1, 1]` used to mask the attention to not attend to padding of input. - summaries: Boolean, whether to output summaries. Returns: a `Tensor` with same shape as `target` @@ -814,7 +813,7 @@ def simple_attention(target, source, bias=None, summaries=True): if bias is not None: attention += tf.expand_dims(tf.squeeze(bias, axis=[2, 3]), axis=1) attention = tf.nn.softmax(attention) - if summaries and not tf.get_variable_scope().reuse: + if not tf.get_variable_scope().reuse: tf.summary.image("attention", tf.expand_dims(attention, 3), max_outputs=5) attended = tf.matmul(attention, source) return tf.reshape(attended, target_shape) @@ -861,8 +860,7 @@ def multiscale_conv_sum(inputs, output_size, dilation_rates_and_kernel_sizes, def multiscale_conv_and_attention(x, padding, hparams, - source=None, - summaries=True): + source=None): """A common part of t2t layers. First, do a linear multiscale convolution @@ -875,7 +873,6 @@ def multiscale_conv_and_attention(x, padding: a padding type hparams: hyperparameters for model source: optional source tensor for attention. (encoder output) - summaries: Boolean, whether to output summaries. Returns: a Tensor. @@ -893,7 +890,7 @@ def multiscale_conv_and_attention(x, x = conv(x, hparams.hidden_size, (1, 1)) x = noam_norm(x + conv_sum) if source is not None: - x = noam_norm(x + simple_attention(x, source, summaries=summaries)) + x = noam_norm(x + simple_attention(x, source)) return x @@ -930,8 +927,7 @@ def conv_with_pools(inputs, output_size, kernel_size, pool_sizes, pooling_type, def conv_with_pools_and_attention(x, padding, hparams, - source=None, - summaries=True): + source=None): """A common part of t2t layers. First, do conv_with_pools @@ -944,7 +940,6 @@ def conv_with_pools_and_attention(x, padding: a padding type hparams: hyperparameters for model source: optional source tensor for attention. (encoder output) - summaries: Boolean, whether to output summaries. Returns: a Tensor. @@ -959,7 +954,7 @@ def conv_with_pools_and_attention(x, conv_sum += x x = noam_norm(conv_sum) if source is not None: - x = noam_norm(x + simple_attention(x, source, summaries=summaries)) + x = noam_norm(x + simple_attention(x, source)) return x @@ -1057,7 +1052,6 @@ def attention_1d_v0(source, transform_source=True, transform_target=True, transform_output=True, - summaries=True, name=None): """multi-headed attention. @@ -1075,7 +1069,6 @@ def attention_1d_v0(source, transform_source: a boolean transform_target: a boolean transform_output: a boolean - summaries: a boolean name: an optional string Returns: @@ -1116,7 +1109,7 @@ def _maybe_transform(t, size, should_transform, name): mask = (1.0 - mask) * -1e9 attention += mask attention = tf.nn.softmax(attention) - if summaries and not tf.get_variable_scope().reuse: + if not tf.get_variable_scope().reuse: # Compute a color image summary. image = tf.reshape(attention, [batch, num_heads, target_length, source_length]) @@ -1162,7 +1155,6 @@ def conv_hidden_relu(inputs, output_size, kernel_size=(1, 1), second_kernel_size=(1, 1), - summaries=True, dropout=0.0, **kwargs): """Hidden layer with RELU activation followed by linear projection.""" @@ -1183,7 +1175,7 @@ def conv_hidden_relu(inputs, **kwargs) if dropout != 0.0: h = tf.nn.dropout(h, 1.0 - dropout) - if summaries and not tf.get_variable_scope().reuse: + if not tf.get_variable_scope().reuse: tf.summary.histogram("hidden_density_logit", relu_density_logit( h, list(range(inputs.shape.ndims - 1)))) diff --git a/tensor2tensor/models/long_answer.py b/tensor2tensor/models/long_answer.py new file mode 100644 index 000000000..7bb6a4a55 --- /dev/null +++ b/tensor2tensor/models/long_answer.py @@ -0,0 +1,274 @@ +# Copyright 2017 The Tensor2Tensor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Model to generate long answers to short questions. + +E.g. wiki_32k title->article dataset. + +Variant on attention_lm_moe.py + - prepend the inputs to the targets. + - use masked local attention to avoid quadratic space and time blowup for + long sequences. + +This model is still highly experimental and under rapid iteration. + +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# Dependency imports + +from six.moves import xrange # pylint: disable=redefined-builtin + +from tensor2tensor.models import common_attention +from tensor2tensor.models import common_hparams +from tensor2tensor.models import common_layers +from tensor2tensor.utils import registry +from tensor2tensor.utils import t2t_model + +import tensorflow as tf + + +@registry.register_model +class LongAnswer(t2t_model.T2TModel): + """Attention net. See file docstring.""" + + def model_fn_body_sharded(self, sharded_features): + # Remove dropout if not training + hparams = self._hparams + dp = self._data_parallelism + targets = sharded_features["targets"] + targets = dp(tf.squeeze, targets, 2) + inputs = sharded_features["inputs"] + inputs = dp(tf.squeeze, inputs, 2) + + decoder_input = dp(long_answer_prepare_decoder, inputs, targets, hparams) + + def residual_fn(x, y): + return common_layers.layer_norm(x + tf.nn.dropout( + y, 1.0 - hparams.residual_dropout)) + + x = dp(tf.nn.dropout, decoder_input, 1.0 - hparams.residual_dropout) + extra_loss = 0.0 + for layer in xrange(hparams.num_hidden_layers): + with tf.variable_scope("layer_%d" % layer): + with tf.variable_scope("attention"): + y = dp(common_attention.multihead_attention, + x, + None, + None, + hparams.attention_key_channels or hparams.hidden_size, + hparams.attention_value_channels or hparams.hidden_size, + hparams.hidden_size, + hparams.num_heads, + hparams.attention_dropout, + attention_type="local_mask_right", + block_length=hparams.block_length, + name="decoder_self_attention") + x = dp(residual_fn, x, y) + with tf.variable_scope("ffn"): + if str(layer) in hparams.moe_layers.split(","): + y, loss = common_layers.moe_layer( + dp, self._ps_devices, x, + hparams.mode == tf.contrib.learn.ModeKeys.TRAIN, + hparams.hidden_size, + hparams.moe_hidden_size, hparams.moe_n1, hparams.moe_n2, + hparams.moe_loss_coef) + extra_loss += loss + else: + y = dp(common_layers.conv_hidden_relu, + x, + hparams.filter_size, + hparams.hidden_size, + dropout=hparams.relu_dropout) + x = dp(residual_fn, x, y) + x = dp(long_answer_output, x, inputs) + return x, extra_loss + + +def long_answer_prepare_decoder(inputs, targets, hparams): + """Prepare one shard of the model for the decoder. + + Args: + inputs: a Tensor. + targets: a Tensor. + hparams: run hyperparameters + + Returns: + decoder_input: a Tensor, bottom of decoder stack + """ + decoder_input = tf.concat([ + length_embedding(targets, hparams), inputs, + common_layers.shift_left_3d(targets)], 1) + if hparams.pos == "timing": + decoder_input = common_attention.add_timing_signal_1d(decoder_input) + return decoder_input + + +def length_embedding(targets, hparams): + """An embedding indicating approximate target length. + + This is a bit of a hack, where we want to be able to request a particular + target length during inference. + During training, we sometimes provide a target length. + During eval, we never provide a target length. + + Args: + targets: a Tensor. + hparams: run hyperparameters + + Returns: + a Tensor with shape [batch, 1, hparams.hidden_size] + """ + # encode the approx target length in case we want to specify it + # during inference. + batch = tf.shape(targets)[0] + padded_target_length = tf.shape(targets)[1] + if hparams.mode == tf.contrib.learn.ModeKeys.TRAIN: + lengths = padded_target_length * tf.to_int32( + tf.less(tf.random_uniform([batch]), + hparams.answer_length_prob_train)) + elif hparams.mode == tf.contrib.learn.ModeKeys.EVAL: + lengths = 0 + else: + assert hparams.mode == tf.contrib.learn.ModeKeys.INFER + lengths = hparams.answer_length_infer + lengths = tf.to_int32(tf.log(tf.to_float(lengths + 1))) + lengths = tf.zeros([batch], dtype=tf.int32) + lengths + ret = tf.gather( + tf.get_variable("answer_length", [100, hparams.hidden_size]), lengths) + return tf.expand_dims(ret, 1) + + +def long_answer_output(x, inputs): + """Strip initial part corresponding to the inputs and the length embedding.""" + x = tf.slice(x, [0, tf.shape(inputs)[1] + 1, 0], [-1, -1, -1]) + x = tf.expand_dims(x, 2) + return x + + +@registry.register_hparams +def long_answer_base(): + """Set of hyperparameters. + + Returns: + a hparams object + """ + hparams = common_hparams.basic_params1() + hparams.hidden_size = 1024 + hparams.batch_size = 8192 + hparams.max_length = 8192 + hparams.dropout = 0.0 + hparams.batching_mantissa_bits = 3 + hparams.clip_grad_norm = 0. # i.e. no gradient clipping + hparams.optimizer_adam_epsilon = 1e-9 + hparams.learning_rate_decay_scheme = "noam" + hparams.learning_rate = 0.1 + hparams.learning_rate_warmup_steps = 1000 + hparams.initializer_gain = 1.0 + hparams.num_hidden_layers = 4 + hparams.initializer = "uniform_unit_scaling" + hparams.weight_decay = 0.0 + hparams.optimizer_adam_beta1 = 0.9 + hparams.optimizer_adam_beta2 = 0.98 + hparams.num_sampled_classes = 0 + hparams.label_smoothing = 0.0 + hparams.shared_embedding_and_softmax_weights = int(True) + hparams.sampling_method = "random" + hparams.add_hparam("filter_size", 2048) # Add new ones like this. + # comma-separated list of layer numbers. + # At each of these layers, we replace the ffn with a mixture of experts. + hparams.add_hparam("moe_layers", "2") + # If moe_n2 is None, then use a flat MoE with moe_n1 experts. + # If moe_n2 is an integer, then use a hierarchical MoE + # consisting of moe_n1 groups of moe_n2 experts each. + hparams.add_hparam("moe_n1", 64) + hparams.add_hparam("moe_n2", 0) + hparams.add_hparam("moe_hidden_size", 2048) + hparams.add_hparam("moe_loss_coef", 1e-2) + # attention-related flags + hparams.add_hparam("num_heads", 8) + hparams.add_hparam("attention_key_channels", 0) + hparams.add_hparam("attention_value_channels", 0) + # All hyperparameters ending in "dropout" are automatically set to 0.0 + # when not in training mode. + hparams.add_hparam("attention_dropout", 0.0) + hparams.add_hparam("relu_dropout", 0.0) + hparams.add_hparam("residual_dropout", 0.0) + hparams.add_hparam("pos", "timing") # timing, none + hparams.add_hparam("block_length", 512) + hparams.add_hparam("answer_length_prob_train", 0.5) + hparams.add_hparam("answer_length_infer", 1000) + # We cannot handle long sequence at this point, so drop them, during eval. + # This affects evaluation metrics. + # TODO(noam): find a different workaround + hparams.eval_drop_long_sequences = int(True) + return hparams + + +@registry.register_hparams +def long_answer_tiny(): + """Cheap model for validation. + + Returns: + an hparams object. + """ + hparams = long_answer_base() + hparams.num_hidden_layers = 3 + hparams.hidden_size = 512 + hparams.filter_size = 1024 + hparams.moe_layers = "2" + hparams.moe_hidden_size = 1024 + hparams.block_length = 128 + hparams.moe_n1 = 8 + hparams.batch_size = 2048 + hparams.max_length = 2048 + return hparams + + +@registry.register_hparams +def long_answer_small(): + """Cheap model for single-gpu training. + + Returns: + an hparams object. + """ + hparams = long_answer_base() + hparams.num_hidden_layers = 4 + hparams.hidden_size = 512 + hparams.filter_size = 2048 + hparams.moe_n1 = 128 + hparams.moe_layers = "2" + hparams.moe_hidden_size = 2048 + return hparams + + +@registry.register_hparams +def long_answer_large(): + """Large model for distributed training. + + Returns: + an hparams object. + """ + hparams = long_answer_base() + hparams.num_hidden_layers = 5 + hparams.moe_layers = "3" + hparams.hidden_size = 1024 + hparams.filter_size = 4096 + hparams.moe_hidden_size = 4096 + hparams.moe_n1 = 128 + hparams.block_length = 1024 + return hparams diff --git a/tensor2tensor/models/models.py b/tensor2tensor/models/models.py index 0ca11996e..2cf639426 100644 --- a/tensor2tensor/models/models.py +++ b/tensor2tensor/models/models.py @@ -26,6 +26,7 @@ from tensor2tensor.models import attention_lm_moe from tensor2tensor.models import bluenet from tensor2tensor.models import bytenet +from tensor2tensor.models import long_answer from tensor2tensor.models import lstm from tensor2tensor.models import modalities from tensor2tensor.models import multimodel diff --git a/tensor2tensor/models/multimodel.py b/tensor2tensor/models/multimodel.py index 6f12db86d..bf06dfd65 100644 --- a/tensor2tensor/models/multimodel.py +++ b/tensor2tensor/models/multimodel.py @@ -138,7 +138,6 @@ def flatten(inputs): hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, - summaries=False, name="decoder_self_attention") z = dp(common_attention.multihead_attention, y, @@ -149,7 +148,6 @@ def flatten(inputs): hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, - summaries=False, name="encdec_attention") x = dp(residual_fn3, x, y, z, hparams) with tf.variable_scope("ffn"): diff --git a/tensor2tensor/models/slicenet.py b/tensor2tensor/models/slicenet.py index 43913eab1..2ad4c89d1 100644 --- a/tensor2tensor/models/slicenet.py +++ b/tensor2tensor/models/slicenet.py @@ -64,8 +64,7 @@ def attention(targets_shifted, inputs_encoded, norm_fn, hparams, bias=None): hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, - name="self_attention", - summaries=False) + name="self_attention") qv = common_attention.multihead_attention( qv, inputs_encoded, @@ -75,12 +74,11 @@ def attention(targets_shifted, inputs_encoded, norm_fn, hparams, bias=None): hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, - name="encdec_attention", - summaries=False) + name="encdec_attention") return tf.expand_dims(qv, 2) elif hparams.attention_type == "simple": targets_with_attention = common_layers.simple_attention( - targets_timed, inputs_encoded, bias=bias, summaries=False) + targets_timed, inputs_encoded, bias=bias) return norm_fn(targets_shifted + targets_with_attention, name="attn_norm") diff --git a/tensor2tensor/models/transformer.py b/tensor2tensor/models/transformer.py index b341d6fe0..b24f7fa50 100644 --- a/tensor2tensor/models/transformer.py +++ b/tensor2tensor/models/transformer.py @@ -60,8 +60,6 @@ def residual_fn(x, y): return common_layers.layer_norm(x + tf.nn.dropout( y, 1.0 - hparams.residual_dropout)) - # encoder_input = tf.squeeze(encoder_input, 2) - # decoder_input = tf.squeeze(decoder_input, 2) encoder_input = tf.nn.dropout(encoder_input, 1.0 - hparams.residual_dropout) decoder_input = tf.nn.dropout(decoder_input, 1.0 - hparams.residual_dropout) encoder_output = transformer_encoder(encoder_input, residual_fn, @@ -145,8 +143,6 @@ def transformer_encoder(encoder_input, y: a Tensors """ x = encoder_input - # Summaries don't work in multi-problem setting yet. - summaries = "problems" not in hparams.values() or len(hparams.problems) == 1 with tf.variable_scope(name): for layer in xrange(hparams.num_hidden_layers): with tf.variable_scope("layer_%d" % layer): @@ -161,7 +157,6 @@ def transformer_encoder(encoder_input, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, - summaries=summaries, name="encoder_self_attention")) x = residual_fn(x, transformer_ffn_layer(x, hparams)) return x @@ -191,8 +186,6 @@ def transformer_decoder(decoder_input, y: a Tensors """ x = decoder_input - # Summaries don't work in multi-problem setting yet. - summaries = "problems" not in hparams.values() or len(hparams.problems) == 1 with tf.variable_scope(name): for layer in xrange(hparams.num_hidden_layers): with tf.variable_scope("layer_%d" % layer): @@ -207,7 +200,6 @@ def transformer_decoder(decoder_input, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, - summaries=summaries, name="decoder_self_attention")) x = residual_fn( x, @@ -220,7 +212,6 @@ def transformer_decoder(decoder_input, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, - summaries=summaries, name="encdec_attention")) x = residual_fn(x, transformer_ffn_layer(x, hparams)) return x diff --git a/tensor2tensor/models/transformer_alternative.py b/tensor2tensor/models/transformer_alternative.py index aed074d56..280dbc713 100644 --- a/tensor2tensor/models/transformer_alternative.py +++ b/tensor2tensor/models/transformer_alternative.py @@ -140,8 +140,6 @@ def alt_transformer_decoder(decoder_input, """Alternative decoder.""" x = decoder_input - # Summaries don't work in multi-problem setting yet. - summaries = "problems" not in hparams.values() or len(hparams.problems) == 1 with tf.variable_scope(name): for layer in xrange(hparams.num_hidden_layers): with tf.variable_scope("layer_%d" % layer): @@ -155,7 +153,6 @@ def alt_transformer_decoder(decoder_input, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, - summaries=summaries, name="encdec_attention") x_ = residual_fn(x_, composite_layer(x_, mask, hparams)) diff --git a/tensor2tensor/models/transformer_test.py b/tensor2tensor/models/transformer_test.py index ca099c653..997b5d172 100644 --- a/tensor2tensor/models/transformer_test.py +++ b/tensor2tensor/models/transformer_test.py @@ -38,6 +38,7 @@ def _testTransformer(self, net): hparams = transformer.transformer_tiny() p_hparams = problem_hparams.test_problem_hparams(hparams, vocab_size, vocab_size) + hparams.problems = [p_hparams] inputs = -1 + np.random.random_integers( vocab_size, size=(batch_size, input_length, 1, 1)) targets = -1 + np.random.random_integers( diff --git a/tensor2tensor/utils/data_reader.py b/tensor2tensor/utils/data_reader.py index a3e9835ac..cb84b9e3e 100644 --- a/tensor2tensor/utils/data_reader.py +++ b/tensor2tensor/utils/data_reader.py @@ -181,10 +181,13 @@ def input_pipeline(data_file_pattern, capacity, mode): """Input pipeline, returns a dictionary of tensors from queues.""" # Read from image TFRecords if the file has "image" in its name. if data_file_pattern and "image" in data_file_pattern: + label_key = "image/class/label" + if "fsns" in data_file_pattern: + label_key = "image/unpadded_label" data_fields = { "image/encoded": tf.FixedLenFeature((), tf.string), "image/format": tf.FixedLenFeature((), tf.string), - "image/class/label": tf.VarLenFeature(tf.int64) + label_key: tf.VarLenFeature(tf.int64) } data_items_to_decoders = { "inputs": @@ -193,7 +196,7 @@ def input_pipeline(data_file_pattern, capacity, mode): format_key="image/format", channels=1 if "mnist" in data_file_pattern else 3), "targets": - tf.contrib.slim.tfexample_decoder.Tensor("image/class/label"), + tf.contrib.slim.tfexample_decoder.Tensor(label_key), } elif data_file_pattern and "audio" in data_file_pattern: data_type = tf.int64 if "timit" in data_file_pattern else tf.float32 diff --git a/tensor2tensor/utils/trainer_utils.py b/tensor2tensor/utils/trainer_utils.py old mode 100755 new mode 100644 index 66a01487c..f7d3010a9 --- a/tensor2tensor/utils/trainer_utils.py +++ b/tensor2tensor/utils/trainer_utils.py @@ -550,6 +550,13 @@ def nth_model(n): optimizer=opt, colocate_gradients_with_ops=True) + # Remove summaries that will fail to run because they are in conditionals. + # TODO(cwhipkey): Test with this code removed, later in 2017. + summaries = tf.get_collection_ref(tf.GraphKeys.SUMMARIES) + for i in range(len(summaries)-1, -1, -1): + if summaries[i].name.startswith("cond_"): + del summaries[i] + tf.logging.info("Global model_fn finished.") return run_info, total_loss, train_op @@ -1037,7 +1044,10 @@ def input_fn(): capacity *= num_datashards examples = data_reader.input_pipeline(data_file_patterns[n], capacity, mode) - drop_long_sequences = mode == tf.contrib.learn.ModeKeys.TRAIN + if mode == tf.contrib.learn.ModeKeys.TRAIN: + drop_long_sequences = True + else: + drop_long_sequences = hparams.eval_drop_long_sequences batch_size_multiplier = hparams.problems[n].batch_size_multiplier feature_map = data_reader.batch_examples( examples, diff --git a/tensor2tensor/utils/usr_dir.py b/tensor2tensor/utils/usr_dir.py index ed5623c8e..0a2d0d15c 100644 --- a/tensor2tensor/utils/usr_dir.py +++ b/tensor2tensor/utils/usr_dir.py @@ -12,16 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Utility to load code from an external directory supplied by user.""" +"""Utility to load code from an external user-supplied directory.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import importlib import os import sys -import importlib + +# Dependency imports + import tensorflow as tf def import_usr_dir(usr_dir): - """Import user module, if provided.""" + """Import module at usr_dir, if provided.""" if not usr_dir: return dir_path = os.path.expanduser(usr_dir)