diff --git a/README.md b/README.md
index 27bb47947..059fbe429 100644
--- a/README.md
+++ b/README.md
@@ -153,7 +153,7 @@ python -c "from tensor2tensor.models.transformer import Transformer"
specification.
* Support for multi-GPU machines and synchronous (1 master, many workers) and
asynchrounous (independent workers synchronizing through a parameter server)
- distributed training.
+ [distributed training](https://github.com/tensorflow/tensor2tensor/tree/master/docs/distributed_training.md).
* Easily swap amongst datasets and models by command-line flag with the data
generation script `t2t-datagen` and the training script `t2t-trainer`.
@@ -173,8 +173,10 @@ and many common sequence datasets are already available for generation and use.
**Problems** define training-time hyperparameters for the dataset and task,
mainly by setting input and output **modalities** (e.g. symbol, image, audio,
-label) and vocabularies, if applicable. All problems are defined in
-[`problem_hparams.py`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/problem_hparams.py).
+label) and vocabularies, if applicable. All problems are defined either in
+[`problem_hparams.py`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/problem_hparams.py)
+or are registered with `@registry.register_problem` (run `t2t-datagen` to see
+the list of all available problems).
**Modalities**, defined in
[`modality.py`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/utils/modality.py),
abstract away the input and output data types so that **models** may deal with
@@ -211,7 +213,7 @@ inference. Users can easily switch between problems, models, and hyperparameter
sets by using the `--model`, `--problems`, and `--hparams_set` flags. Specific
hyperparameters can be overridden with the `--hparams` flag. `--schedule` and
related flags control local and distributed training/evaluation
-([distributed training documentation](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/docs/distributed_training.md)).
+([distributed training documentation](https://github.com/tensorflow/tensor2tensor/tree/master/docs/distributed_training.md)).
---
@@ -222,7 +224,7 @@ enables easily adding new ones and easily swapping amongst them by command-line
flag. You can add your own components without editing the T2T codebase by
specifying the `--t2t_usr_dir` flag in `t2t-trainer`.
-You can currently do so for models, hyperparameter sets, and modalities. Please
+You can do so for models, hyperparameter sets, modalities, and problems. Please
do submit a pull request if your component might be useful to others.
Here's an example with a new hyperparameter set:
@@ -253,9 +255,18 @@ You'll see under the registered HParams your
`transformer_my_very_own_hparams_set`, which you can directly use on the command
line with the `--hparams_set` flag.
+`t2t-datagen` also supports the `--t2t_usr_dir` flag for `Problem`
+registrations.
+
## Adding a dataset
-See the [data generators
+To add a new dataset, subclass
+[`Problem`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/problem.py)
+and register it with `@registry.register_problem`. See
+[`WMTEnDeTokens8k`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/wmt.py)
+for an example.
+
+Also see the [data generators
README](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/README.md).
---
diff --git a/tensor2tensor/docs/distributed_training.md b/docs/distributed_training.md
similarity index 100%
rename from tensor2tensor/docs/distributed_training.md
rename to docs/distributed_training.md
diff --git a/docs/index.md b/docs/index.md
new file mode 100644
index 000000000..a5eeba137
--- /dev/null
+++ b/docs/index.md
@@ -0,0 +1,23 @@
+# T2T: Tensor2Tensor Transformers
+
+Check us out on
+
+GitHub
+
+
+.
+
+[![PyPI
+version](https://badge.fury.io/py/tensor2tensor.svg)](https://badge.fury.io/py/tensor2tensor)
+[![GitHub
+Issues](https://img.shields.io/github/issues/tensorflow/tensor2tensor.svg)](https://github.com/tensorflow/tensor2tensor/issues)
+[![Contributions
+welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg)](CONTRIBUTING.md)
+[![Gitter](https://img.shields.io/gitter/room/nwjs/nw.js.svg)](https://gitter.im/tensor2tensor/Lobby)
+[![License](https://img.shields.io/badge/License-Apache%202.0-brightgreen.svg)](https://opensource.org/licenses/Apache-2.0)
+
+See our
+[README](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/README.md)
+for documentation.
+
+More documentation and tutorials coming soon...
diff --git a/setup.py b/setup.py
index 00325cff2..d8fd19cf4 100644
--- a/setup.py
+++ b/setup.py
@@ -5,7 +5,7 @@
setup(
name='tensor2tensor',
- version='1.0.14',
+ version='1.1.0',
description='Tensor2Tensor',
author='Google Inc.',
author_email='no-reply@google.com',
diff --git a/tensor2tensor/bin/t2t-datagen b/tensor2tensor/bin/t2t-datagen
old mode 100755
new mode 100644
index cbf0a6164..b0fd816a2
--- a/tensor2tensor/bin/t2t-datagen
+++ b/tensor2tensor/bin/t2t-datagen
@@ -35,7 +35,6 @@ import tempfile
import numpy as np
-from tensor2tensor.data_generators import algorithmic
from tensor2tensor.data_generators import algorithmic_math
from tensor2tensor.data_generators import all_problems # pylint: disable=unused-import
from tensor2tensor.data_generators import audio
@@ -60,52 +59,22 @@ flags.DEFINE_string("tmp_dir", "/tmp/t2t_datagen",
"Temporary storage directory.")
flags.DEFINE_string("problem", "",
"The name of the problem to generate data for.")
+flags.DEFINE_string("exclude_problems", "",
+ "Comma-separates list of problems to exclude.")
flags.DEFINE_integer("num_shards", 10, "How many shards to use.")
flags.DEFINE_integer("max_cases", 0,
"Maximum number of cases to generate (unbounded if 0).")
flags.DEFINE_integer("random_seed", 429459, "Random seed to use.")
-
flags.DEFINE_string("t2t_usr_dir", "",
"Path to a Python module that will be imported. The "
"__init__.py file should include the necessary imports. "
"The imported files should contain registrations, "
- "e.g. @registry.register_model calls, that will then be "
- "available to the t2t-datagen.")
+ "e.g. @registry.register_problem calls, that will then be "
+ "available to t2t-datagen.")
# Mapping from problems that we can generate data for to their generators.
# pylint: disable=g-long-lambda
_SUPPORTED_PROBLEM_GENERATORS = {
- "algorithmic_shift_decimal40": (
- lambda: algorithmic.shift_generator(20, 10, 40, 100000),
- lambda: algorithmic.shift_generator(20, 10, 80, 10000)),
- "algorithmic_reverse_binary40": (
- lambda: algorithmic.reverse_generator(2, 40, 100000),
- lambda: algorithmic.reverse_generator(2, 400, 10000)),
- "algorithmic_reverse_decimal40": (
- lambda: algorithmic.reverse_generator(10, 40, 100000),
- lambda: algorithmic.reverse_generator(10, 400, 10000)),
- "algorithmic_addition_binary40": (
- lambda: algorithmic.addition_generator(2, 40, 100000),
- lambda: algorithmic.addition_generator(2, 400, 10000)),
- "algorithmic_addition_decimal40": (
- lambda: algorithmic.addition_generator(10, 40, 100000),
- lambda: algorithmic.addition_generator(10, 400, 10000)),
- "algorithmic_multiplication_binary40": (
- lambda: algorithmic.multiplication_generator(2, 40, 100000),
- lambda: algorithmic.multiplication_generator(2, 400, 10000)),
- "algorithmic_multiplication_decimal40": (
- lambda: algorithmic.multiplication_generator(10, 40, 100000),
- lambda: algorithmic.multiplication_generator(10, 400, 10000)),
- "algorithmic_reverse_nlplike_decimal8K": (
- lambda: algorithmic.reverse_generator_nlplike(8000, 70, 100000,
- 10, 1.300),
- lambda: algorithmic.reverse_generator_nlplike(8000, 70, 10000,
- 10, 1.300)),
- "algorithmic_reverse_nlplike_decimal32K": (
- lambda: algorithmic.reverse_generator_nlplike(32000, 70, 100000,
- 10, 1.050),
- lambda: algorithmic.reverse_generator_nlplike(32000, 70, 10000,
- 10, 1.050)),
"algorithmic_algebra_inverse": (
lambda: algorithmic_math.algebra_inverse(26, 0, 2, 100000),
lambda: algorithmic_math.algebra_inverse(26, 3, 3, 10000)),
@@ -125,29 +94,9 @@ _SUPPORTED_PROBLEM_GENERATORS = {
2**14, 2**9),
lambda: wsj_parsing.parsing_token_generator(FLAGS.tmp_dir, False,
2**14, 2**9)),
- "wmt_enfr_characters": (
- lambda: wmt.enfr_character_generator(FLAGS.tmp_dir, True),
- lambda: wmt.enfr_character_generator(FLAGS.tmp_dir, False)),
- "wmt_enfr_tokens_8k": (
- lambda: wmt.enfr_wordpiece_token_generator(FLAGS.tmp_dir, True, 2**13),
- lambda: wmt.enfr_wordpiece_token_generator(FLAGS.tmp_dir, False, 2**13)
- ),
- "wmt_enfr_tokens_32k": (
- lambda: wmt.enfr_wordpiece_token_generator(FLAGS.tmp_dir, True, 2**15),
- lambda: wmt.enfr_wordpiece_token_generator(FLAGS.tmp_dir, False, 2**15)
- ),
- "wmt_ende_characters": (
- lambda: wmt.ende_character_generator(FLAGS.tmp_dir, True),
- lambda: wmt.ende_character_generator(FLAGS.tmp_dir, False)),
"wmt_ende_bpe32k": (
lambda: wmt.ende_bpe_token_generator(FLAGS.tmp_dir, True),
lambda: wmt.ende_bpe_token_generator(FLAGS.tmp_dir, False)),
- "wmt_zhen_tokens_32k": (
- lambda: wmt.zhen_wordpiece_token_generator(FLAGS.tmp_dir, True,
- 2**15, 2**15),
- lambda: wmt.zhen_wordpiece_token_generator(FLAGS.tmp_dir, False,
- 2**15, 2**15)
- ),
"lm1b_32k": (
lambda: lm1b.generator(FLAGS.tmp_dir, True),
lambda: lm1b.generator(FLAGS.tmp_dir, False)
@@ -286,6 +235,9 @@ def main(_):
# Calculate the list of problems to generate.
problems = sorted(
list(_SUPPORTED_PROBLEM_GENERATORS) + registry.list_problems())
+ for exclude in FLAGS.exclude_problems.split(","):
+ if exclude:
+ problems = [p for p in problems if exclude not in p]
if FLAGS.problem and FLAGS.problem[-1] == "*":
problems = [p for p in problems if p.startswith(FLAGS.problem[:-1])]
elif FLAGS.problem:
diff --git a/tensor2tensor/bin/t2t-trainer b/tensor2tensor/bin/t2t-trainer
old mode 100755
new mode 100644
index 6b3f4de71..8a801e70e
--- a/tensor2tensor/bin/t2t-trainer
+++ b/tensor2tensor/bin/t2t-trainer
@@ -29,14 +29,11 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import importlib
-import os
-import sys
-
# Dependency imports
from tensor2tensor.utils import trainer_utils as utils
from tensor2tensor.utils import usr_dir
+
import tensorflow as tf
flags = tf.flags
@@ -49,6 +46,7 @@ flags.DEFINE_string("t2t_usr_dir", "",
"e.g. @registry.register_model calls, that will then be "
"available to the t2t-trainer.")
+
def main(_):
tf.logging.set_verbosity(tf.logging.INFO)
usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
diff --git a/tensor2tensor/data_generators/README.md b/tensor2tensor/data_generators/README.md
index f8495c38f..310bc39df 100644
--- a/tensor2tensor/data_generators/README.md
+++ b/tensor2tensor/data_generators/README.md
@@ -1,7 +1,7 @@
-# Data generators for T2T models.
+# T2T Problems.
-This directory contains data generators for a number of problems. We use a
-naming scheme for the problems, they have names of the form
+This directory contains `Problem` specifications for a number of problems. We
+use a naming scheme for the problems, they have names of the form
`[task-family]_[task]_[specifics]`. Data for all currently supported problems
can be generated by calling the main generator binary (`t2t-datagen`). For
example:
@@ -20,53 +20,51 @@ All tasks produce TFRecord files of `tensorflow.Example` protocol buffers.
## Adding a new problem
-1. Implement and register a Python generator for the dataset
-1. Add a problem specification to `problem_hparams.py` specifying input and
- output modalities
+To add a new problem, subclass
+[`Problem`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/problem.py)
+and register it with `@registry.register_problem`. See
+[`WMTEnDeTokens8k`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/wmt.py)
+for an example.
-To add a new problem, you first need to create python generators for training
-and development data for the problem. The python generators should yield
-dictionaries with string keys and values being lists of {int, float, str}.
-Here is a very simple generator for a data-set where inputs are lists of 1s with
-length upto 100 and targets are lists of length 1 with an integer denoting the
-length of the input list.
+`Problem`s support data generation, training, and decoding.
+
+Data generation is handles by `Problem.generate_data` which should produce 2
+datasets, training and dev, which should be named according to
+`Problem.training_filepaths` and `Problem.dev_filepaths`.
+`Problem.generate_data` should also produce any other files that may be required
+for training/decoding, e.g. a vocabulary file.
+
+A particularly easy way to implement `Problem.generate_data` for your dataset is
+to create 2 Python generators, one for the training data and another for the
+dev data, and pass them to `generator_utils.generate_dataset_and_shuffle`. See
+[`WMTEnDeTokens8k.generate_data`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/wmt.py)
+for an example of usage.
+
+The generators should yield dictionaries with string keys and values being lists
+of {int, float, str}. Here is a very simple generator for a data-set where
+inputs are lists of 2s with length upto 100 and targets are lists of length 1
+with an integer denoting the length of the input list.
```
def length_generator(nbr_cases):
for _ in xrange(nbr_cases):
length = np.random.randint(100) + 1
- yield {"inputs": [1] * length, "targets": [length]}
+ yield {"inputs": [2] * length, "targets": [length]}
```
-Note that our data reader uses 0 for padding, so it is a good idea to never
-generate 0s, except if all your examples have the same size (in which case
-they'll never be padded anyway) or if you're doing padding on your own (in which
-case please use 0s for padding). When adding the python generator function,
-please also add unit tests to check if the code runs.
+Note that our data reader uses 0 for padding and other parts of the code assume
+end-of-string (EOS) is 1, so it is a good idea to never generate 0s or 1s,
+except if all your examples have the same size (in which case they'll never be
+padded anyway) or if you're doing padding on your own (in which case please use
+0s for padding). When adding the python generator function, please also add unit
+tests to check if the code runs.
The generator can do arbitrary setup before beginning to yield examples - for
example, downloading data, generating vocabulary files, etc.
Some examples:
-* [Algorithmic generators](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/algorithmic.py)
+* [Algorithmic problems](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/algorithmic.py)
and their [unit tests](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/algorithmic_test.py)
-* [WMT generators](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/wmt.py)
+* [WMT problems](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/wmt.py)
and their [unit tests](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/wmt_test.py)
-
-When your python generator is ready and tested, add it to the
-`_SUPPORTED_PROBLEM_GENERATORS` dictionary in the
-[data
-generator](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/bin/t2t-datagen).
-The keys are problem names, and the values are pairs of (training-set-generator
-function, dev-set-generator function). For the generator above, one could add
-the following lines:
-
-```
- "algorithmic_length_upto100":
- (lambda: algorithmic.length_generator(10000),
- lambda: algorithmic.length_generator(1000)),
-```
-
-Note the lambdas above: we don't want to call the generators too early.
-
diff --git a/tensor2tensor/data_generators/algorithmic.py b/tensor2tensor/data_generators/algorithmic.py
index 7e522bfa0..2169e1910 100644
--- a/tensor2tensor/data_generators/algorithmic.py
+++ b/tensor2tensor/data_generators/algorithmic.py
@@ -25,48 +25,86 @@
from tensor2tensor.data_generators import generator_utils as utils
from tensor2tensor.data_generators import problem
+from tensor2tensor.data_generators import text_encoder
from tensor2tensor.utils import registry
-@registry.register_problem
-class AlgorithmicIdentityBinary40(problem.Problem):
- """Problem spec for algorithmic binary identity task."""
+class AlgorithmicProblem(problem.Problem):
+ """Base class for algorithmic problems."""
@property
def num_symbols(self):
- return 2
+ raise NotImplementedError()
+
+ @property
+ def train_generator(self):
+ """Generator; takes 3 args: nbr_symbols, max_length, nbr_cases."""
+ raise NotImplementedError()
+
+ @property
+ def dev_generator(self):
+ return self.train_generator
+
+ @property
+ def train_length(self):
+ return 40
+
+ @property
+ def dev_length(self):
+ return 400
+
+ @property
+ def train_size(self):
+ return 100000
+
+ @property
+ def dev_size(self):
+ return 10000
+
+ @property
+ def num_shards(self):
+ return 10
+
+ def generate_data(self, data_dir, _, num_shards=None):
+ if num_shards is None:
+ num_shards = self.num_shards
+
+ def generator_eos(generator):
+ """Shift by NUM_RESERVED_IDS and append EOS token."""
+ for case in generator:
+ new_case = {}
+ for feature in case:
+ new_case[feature] = [i + text_encoder.NUM_RESERVED_TOKENS
+ for i in case[feature]] + [text_encoder.EOS_ID]
+ yield new_case
+
+ train_generator_eos = lambda: generator_eos( # pylint: disable=g-long-lambda
+ self.train_generator(self.num_symbols,
+ self.train_length, self.train_size))
+ dev_generator_eos = lambda: generator_eos( # pylint: disable=g-long-lambda
+ self.dev_generator(self.num_symbols, self.dev_length, self.dev_size))
- def generate_data(self, data_dir, _, num_shards=100):
utils.generate_dataset_and_shuffle(
- identity_generator(self.num_symbols, 40, 100000),
+ train_generator_eos(),
self.training_filepaths(data_dir, num_shards, shuffled=True),
- identity_generator(self.num_symbols, 400, 10000),
+ dev_generator_eos(),
self.dev_filepaths(data_dir, 1, shuffled=True),
shuffle=False)
def hparams(self, defaults, unused_model_hparams):
p = defaults
- vocab_size = self.num_symbols + self._encoders["inputs"].num_reserved_ids
+ vocab_size = self.num_symbols + text_encoder.NUM_RESERVED_TOKENS
p.input_modality = {"inputs": (registry.Modalities.SYMBOL, vocab_size)}
p.target_modality = (registry.Modalities.SYMBOL, vocab_size)
p.input_space_id = problem.SpaceID.DIGIT_0
p.target_space_id = problem.SpaceID.DIGIT_1
-@registry.register_problem
-class AlgorithmicIdentityDecimal40(AlgorithmicIdentityBinary40):
- """Problem spec for algorithmic decimal identity task."""
-
- @property
- def num_symbols(self):
- return 10
-
-
def identity_generator(nbr_symbols, max_length, nbr_cases):
"""Generator for the identity (copy) task on sequences of symbols.
The length of the sequence is drawn uniformly at random from [1, max_length]
- and then symbols are drawn uniformly at random from [2, nbr_symbols + 2) until
+ and then symbols are drawn uniformly at random from [0, nbr_symbols) until
nbr_cases sequences have been produced.
Args:
@@ -80,15 +118,37 @@ def identity_generator(nbr_symbols, max_length, nbr_cases):
"""
for _ in xrange(nbr_cases):
l = np.random.randint(max_length) + 1
- inputs = [np.random.randint(nbr_symbols) + 2 for _ in xrange(l)]
- yield {"inputs": inputs, "targets": inputs + [1]} # [1] for EOS
+ inputs = [np.random.randint(nbr_symbols) for _ in xrange(l)]
+ yield {"inputs": inputs, "targets": inputs}
+
+
+@registry.register_problem
+class AlgorithmicIdentityBinary40(AlgorithmicProblem):
+ """Problem spec for algorithmic binary identity task."""
+
+ @property
+ def num_symbols(self):
+ return 2
+
+ @property
+ def train_generator(self):
+ return identity_generator
+
+
+@registry.register_problem
+class AlgorithmicIdentityDecimal40(AlgorithmicIdentityBinary40):
+ """Problem spec for algorithmic decimal identity task."""
+
+ @property
+ def num_symbols(self):
+ return 10
def shift_generator(nbr_symbols, shift, max_length, nbr_cases):
"""Generator for the shift task on sequences of symbols.
The length of the sequence is drawn uniformly at random from [1, max_length]
- and then symbols are drawn uniformly at random from [2, nbr_symbols - shift]
+ and then symbols are drawn uniformly at random from [0, nbr_symbols - shift]
until nbr_cases sequences have been produced (output[i] = input[i] + shift).
Args:
@@ -103,18 +163,35 @@ def shift_generator(nbr_symbols, shift, max_length, nbr_cases):
"""
for _ in xrange(nbr_cases):
l = np.random.randint(max_length) + 1
- inputs = [np.random.randint(nbr_symbols - shift) + 2 for _ in xrange(l)]
+ inputs = [np.random.randint(nbr_symbols - shift) for _ in xrange(l)]
yield {
"inputs": inputs,
- "targets": [i + shift for i in inputs] + [1]
- } # [1] for EOS
+ "targets": [i + shift for i in inputs]
+ }
+
+
+@registry.register_problem
+class AlgorithmicShiftDecimal40(AlgorithmicProblem):
+ """Problem spec for algorithmic decimal shift task."""
+
+ @property
+ def num_symbols(self):
+ return 20
+
+ @property
+ def train_generator(self):
+ return lambda nbr_sym, l, size: shift_generator(nbr_sym, 10, l, size)
+
+ @property
+ def dev_length(self):
+ return 80
def reverse_generator(nbr_symbols, max_length, nbr_cases):
"""Generator for the reversing task on sequences of symbols.
The length of the sequence is drawn uniformly at random from [1, max_length]
- and then symbols are drawn uniformly at random from [2, nbr_symbols] until
+ and then symbols are drawn uniformly at random from [0, nbr_symbols) until
nbr_cases sequences have been produced.
Args:
@@ -128,11 +205,33 @@ def reverse_generator(nbr_symbols, max_length, nbr_cases):
"""
for _ in xrange(nbr_cases):
l = np.random.randint(max_length) + 1
- inputs = [np.random.randint(nbr_symbols) + 2 for _ in xrange(l)]
+ inputs = [np.random.randint(nbr_symbols) for _ in xrange(l)]
yield {
"inputs": inputs,
- "targets": list(reversed(inputs)) + [1]
- } # [1] for EOS
+ "targets": list(reversed(inputs))
+ }
+
+
+@registry.register_problem
+class AlgorithmicReverseBinary40(AlgorithmicProblem):
+ """Problem spec for algorithmic binary reversing task."""
+
+ @property
+ def num_symbols(self):
+ return 2
+
+ @property
+ def train_generator(self):
+ return reverse_generator
+
+
+@registry.register_problem
+class AlgorithmicReverseDecimal40(AlgorithmicReverseBinary40):
+ """Problem spec for algorithmic decimal reversing task."""
+
+ @property
+ def num_symbols(self):
+ return 10
def zipf_distribution(nbr_symbols, alpha):
@@ -166,11 +265,8 @@ def zipf_random_sample(distr_map, sample_len):
"""
u = np.random.random(sample_len)
# Random produces values in range [0.0,1.0); even if it is almost
- # improbable(but possible) that it can generate a clear 0.000..0,
- # we have made a sanity check to overcome this issue. On the other hand,
- # t+1 is enough from saving us to generate PAD(0) and EOS(1) which are
- # reservated symbols.
- return [t + 1 if t > 0 else t + 2 for t in np.searchsorted(distr_map, u)]
+ # improbable(but possible) that it can generate a clear 0.000..0.
+ return list(np.searchsorted(distr_map, u))
def reverse_generator_nlplike(nbr_symbols,
@@ -182,7 +278,7 @@ def reverse_generator_nlplike(nbr_symbols,
The length of the sequence is drawn from a Gaussian(Normal) distribution
at random from [1, max_length] and with std deviation of 1%,
- then symbols are drawn from Zipf's law at random from [2, nbr_symbols] until
+ then symbols are drawn from Zipf's law at random from [0, nbr_symbols) until
nbr_cases sequences have been produced.
Args:
@@ -206,8 +302,44 @@ def reverse_generator_nlplike(nbr_symbols,
inputs = zipf_random_sample(distr_map, l)
yield {
"inputs": inputs,
- "targets": list(reversed(inputs)) + [1]
- } # [1] for EOS
+ "targets": list(reversed(inputs))
+ }
+
+
+@registry.register_problem
+class AlgorithmicReverseNlplike8K(AlgorithmicProblem):
+ """Problem spec for algorithmic nlp-like reversing task."""
+
+ @property
+ def num_symbols(self):
+ return 8000
+
+ @property
+ def train_generator(self):
+ return lambda nbr_sym, length, size: reverse_generator_nlplike( # pylint: disable=g-long-lambda
+ nbr_sym, length, size, 10, 1.300)
+
+ @property
+ def train_length(self):
+ return 70
+
+ @property
+ def dev_length(self):
+ return 70
+
+
+@registry.register_problem
+class AlgorithmicReverseNlplike32K(AlgorithmicReverseNlplike8K):
+ """Problem spec for algorithmic nlp-like reversing task, 32K vocab."""
+
+ @property
+ def num_symbols(self):
+ return 32000
+
+ @property
+ def train_generator(self):
+ return lambda nbr_sym, length, size: reverse_generator_nlplike( # pylint: disable=g-long-lambda
+ nbr_sym, length, size, 10, 1.050)
def lower_endian_to_number(l, base):
@@ -235,7 +367,7 @@ def addition_generator(base, max_length, nbr_cases):
The length of each number is drawn uniformly at random from [1, max_length/2]
and then digits are drawn uniformly at random. The numbers are added and
- separated by [base+1] in the input. Stops at nbr_cases.
+ separated by [base] in the input. Stops at nbr_cases.
Args:
base: in which base are the numbers.
@@ -257,10 +389,31 @@ def addition_generator(base, max_length, nbr_cases):
n1 = random_number_lower_endian(l1, base)
n2 = random_number_lower_endian(l2, base)
result = lower_endian_to_number(n1, base) + lower_endian_to_number(n2, base)
- # We shift digits by 1 on input and output to leave 0 for padding.
- inputs = [i + 2 for i in n1] + [base + 2] + [i + 2 for i in n2]
- targets = [i + 2 for i in number_to_lower_endian(result, base)]
- yield {"inputs": inputs, "targets": targets + [1]} # [1] for EOS
+ inputs = n1 + [base] + n2
+ targets = number_to_lower_endian(result, base)
+ yield {"inputs": inputs, "targets": targets}
+
+
+@registry.register_problem
+class AlgorithmicAdditionBinary40(AlgorithmicProblem):
+ """Problem spec for algorithmic binary addition task."""
+
+ @property
+ def num_symbols(self):
+ return 2
+
+ @property
+ def train_generator(self):
+ return addition_generator
+
+
+@registry.register_problem
+class AlgorithmicAdditionDecimal40(AlgorithmicAdditionBinary40):
+ """Problem spec for algorithmic decimal addition task."""
+
+ @property
+ def num_symbols(self):
+ return 10
def multiplication_generator(base, max_length, nbr_cases):
@@ -268,7 +421,7 @@ def multiplication_generator(base, max_length, nbr_cases):
The length of each number is drawn uniformly at random from [1, max_length/2]
and then digits are drawn uniformly at random. The numbers are multiplied
- and separated by [base+1] in the input. Stops at nbr_cases.
+ and separated by [base] in the input. Stops at nbr_cases.
Args:
base: in which base are the numbers.
@@ -291,7 +444,28 @@ def multiplication_generator(base, max_length, nbr_cases):
n1 = random_number_lower_endian(l1, base)
n2 = random_number_lower_endian(l2, base)
result = lower_endian_to_number(n1, base) * lower_endian_to_number(n2, base)
- # We shift digits by 1 on input and output to leave 0 for padding.
- inputs = [i + 2 for i in n1] + [base + 2] + [i + 2 for i in n2]
- targets = [i + 2 for i in number_to_lower_endian(result, base)]
- yield {"inputs": inputs, "targets": targets + [1]} # [1] for EOS
+ inputs = n1 + [base] + n2
+ targets = number_to_lower_endian(result, base)
+ yield {"inputs": inputs, "targets": targets}
+
+
+@registry.register_problem
+class AlgorithmicMultiplicationBinary40(AlgorithmicProblem):
+ """Problem spec for algorithmic binary multiplication task."""
+
+ @property
+ def num_symbols(self):
+ return 2
+
+ @property
+ def train_generator(self):
+ return multiplication_generator
+
+
+@registry.register_problem
+class AlgorithmicMultiplicationDecimal40(AlgorithmicMultiplicationBinary40):
+ """Problem spec for algorithmic decimal multiplication task."""
+
+ @property
+ def num_symbols(self):
+ return 10
diff --git a/tensor2tensor/data_generators/algorithmic_test.py b/tensor2tensor/data_generators/algorithmic_test.py
index 9961e6173..fb8ff6719 100644
--- a/tensor2tensor/data_generators/algorithmic_test.py
+++ b/tensor2tensor/data_generators/algorithmic_test.py
@@ -31,14 +31,14 @@ def testIdentityGenerator(self):
counter = 0
for d in algorithmic.identity_generator(3, 8, 10):
counter += 1
- self.assertEqual(d["inputs"] + [1], d["targets"])
+ self.assertEqual(d["inputs"], d["targets"])
self.assertEqual(counter, 10)
def testReverseGenerator(self):
counter = 0
for d in algorithmic.reverse_generator(3, 8, 10):
counter += 1
- self.assertEqual(list(reversed(d["inputs"])) + [1], d["targets"])
+ self.assertEqual(list(reversed(d["inputs"])), d["targets"])
self.assertEqual(counter, 10)
def testZipfDistribution(self):
@@ -53,7 +53,7 @@ def testReverseGeneratorNlpLike(self):
counter = 0
for d in algorithmic.reverse_generator_nlplike(3, 8, 10):
counter += 1
- self.assertEqual(list(reversed(d["inputs"])) + [1], d["targets"])
+ self.assertEqual(list(reversed(d["inputs"])), d["targets"])
self.assertEqual(counter, 10)
def testLowerEndianToNumber(self):
@@ -78,20 +78,20 @@ def testAdditionGenerator(self):
counter = 0
for d in algorithmic.addition_generator(4, 8, 10):
counter += 1
- self.assertEqual(d["inputs"].count(6), 1)
- self.assertEqual(d["inputs"].count(0), 0)
- self.assertEqual(d["targets"].count(6), 0)
- self.assertEqual(d["targets"].count(0), 0)
+ self.assertEqual(d["inputs"].count(4), 1)
+ self.assertEqual(d["inputs"].count(5), 0)
+ self.assertEqual(d["targets"].count(4), 0)
+ self.assertEqual(d["targets"].count(5), 0)
self.assertEqual(counter, 10)
def testMultiplicationGenerator(self):
counter = 0
for d in algorithmic.multiplication_generator(4, 8, 10):
counter += 1
- self.assertEqual(d["inputs"].count(6), 1)
- self.assertEqual(d["inputs"].count(0), 0)
- self.assertEqual(d["targets"].count(6), 0)
- self.assertEqual(d["targets"].count(0), 0)
+ self.assertEqual(d["inputs"].count(4), 1)
+ self.assertEqual(d["inputs"].count(5), 0)
+ self.assertEqual(d["targets"].count(4), 0)
+ self.assertEqual(d["targets"].count(5), 0)
self.assertEqual(counter, 10)
diff --git a/tensor2tensor/data_generators/generator_utils.py b/tensor2tensor/data_generators/generator_utils.py
old mode 100755
new mode 100644
index 20f3959d8..b34a87138
--- a/tensor2tensor/data_generators/generator_utils.py
+++ b/tensor2tensor/data_generators/generator_utils.py
@@ -244,7 +244,6 @@ def gunzip_file(gz_path, new_path):
"http://www.statmt.org/wmt13/training-parallel-un.tgz",
["un/undoc.2000.fr-en.en", "un/undoc.2000.fr-en.fr"]
],
- # Macedonian-English
[
"https://github.com/stefan-it/nmt-mk-en/raw/master/data/setimes.mk-en.train.tgz", # pylint: disable=line-too-long
["train.mk", "train.en"]
diff --git a/tensor2tensor/data_generators/problem.py b/tensor2tensor/data_generators/problem.py
index 0eb1987fa..1182ed7d1 100644
--- a/tensor2tensor/data_generators/problem.py
+++ b/tensor2tensor/data_generators/problem.py
@@ -92,6 +92,7 @@ class Problem(object):
get sharded filenames. If shuffled=False, the filenames will contain
an "unshuffled" suffix; you should then shuffle the data
shard-by-shard with generator_utils.shuffle_dataset.
+ - Allows to specify the number of shards, optionally (can be omitted).
- Subclasses must override
* dataset_filename()
- Base filename for problem.
@@ -113,7 +114,7 @@ class Problem(object):
# BEGIN SUBCLASS INTERFACE
# ============================================================================
- def generate_data(self, data_dir, tmp_dir, num_shards=100):
+ def generate_data(self, data_dir, tmp_dir, num_shards=None):
raise NotImplementedError()
def hparams(self, defaults, model_hparams):
diff --git a/tensor2tensor/data_generators/problem_hparams.py b/tensor2tensor/data_generators/problem_hparams.py
index 70b9dada8..5922ab59a 100644
--- a/tensor2tensor/data_generators/problem_hparams.py
+++ b/tensor2tensor/data_generators/problem_hparams.py
@@ -217,20 +217,6 @@ def test_problem_hparams(unused_model_hparams, input_vocab_size,
return p
-def algorithmic(vocab_size, unused_model_hparams):
- """Default parameters for algorithmic tasks."""
- p = default_problem_hparams()
- p.input_modality = {"inputs": (registry.Modalities.SYMBOL, vocab_size)}
- p.target_modality = (registry.Modalities.SYMBOL, vocab_size)
- p.vocabulary = {
- "inputs": text_encoder.TextEncoder(),
- "targets": text_encoder.TextEncoder(),
- }
- p.input_space_id = 10
- p.target_space_id = 11
- return p
-
-
def audio_timit_characters(unused_model_hparams):
"""English audio transcription benchmark."""
p = default_problem_hparams()
@@ -351,10 +337,9 @@ def wiki_32k(model_hparams):
p = default_problem_hparams()
encoder = text_encoder.SubwordTextEncoder(
os.path.join(model_hparams.data_dir, "wiki_32k.subword_text_encoder"))
- p.input_modality = {
- "inputs": (registry.Modalities.SYMBOL, encoder.vocab_size)
- }
- p.target_modality = (registry.Modalities.SYMBOL, encoder.vocab_size)
+ modality_spec = (registry.Modalities.SYMBOL, encoder.vocab_size)
+ p.input_modality = {"inputs": modality_spec}
+ p.target_modality = modality_spec
p.vocabulary = {
"inputs": encoder,
"targets": encoder
@@ -378,50 +363,6 @@ def lmptb_10k(model_hparams):
return p
-def wmt_enfr_characters(unused_model_hparams):
- """English to French translation benchmark."""
- p = default_problem_hparams()
- p.input_modality = {"inputs": (registry.Modalities.SYMBOL, 256)}
- p.target_modality = (registry.Modalities.SYMBOL, 256)
- p.vocabulary = {
- "inputs": text_encoder.ByteTextEncoder(),
- "targets": text_encoder.ByteTextEncoder(),
- }
- p.loss_multiplier = 2.0
- p.input_space_id = 2
- p.target_space_id = 5
- return p
-
-
-def wmt_enfr_tokens(model_hparams, wrong_vocab_size):
- """English to French translation benchmark.
-
- Args:
- model_hparams: a tf.contrib.training.HParams
- wrong_vocab_size: a number used in the filename indicating the approximate
- vocabulary size. This is not to be confused with the actual vocabulary
- size.
- Returns:
- a tf.contrib.training.HParams
- """
- p = default_problem_hparams()
- # This vocab file must be present within the data directory.
- vocab_filename = os.path.join(model_hparams.data_dir,
- "tokens.vocab.%d" % wrong_vocab_size)
- subtokenizer = text_encoder.SubwordTextEncoder(vocab_filename)
- p.input_modality = {
- "inputs": (registry.Modalities.SYMBOL, subtokenizer.vocab_size)
- }
- p.target_modality = (registry.Modalities.SYMBOL, subtokenizer.vocab_size)
- p.vocabulary = {
- "inputs": subtokenizer,
- "targets": subtokenizer,
- }
- p.input_space_id = 3
- p.target_space_id = 6
- return p
-
-
def wmt_ende_bpe32k(model_hparams):
"""English to German translation benchmark."""
p = default_problem_hparams()
@@ -441,47 +382,6 @@ def wmt_ende_bpe32k(model_hparams):
return p
-def wmt_ende_characters(unused_model_hparams):
- """English to German translation benchmark."""
- p = default_problem_hparams()
- p.input_modality = {"inputs": (registry.Modalities.SYMBOL, 256)}
- p.target_modality = (registry.Modalities.SYMBOL, 256)
- p.vocabulary = {
- "inputs": text_encoder.ByteTextEncoder(),
- "targets": text_encoder.ByteTextEncoder(),
- }
- p.loss_multiplier = 2.0
- p.input_space_id = 2
- p.target_space_id = 7
- return p
-
-
-def wmt_zhen_tokens(model_hparams, wrong_vocab_size):
- """Chinese to English translation benchmark."""
- p = default_problem_hparams()
- # This vocab file must be present within the data directory.
- if model_hparams.shared_embedding_and_softmax_weights == 1:
- model_hparams.shared_embedding_and_softmax_weights = 0
- source_vocab_filename = os.path.join(model_hparams.data_dir,
- "tokens.vocab.zh.%d" % wrong_vocab_size)
- target_vocab_filename = os.path.join(model_hparams.data_dir,
- "tokens.vocab.en.%d" % wrong_vocab_size)
- source_token = text_encoder.SubwordTextEncoder(source_vocab_filename)
- target_token = text_encoder.SubwordTextEncoder(target_vocab_filename)
- p.input_modality = {
- "inputs": (registry.Modalities.SYMBOL, source_token.vocab_size)
- }
- p.target_modality = (registry.Modalities.SYMBOL, target_token.vocab_size)
- p.vocabulary = {
- "inputs": source_token,
- "targets": target_token,
- }
- p.loss_multiplier = 1.4
- p.input_space_id = 16
- p.target_space_id = 4
- return p
-
-
def wmt_parsing_characters(model_hparams):
"""English to parse tree translation benchmark."""
del model_hparams # Unused.
@@ -699,15 +599,6 @@ def img2img_imagenet(unused_model_hparams):
# Dictionary of named hyperparameter settings for various problems.
# This is only accessed through the problem_hparams function below.
PROBLEM_HPARAMS_MAP = {
- "algorithmic_addition_binary40": lambda p: algorithmic(4, p),
- "algorithmic_addition_decimal40": lambda p: algorithmic(12, p),
- "algorithmic_multiplication_binary40": lambda p: algorithmic(4, p),
- "algorithmic_multiplication_decimal40": lambda p: algorithmic(12, p),
- "algorithmic_reverse_binary40": lambda p: algorithmic(4, p),
- "algorithmic_reverse_decimal40": lambda p: algorithmic(12, p),
- "algorithmic_reverse_nlplike_decimal8K": lambda p: algorithmic(8002, p),
- "algorithmic_reverse_nlplike_decimal32K": lambda p: algorithmic(32002, p),
- "algorithmic_shift_decimal40": lambda p: algorithmic(22, p),
"audio_timit_characters_tune": audio_timit_characters,
"audio_timit_characters_test": audio_timit_characters,
"audio_timit_tokens_8k_tune": lambda p: audio_timit_tokens(p, 2**13),
@@ -724,15 +615,7 @@ def img2img_imagenet(unused_model_hparams):
"wmt_parsing_tokens_8k": lambda p: wmt_parsing_tokens(p, 2**13),
"wsj_parsing_tokens_16k": lambda p: wsj_parsing_tokens( # pylint: disable=g-long-lambda
p, "wsj", 2**14, 2**9),
- "wmt_enfr_characters": wmt_enfr_characters,
- "wmt_enfr_tokens_8k": lambda p: wmt_enfr_tokens(p, 2**13),
- "wmt_enfr_tokens_32k": lambda p: wmt_enfr_tokens(p, 2**15),
- "wmt_enfr_tokens_32k_shuffled": lambda p: wmt_enfr_tokens(p, 2**15),
- "wmt_enfr_tokens_32k_combined": lambda p: wmt_enfr_tokens(p, 2**15),
- "wmt_enfr_tokens_128k": lambda p: wmt_enfr_tokens(p, 2**17),
- "wmt_ende_characters": wmt_ende_characters,
"wmt_ende_bpe32k": wmt_ende_bpe32k,
- "wmt_zhen_tokens_32k": lambda p: wmt_zhen_tokens(p, 2**15),
"image_cifar10_tune": image_cifar10,
"image_cifar10_test": image_cifar10,
"image_mnist_tune": image_mnist,
diff --git a/tensor2tensor/data_generators/text_encoder.py b/tensor2tensor/data_generators/text_encoder.py
index c812ced4f..e0ac1901e 100644
--- a/tensor2tensor/data_generators/text_encoder.py
+++ b/tensor2tensor/data_generators/text_encoder.py
@@ -42,8 +42,8 @@
EOS = ""
RESERVED_TOKENS = [PAD, EOS]
NUM_RESERVED_TOKENS = len(RESERVED_TOKENS)
-PAD_TOKEN = RESERVED_TOKENS.index(PAD) # Normally 0
-EOS_TOKEN = RESERVED_TOKENS.index(EOS) # Normally 1
+PAD_ID = RESERVED_TOKENS.index(PAD) # Normally 0
+EOS_ID = RESERVED_TOKENS.index(EOS) # Normally 1
if PY2:
RESERVED_TOKENS_BYTES = RESERVED_TOKENS
@@ -51,6 +51,13 @@
RESERVED_TOKENS_BYTES = [bytes(PAD, "ascii"), bytes(EOS, "ascii")]
+# Regular expression for unescaping token strings.
+# '\u' is converted to '_'
+# '\\' is converted to '\'
+# '\213;' is converted to unichr(213)
+_UNESCAPE_REGEX = re.compile(u"|".join([r"\\u", r"\\\\", r"\\([0-9]+);"]))
+
+
def native_to_unicode_py2(s):
"""Python 2: transform native string to Unicode."""
if isinstance(s, unicode):
@@ -505,12 +512,6 @@ def _escape_token(self, token):
ret += u"\\%d;" % ord(c)
return ret
- # Regular expression for unescaping token strings
- # '\u' is converted to '_'
- # '\\' is converted to '\'
- # '\213;' is converted to unichr(213)
- _UNESCAPE_REGEX = re.compile(u'|'.join([r"\\u", r"\\\\", r"\\([0-9]+);"]))
-
def _unescape_token(self, escaped_token):
"""Inverse of _escape_token().
diff --git a/tensor2tensor/data_generators/tokenizer_test.py b/tensor2tensor/data_generators/tokenizer_test.py
old mode 100755
new mode 100644
index 45a1f7e41..c279290ed
--- a/tensor2tensor/data_generators/tokenizer_test.py
+++ b/tensor2tensor/data_generators/tokenizer_test.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
# Copyright 2017 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+# coding=utf-8
"""Tests for tensor2tensor.data_generators.tokenizer."""
from __future__ import absolute_import
diff --git a/tensor2tensor/data_generators/wmt.py b/tensor2tensor/data_generators/wmt.py
index de5a25e13..2e1f1e8af 100644
--- a/tensor2tensor/data_generators/wmt.py
+++ b/tensor2tensor/data_generators/wmt.py
@@ -38,77 +38,97 @@
FLAGS = tf.flags.FLAGS
-@registry.register_problem("wmt_ende_tokens_8k")
-class WMTEnDeTokens8k(problem.Problem):
- """Problem spec for WMT En-De translation."""
+# End-of-sentence marker.
+EOS = text_encoder.EOS_ID
- @property
- def target_vocab_size(self):
- return 2**13 # 8192
- def feature_encoders(self, data_dir):
- return _default_wmt_feature_encoders(data_dir, self.target_vocab_size)
+def _default_token_feature_encoders(data_dir, target_vocab_size):
+ vocab_filename = os.path.join(data_dir, "tokens.vocab.%d" % target_vocab_size)
+ subtokenizer = text_encoder.SubwordTextEncoder(vocab_filename)
+ return {
+ "inputs": subtokenizer,
+ "targets": subtokenizer,
+ }
- def generate_data(self, data_dir, tmp_dir, num_shards=100):
- generator_utils.generate_dataset_and_shuffle(
- ende_wordpiece_token_generator(tmp_dir, True, self.target_vocab_size),
- self.training_filepaths(data_dir, num_shards, shuffled=False),
- ende_wordpiece_token_generator(tmp_dir, False, self.target_vocab_size),
- self.dev_filepaths(data_dir, 1, shuffled=False))
- def hparams(self, defaults, unused_model_hparams):
- p = defaults
- vocab_size = self._encoders["inputs"].vocab_size
- p.input_modality = {"inputs": (registry.Modalities.SYMBOL, vocab_size)}
- p.target_modality = (registry.Modalities.SYMBOL, vocab_size)
- p.input_space_id = problem.SpaceID.EN_TOK
- p.target_space_id = problem.SpaceID.DE_TOK
+def _default_character_feature_encoders():
+ return {
+ "inputs": text_encoder.ByteTextEncoder(),
+ "targets": text_encoder.ByteTextEncoder(),
+ }
-@registry.register_problem("wmt_ende_tokens_32k")
-class WMTEnDeTokens32k(WMTEnDeTokens8k):
+class WMTProblem(problem.Problem):
+ """Base class for WMT problems."""
@property
- def target_vocab_size(self):
- return 2**15 # 32768
+ def is_character_level(self):
+ return False
+ @property
+ def targeted_vocab_size(self):
+ raise NotImplementedError() # Not needed if self.is_character_level.
-def _default_wmt_feature_encoders(data_dir, target_vocab_size):
- vocab_filename = os.path.join(data_dir, "tokens.vocab.%d" % target_vocab_size)
- subtokenizer = text_encoder.SubwordTextEncoder(vocab_filename)
- return {
- "inputs": subtokenizer,
- "targets": subtokenizer,
- }
+ @property
+ def train_generator(self):
+ """Generator; takes tmp_dir, is_training, possibly targeted_vocab_size."""
+ raise NotImplementedError()
-@registry.register_problem("setimes_mken_tokens_32k")
-class SETimesMkEnTokens32k(problem.Problem):
- """Problem spec for SETimes Mk-En translation."""
+ @property
+ def dev_generator(self):
+ return self.train_generator
@property
- def target_vocab_size(self):
- return 2**15 # 32768
+ def input_space_id(self):
+ raise NotImplementedError()
- def feature_encoders(self, data_dir):
- return _default_wmt_feature_encoders(data_dir, self.target_vocab_size)
+ @property
+ def target_space_id(self):
+ raise NotImplementedError()
+
+ @property
+ def num_shards(self):
+ return 100
+
+ def generate_data(self, data_dir, tmp_dir, num_shards=None):
+ if num_shards is None:
+ num_shards = self.num_shards
+ if self.is_character_level:
+ generator_utils.generate_dataset_and_shuffle(
+ self.train_generator(tmp_dir, True),
+ self.training_filepaths(data_dir, num_shards, shuffled=False),
+ self.dev_generator(tmp_dir, False),
+ self.dev_filepaths(data_dir, 1, shuffled=False))
+ else:
+ generator_utils.generate_dataset_and_shuffle(
+ self.train_generator(tmp_dir, True, self.targeted_vocab_size),
+ self.training_filepaths(data_dir, num_shards, shuffled=False),
+ self.dev_generator(tmp_dir, False, self.targeted_vocab_size),
+ self.dev_filepaths(data_dir, 1, shuffled=False))
- def generate_data(self, data_dir, tmp_dir, num_shards=100):
- generator_utils.generate_dataset_and_shuffle(
- mken_wordpiece_token_generator(tmp_dir, True, self.target_vocab_size),
- self.training_filepaths(data_dir, num_shards, shuffled=False),
- mken_wordpiece_token_generator(tmp_dir, False, self.target_vocab_size),
- self.dev_filepaths(data_dir, 1, shuffled=False))
+ def feature_encoders(self, data_dir):
+ if self.is_character_level:
+ return _default_character_feature_encoders()
+ return _default_token_feature_encoders(data_dir, self.targeted_vocab_size)
def hparams(self, defaults, unused_model_hparams):
p = defaults
- vocab_size = self._encoders["inputs"].vocab_size
- p.input_modality = {"inputs": (registry.Modalities.SYMBOL, vocab_size)}
- p.target_modality = (registry.Modalities.SYMBOL, vocab_size)
- p.input_space_id = problem.SpaceID.MK_TOK
- p.target_space_id = problem.SpaceID.EN_TOK
+ if self.is_character_level:
+ source_vocab_size = 256
+ target_vocab_size = 256
+ else:
+ source_vocab_size = self._encoders["inputs"].vocab_size
+ target_vocab_size = self._encoders["targets"].vocab_size
+ p.input_modality = {"inputs": (registry.Modalities.SYMBOL,
+ source_vocab_size)}
+ p.target_modality = (registry.Modalities.SYMBOL, target_vocab_size)
+ p.input_space_id = self.input_space_id
+ p.target_space_id = self.target_space_id
+ if self.is_character_level:
+ p.loss_multiplier = 2.0
-# End-of-sentence marker.
-EOS = text_encoder.EOS_TOKEN
+
+# Generic generators used later for multiple problems.
def character_generator(source_path, target_path, character_vocab, eos=None):
@@ -233,29 +253,7 @@ def bi_vocabs_token_generator(source_path,
source, target = source_file.readline(), target_file.readline()
-def _get_wmt_ende_dataset(directory, filename):
- """Extract the WMT en-de corpus `filename` to directory unless it's there."""
- train_path = os.path.join(directory, filename)
- if not (tf.gfile.Exists(train_path + ".de") and
- tf.gfile.Exists(train_path + ".en")):
- # We expect that this file has been downloaded from:
- # https://drive.google.com/open?id=0B_bZck-ksdkpM25jRUN2X2UxMm8 and placed
- # in `directory`.
- corpus_file = os.path.join(directory, FLAGS.ende_bpe_path)
- with tarfile.open(corpus_file, "r:gz") as corpus_tar:
- corpus_tar.extractall(directory)
- return train_path
-
-
-def ende_bpe_token_generator(tmp_dir, train):
- """Instance of token generator for the WMT en->de task, training set."""
- dataset_path = ("train.tok.clean.bpe.32000"
- if train else "newstest2013.tok.bpe.32000")
- train_path = _get_wmt_ende_dataset(tmp_dir, dataset_path)
- token_path = os.path.join(tmp_dir, "vocab.bpe.32000")
- token_vocab = text_encoder.TokenTextEncoder(vocab_filename=token_path)
- return token_generator(train_path + ".en", train_path + ".de", token_vocab,
- EOS)
+# Data-set URLs.
_ENDE_TRAIN_DATASETS = [
@@ -336,6 +334,34 @@ def ende_bpe_token_generator(tmp_dir, train):
]]
+# Generators.
+
+
+def _get_wmt_ende_dataset(directory, filename):
+ """Extract the WMT en-de corpus `filename` to directory unless it's there."""
+ train_path = os.path.join(directory, filename)
+ if not (tf.gfile.Exists(train_path + ".de") and
+ tf.gfile.Exists(train_path + ".en")):
+ # We expect that this file has been downloaded from:
+ # https://drive.google.com/open?id=0B_bZck-ksdkpM25jRUN2X2UxMm8 and placed
+ # in `directory`.
+ corpus_file = os.path.join(directory, FLAGS.ende_bpe_path)
+ with tarfile.open(corpus_file, "r:gz") as corpus_tar:
+ corpus_tar.extractall(directory)
+ return train_path
+
+
+def ende_bpe_token_generator(tmp_dir, train):
+ """Instance of token generator for the WMT en->de task, training set."""
+ dataset_path = ("train.tok.clean.bpe.32000"
+ if train else "newstest2013.tok.bpe.32000")
+ train_path = _get_wmt_ende_dataset(tmp_dir, dataset_path)
+ token_path = os.path.join(tmp_dir, "vocab.bpe.32000")
+ token_vocab = text_encoder.TokenTextEncoder(vocab_filename=token_path)
+ return token_generator(train_path + ".en", train_path + ".de", token_vocab,
+ EOS)
+
+
def _compile_data(tmp_dir, datasets, filename):
"""Concatenate all `datasets` and save to `filename`."""
filename = os.path.join(tmp_dir, filename)
@@ -386,6 +412,35 @@ def ende_wordpiece_token_generator(tmp_dir, train, vocab_size):
symbolizer_vocab, EOS)
+@registry.register_problem("wmt_ende_tokens_8k")
+class WMTEnDeTokens8k(WMTProblem):
+ """Problem spec for WMT En-De translation."""
+
+ @property
+ def targeted_vocab_size(self):
+ return 2**13 # 8192
+
+ @property
+ def train_generator(self):
+ return ende_wordpiece_token_generator
+
+ @property
+ def input_space_id(self):
+ return problem.SpaceID.EN_TOK
+
+ @property
+ def target_space_id(self):
+ return problem.SpaceID.DE_TOK
+
+
+@registry.register_problem("wmt_ende_tokens_32k")
+class WMTEnDeTokens32k(WMTEnDeTokens8k):
+
+ @property
+ def targeted_vocab_size(self):
+ return 2**15 # 32768
+
+
def ende_character_generator(tmp_dir, train):
character_vocab = text_encoder.ByteTextEncoder()
datasets = _ENDE_TRAIN_DATASETS if train else _ENDE_TEST_DATASETS
@@ -395,8 +450,29 @@ def ende_character_generator(tmp_dir, train):
character_vocab, EOS)
-def zhen_wordpiece_token_generator(tmp_dir, train, source_vocab_size,
- target_vocab_size):
+@registry.register_problem("wmt_ende_characters")
+class WMTEnDeCharacters(WMTProblem):
+ """Problem spec for WMT En-De translation."""
+
+ @property
+ def is_character_level(self):
+ return True
+
+ @property
+ def train_generator(self):
+ return ende_character_generator
+
+ @property
+ def input_space_id(self):
+ return problem.SpaceID.EN_CHR
+
+ @property
+ def target_space_id(self):
+ return problem.SpaceID.DE_CHR
+
+
+def zhen_wordpiece_token_bigenerator(tmp_dir, train, source_vocab_size,
+ target_vocab_size):
"""Wordpiece generator for the WMT'17 zh-en dataset."""
datasets = _ZHEN_TRAIN_DATASETS if train else _ZHEN_TEST_DATASETS
source_datasets = [[item[0], [item[1][0]]] for item in datasets]
@@ -413,6 +489,53 @@ def zhen_wordpiece_token_generator(tmp_dir, train, source_vocab_size,
source_vocab, target_vocab, EOS)
+def zhen_wordpiece_token_generator(tmp_dir, train, vocab_size):
+ return zhen_wordpiece_token_bigenerator(tmp_dir, train,
+ vocab_size, vocab_size)
+
+
+@registry.register_problem("wmt_zhen_tokens_8k")
+class WMTZhEnTokens8k(WMTProblem):
+ """Problem spec for WMT Zh-En translation."""
+
+ @property
+ def targeted_vocab_size(self):
+ return 2**13 # 8192
+
+ @property
+ def train_generator(self):
+ return zhen_wordpiece_token_generator
+
+ @property
+ def input_space_id(self):
+ return problem.SpaceID.ZH_TOK
+
+ @property
+ def target_space_id(self):
+ return problem.SpaceID.EN_TOK
+
+ def feature_encoders(self, data_dir):
+ vocab_size = self.targeted_vocab_size
+ source_vocab_filename = os.path.join(data_dir,
+ "tokens.vocab.zh.%d" % vocab_size)
+ target_vocab_filename = os.path.join(data_dir,
+ "tokens.vocab.en.%d" % vocab_size)
+ source_token = text_encoder.SubwordTextEncoder(source_vocab_filename)
+ target_token = text_encoder.SubwordTextEncoder(target_vocab_filename)
+ return {
+ "inputs": source_token,
+ "targets": target_token,
+ }
+
+
+@registry.register_problem("wmt_zhen_tokens_32k")
+class WMTZhEnTokens32k(WMTZhEnTokens8k):
+
+ @property
+ def targeted_vocab_size(self):
+ return 2**15 # 32768
+
+
def enfr_wordpiece_token_generator(tmp_dir, train, vocab_size):
"""Instance of token generator for the WMT en->fr task."""
symbolizer_vocab = generator_utils.get_or_generate_vocab(
@@ -424,6 +547,35 @@ def enfr_wordpiece_token_generator(tmp_dir, train, vocab_size):
symbolizer_vocab, EOS)
+@registry.register_problem("wmt_enfr_tokens_8k")
+class WMTEnFrTokens8k(WMTProblem):
+ """Problem spec for WMT En-Fr translation."""
+
+ @property
+ def targeted_vocab_size(self):
+ return 2**13 # 8192
+
+ @property
+ def train_generator(self):
+ return enfr_wordpiece_token_generator
+
+ @property
+ def input_space_id(self):
+ return problem.SpaceID.EN_TOK
+
+ @property
+ def target_space_id(self):
+ return problem.SpaceID.FR_TOK
+
+
+@registry.register_problem("wmt_enfr_tokens_32k")
+class WMTEnFrTokens32k(WMTEnFrTokens8k):
+
+ @property
+ def targeted_vocab_size(self):
+ return 2**15 # 32768
+
+
def enfr_character_generator(tmp_dir, train):
"""Instance of character generator for the WMT en->fr task."""
character_vocab = text_encoder.ByteTextEncoder()
@@ -433,6 +585,28 @@ def enfr_character_generator(tmp_dir, train):
return character_generator(data_path + ".lang1", data_path + ".lang2",
character_vocab, EOS)
+
+@registry.register_problem("wmt_enfr_characters")
+class WMTEnFrCharacters(WMTProblem):
+ """Problem spec for WMT En-Fr translation."""
+
+ @property
+ def is_character_level(self):
+ return True
+
+ @property
+ def train_generator(self):
+ return enfr_character_generator
+
+ @property
+ def input_space_id(self):
+ return problem.SpaceID.EN_CHR
+
+ @property
+ def target_space_id(self):
+ return problem.SpaceID.FR_CHR
+
+
def mken_wordpiece_token_generator(tmp_dir, train, vocab_size):
"""Wordpiece generator for the SETimes Mk-En dataset."""
datasets = _MKEN_TRAIN_DATASETS if train else _MKEN_TEST_DATASETS
@@ -447,6 +621,27 @@ def mken_wordpiece_token_generator(tmp_dir, train, vocab_size):
symbolizer_vocab, EOS)
+@registry.register_problem("setimes_mken_tokens_32k")
+class SETimesMkEnTokens32k(WMTProblem):
+ """Problem spec for SETimes Mk-En translation."""
+
+ @property
+ def targeted_vocab_size(self):
+ return 2**15 # 32768
+
+ @property
+ def train_generator(self):
+ return mken_wordpiece_token_generator
+
+ @property
+ def input_space_id(self):
+ return problem.SpaceID.MK_TOK
+
+ @property
+ def target_space_id(self):
+ return problem.SpaceID.EN_TOK
+
+
def parsing_character_generator(tmp_dir, train):
character_vocab = text_encoder.ByteTextEncoder()
filename = "parsing_%s" % ("train" if train else "dev")
diff --git a/tensor2tensor/models/attention_lm.py b/tensor2tensor/models/attention_lm.py
index 947dc9306..752de038e 100644
--- a/tensor2tensor/models/attention_lm.py
+++ b/tensor2tensor/models/attention_lm.py
@@ -101,8 +101,6 @@ def attention_lm_decoder(decoder_input,
y: a Tensors
"""
x = decoder_input
- # Summaries don't work in multi-problem setting yet.
- summaries = "problems" not in hparams.values() or len(hparams.problems) == 1
with tf.variable_scope(name):
for layer in xrange(hparams.num_hidden_layers):
with tf.variable_scope("layer_%d" % layer):
@@ -117,7 +115,6 @@ def attention_lm_decoder(decoder_input,
hparams.hidden_size,
hparams.num_heads,
hparams.attention_dropout,
- summaries=summaries,
name="decoder_self_attention"))
x = residual_fn(x,
common_layers.conv_hidden_relu(
diff --git a/tensor2tensor/models/attention_lm_moe.py b/tensor2tensor/models/attention_lm_moe.py
index 952ff1a71..2754e8366 100644
--- a/tensor2tensor/models/attention_lm_moe.py
+++ b/tensor2tensor/models/attention_lm_moe.py
@@ -69,7 +69,6 @@ def residual_fn(x, y):
hparams.hidden_size,
hparams.num_heads,
hparams.attention_dropout,
- summaries=True,
name="decoder_self_attention")
x = dp(residual_fn, x, y)
with tf.variable_scope("ffn"):
diff --git a/tensor2tensor/models/common_attention.py b/tensor2tensor/models/common_attention.py
index 49cd40285..6aa8a2a07 100644
--- a/tensor2tensor/models/common_attention.py
+++ b/tensor2tensor/models/common_attention.py
@@ -280,13 +280,13 @@ def attention_image_summary(attn, image_shapes=None):
(query_rows, query_cols, query_channels,
memory_rows, memory_cols, memory_channels).
"""
- num_heads = attn.get_shape().as_list()[1]
+ num_heads = tf.shape(attn)[1]
# [batch, query_length, memory_length, num_heads]
image = tf.transpose(attn, [0, 2, 3, 1])
image = tf.pow(image, 0.2) # for high-dynamic-range
# Each head will correspond to one of RGB.
# pad the heads to be a multiple of 3
- image = tf.pad(image, [[0, 0], [0, 0], [0, 0], [0, -num_heads % 3]])
+ image = tf.pad(image, [[0, 0], [0, 0], [0, 0], [0, tf.mod(-num_heads, 3)]])
image = split_last_dimension(image, 3)
image = tf.reduce_max(image, 4)
if image_shapes is not None:
@@ -312,7 +312,6 @@ def dot_product_attention(q,
v,
bias,
dropout_rate=0.0,
- summaries=False,
image_shapes=None,
name=None):
"""dot-product attention.
@@ -323,7 +322,6 @@ def dot_product_attention(q,
v: a Tensor with shape [batch, heads, length_kv, depth_v]
bias: bias Tensor (see attention_bias())
dropout_rate: a floating point number
- summaries: a boolean
image_shapes: optional tuple of integer scalars.
see comments for attention_image_summary()
name: an optional string
@@ -340,11 +338,99 @@ def dot_product_attention(q,
weights = tf.nn.softmax(logits, name="attention_weights")
# dropping out the attention links for each of the heads
weights = tf.nn.dropout(weights, 1.0 - dropout_rate)
- if summaries and not tf.get_variable_scope().reuse:
+ if not tf.get_variable_scope().reuse:
attention_image_summary(weights, image_shapes)
return tf.matmul(weights, v)
+def masked_local_attention_1d(
+ q, k, v, block_length=128, name=None):
+ """Attention to the source position and a neigborhood to the left of it.
+
+ The sequence is divided into blocks of length block_size.
+ Attention for a given query position can only see memory positions
+ less than or equal to the query position, in the corresponding block
+ and the previous block.
+
+ If mask_right is True, then a target position cannot see greater source
+ positions.
+
+ Args:
+ q: a Tensor with shape [batch, heads, length, depth_k]
+ k: a Tensor with shape [batch, heads, length, depth_k]
+ v: a Tensor with shape [batch, heads, length, depth_v]
+ block_length: an integer
+ name: an optional string
+
+ Returns:
+ a Tensor of shape [batch, heads, length, depth_v]
+ """
+ with tf.variable_scope(name, default_name="local_attention_1d",
+ values=[q, k, v]):
+ v_shape = v.get_shape()
+ batch = tf.shape(q)[0]
+ heads = tf.shape(q)[1]
+ length = tf.shape(q)[2]
+ # If (length < 2 * block_length), then we use only one block.
+ block_length = tf.where(tf.less(length, block_length * 2),
+ length, block_length)
+ depth_k = tf.shape(q)[3]
+ depth_v = tf.shape(v)[3]
+ original_length = length
+ padding_size = tf.mod(-length, block_length)
+ length += padding_size
+ padding = [[0, 0], [0, 0], [0, padding_size], [0, 0]]
+ q = tf.pad(q, padding)
+ k = tf.pad(k, padding)
+ v = tf.pad(v, padding)
+ num_blocks = tf.div(length, block_length)
+
+ # compute attention for the first query block.
+ first_q = tf.slice(q, [0, 0, 0, 0], [-1, -1, block_length, -1])
+ first_k = tf.slice(k, [0, 0, 0, 0], [-1, -1, block_length, -1])
+ first_v = tf.slice(v, [0, 0, 0, 0], [-1, -1, block_length, -1])
+ first_output = dot_product_attention(
+ first_q, first_k, first_v, attention_bias_lower_triangle(block_length),
+ name="fist_block")
+
+ # compute attention for all subsequent query blocks.
+ q = tf.reshape(q, [batch, heads, num_blocks, block_length, depth_k])
+ k = tf.reshape(k, [batch, heads, num_blocks, block_length, depth_k])
+ v = tf.reshape(v, [batch, heads, num_blocks, block_length, depth_v])
+
+ def local(x):
+ """Create a local version of the keys or values."""
+ prev_block = tf.slice(
+ x, [0, 0, 0, 0, 0], [-1, -1, num_blocks - 1, -1, -1])
+ cur_block = tf.slice(
+ x, [0, 0, 1, 0, 0], [-1, -1, -1, -1, -1])
+ return tf.concat([prev_block, cur_block], 3)
+ local_k = local(k)
+ local_v = local(v)
+ tail_q = tf.slice(q, [0, 0, 1, 0, 0], [-1, -1, -1, -1, -1])
+
+ local_length = tf.shape(local_k)[3]
+
+ # [batch, heads, num_blocks - 1, block_length, local_length]
+ attention = tf.matmul(tail_q, local_k, transpose_b=True)
+
+ # make sure source_pos <= target_pos
+ good_part = tf.matrix_band_part(
+ tf.ones([block_length, local_length]), -1, tf.to_int64(block_length))
+ mask = (1.0 - good_part) * -1e9
+ attention += tf.reshape(mask, [1, 1, 1, block_length, local_length])
+ attention = tf.nn.softmax(attention)
+ # TODO(noam): figure out how to show a summary for the remaining blocks.
+ # The naive way currently causes errors due to empty tensors.
+ # output: [batch, heads, num_blocks-1, block_length, depth_v]
+ output = tf.matmul(attention, local_v)
+ output = tf.reshape(output, [batch, heads, -1, depth_v])
+ output = tf.concat([first_output, output], axis=2)
+ output = tf.slice(output, [0, 0, 0, 0], [-1, -1, original_length, -1])
+ output.set_shape(v_shape)
+ return output
+
+
def multihead_attention(query_antecedent,
memory_antecedent,
bias,
@@ -353,8 +439,9 @@ def multihead_attention(query_antecedent,
output_depth,
num_heads,
dropout_rate,
- summaries=False,
image_shapes=None,
+ attention_type="dot_product",
+ block_length=128,
name=None):
"""Multihead scaled-dot-product attention with input/output transformations.
@@ -367,9 +454,10 @@ def multihead_attention(query_antecedent,
output_depth: an integer
num_heads: an integer dividing total_key_depth and total_value_depth
dropout_rate: a floating point number
- summaries: a boolean
image_shapes: optional tuple of integer scalars.
see comments for attention_image_summary()
+ attention_type: a string, either "dot_product" or "local_mask_right"
+ block_length: an integer - relevent for "local_mask_right"
name: an optional string
Returns:
@@ -414,8 +502,12 @@ def multihead_attention(query_antecedent,
v = split_heads(v, num_heads)
key_depth_per_head = total_key_depth // num_heads
q *= key_depth_per_head**-0.5
- x = dot_product_attention(
- q, k, v, bias, dropout_rate, summaries, image_shapes)
+ if attention_type == "dot_product":
+ x = dot_product_attention(
+ q, k, v, bias, dropout_rate, image_shapes)
+ else:
+ assert attention_type == "local_mask_right"
+ x = masked_local_attention_1d(q, k, v, block_length=block_length)
x = combine_heads(x)
x = common_layers.conv1d(x, output_depth, 1, name="output_transform")
return x
diff --git a/tensor2tensor/models/common_hparams.py b/tensor2tensor/models/common_hparams.py
index f067b724e..ff856968b 100644
--- a/tensor2tensor/models/common_hparams.py
+++ b/tensor2tensor/models/common_hparams.py
@@ -72,6 +72,9 @@ def basic_params1():
# setting the max length in a minibatch. 0 means default behavior,
# max_length = hparams.batch_size * length_multiplier
max_length=0,
+ # If set to True, drop sequences longer than max_length during eval.
+ # This affects the validity of the evaluation metrics.
+ eval_drop_long_sequences=int(False),
# in SymbolModality, share the output embeddings and the softmax
# variables.
# You can also share the input embeddings with the output embeddings
diff --git a/tensor2tensor/models/common_layers.py b/tensor2tensor/models/common_layers.py
index 1e7050570..638535aa2 100644
--- a/tensor2tensor/models/common_layers.py
+++ b/tensor2tensor/models/common_layers.py
@@ -777,7 +777,7 @@ def moe_layer(data_parallelism,
xs_2d = dp(tf.reshape, xs, [[-1, model_hidden_size]] * dp.n)
# Call the MoE
moe_out_2d, importance, load, _, _ = moe.Eval(
- dp.devices, xs_2d, train, identifiers=None, summaries=True)
+ dp.devices, xs_2d, train, identifiers=None)
# Reshape the output to the original shape.
moe_out = dp(tf.reshape, moe_out_2d, dp(tf.shape, xs))
# These losses encourage equal load on the different experts.
@@ -785,7 +785,7 @@ def moe_layer(data_parallelism,
return moe_out, loss
-def simple_attention(target, source, bias=None, summaries=True):
+def simple_attention(target, source, bias=None):
"""A simple attention function.
Args:
@@ -795,7 +795,6 @@ def simple_attention(target, source, bias=None, summaries=True):
`[batch, source_timesteps_1, source_timesteps_2, depth]`
bias: an optional `Tensor` with shape `[batch, timesteps, 1, 1]` used
to mask the attention to not attend to padding of input.
- summaries: Boolean, whether to output summaries.
Returns:
a `Tensor` with same shape as `target`
@@ -814,7 +813,7 @@ def simple_attention(target, source, bias=None, summaries=True):
if bias is not None:
attention += tf.expand_dims(tf.squeeze(bias, axis=[2, 3]), axis=1)
attention = tf.nn.softmax(attention)
- if summaries and not tf.get_variable_scope().reuse:
+ if not tf.get_variable_scope().reuse:
tf.summary.image("attention", tf.expand_dims(attention, 3), max_outputs=5)
attended = tf.matmul(attention, source)
return tf.reshape(attended, target_shape)
@@ -861,8 +860,7 @@ def multiscale_conv_sum(inputs, output_size, dilation_rates_and_kernel_sizes,
def multiscale_conv_and_attention(x,
padding,
hparams,
- source=None,
- summaries=True):
+ source=None):
"""A common part of t2t layers.
First, do a linear multiscale convolution
@@ -875,7 +873,6 @@ def multiscale_conv_and_attention(x,
padding: a padding type
hparams: hyperparameters for model
source: optional source tensor for attention. (encoder output)
- summaries: Boolean, whether to output summaries.
Returns:
a Tensor.
@@ -893,7 +890,7 @@ def multiscale_conv_and_attention(x,
x = conv(x, hparams.hidden_size, (1, 1))
x = noam_norm(x + conv_sum)
if source is not None:
- x = noam_norm(x + simple_attention(x, source, summaries=summaries))
+ x = noam_norm(x + simple_attention(x, source))
return x
@@ -930,8 +927,7 @@ def conv_with_pools(inputs, output_size, kernel_size, pool_sizes, pooling_type,
def conv_with_pools_and_attention(x,
padding,
hparams,
- source=None,
- summaries=True):
+ source=None):
"""A common part of t2t layers.
First, do conv_with_pools
@@ -944,7 +940,6 @@ def conv_with_pools_and_attention(x,
padding: a padding type
hparams: hyperparameters for model
source: optional source tensor for attention. (encoder output)
- summaries: Boolean, whether to output summaries.
Returns:
a Tensor.
@@ -959,7 +954,7 @@ def conv_with_pools_and_attention(x,
conv_sum += x
x = noam_norm(conv_sum)
if source is not None:
- x = noam_norm(x + simple_attention(x, source, summaries=summaries))
+ x = noam_norm(x + simple_attention(x, source))
return x
@@ -1057,7 +1052,6 @@ def attention_1d_v0(source,
transform_source=True,
transform_target=True,
transform_output=True,
- summaries=True,
name=None):
"""multi-headed attention.
@@ -1075,7 +1069,6 @@ def attention_1d_v0(source,
transform_source: a boolean
transform_target: a boolean
transform_output: a boolean
- summaries: a boolean
name: an optional string
Returns:
@@ -1116,7 +1109,7 @@ def _maybe_transform(t, size, should_transform, name):
mask = (1.0 - mask) * -1e9
attention += mask
attention = tf.nn.softmax(attention)
- if summaries and not tf.get_variable_scope().reuse:
+ if not tf.get_variable_scope().reuse:
# Compute a color image summary.
image = tf.reshape(attention,
[batch, num_heads, target_length, source_length])
@@ -1162,7 +1155,6 @@ def conv_hidden_relu(inputs,
output_size,
kernel_size=(1, 1),
second_kernel_size=(1, 1),
- summaries=True,
dropout=0.0,
**kwargs):
"""Hidden layer with RELU activation followed by linear projection."""
@@ -1183,7 +1175,7 @@ def conv_hidden_relu(inputs,
**kwargs)
if dropout != 0.0:
h = tf.nn.dropout(h, 1.0 - dropout)
- if summaries and not tf.get_variable_scope().reuse:
+ if not tf.get_variable_scope().reuse:
tf.summary.histogram("hidden_density_logit",
relu_density_logit(
h, list(range(inputs.shape.ndims - 1))))
diff --git a/tensor2tensor/models/long_answer.py b/tensor2tensor/models/long_answer.py
new file mode 100644
index 000000000..7bb6a4a55
--- /dev/null
+++ b/tensor2tensor/models/long_answer.py
@@ -0,0 +1,274 @@
+# Copyright 2017 The Tensor2Tensor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Model to generate long answers to short questions.
+
+E.g. wiki_32k title->article dataset.
+
+Variant on attention_lm_moe.py
+ - prepend the inputs to the targets.
+ - use masked local attention to avoid quadratic space and time blowup for
+ long sequences.
+
+This model is still highly experimental and under rapid iteration.
+
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# Dependency imports
+
+from six.moves import xrange # pylint: disable=redefined-builtin
+
+from tensor2tensor.models import common_attention
+from tensor2tensor.models import common_hparams
+from tensor2tensor.models import common_layers
+from tensor2tensor.utils import registry
+from tensor2tensor.utils import t2t_model
+
+import tensorflow as tf
+
+
+@registry.register_model
+class LongAnswer(t2t_model.T2TModel):
+ """Attention net. See file docstring."""
+
+ def model_fn_body_sharded(self, sharded_features):
+ # Remove dropout if not training
+ hparams = self._hparams
+ dp = self._data_parallelism
+ targets = sharded_features["targets"]
+ targets = dp(tf.squeeze, targets, 2)
+ inputs = sharded_features["inputs"]
+ inputs = dp(tf.squeeze, inputs, 2)
+
+ decoder_input = dp(long_answer_prepare_decoder, inputs, targets, hparams)
+
+ def residual_fn(x, y):
+ return common_layers.layer_norm(x + tf.nn.dropout(
+ y, 1.0 - hparams.residual_dropout))
+
+ x = dp(tf.nn.dropout, decoder_input, 1.0 - hparams.residual_dropout)
+ extra_loss = 0.0
+ for layer in xrange(hparams.num_hidden_layers):
+ with tf.variable_scope("layer_%d" % layer):
+ with tf.variable_scope("attention"):
+ y = dp(common_attention.multihead_attention,
+ x,
+ None,
+ None,
+ hparams.attention_key_channels or hparams.hidden_size,
+ hparams.attention_value_channels or hparams.hidden_size,
+ hparams.hidden_size,
+ hparams.num_heads,
+ hparams.attention_dropout,
+ attention_type="local_mask_right",
+ block_length=hparams.block_length,
+ name="decoder_self_attention")
+ x = dp(residual_fn, x, y)
+ with tf.variable_scope("ffn"):
+ if str(layer) in hparams.moe_layers.split(","):
+ y, loss = common_layers.moe_layer(
+ dp, self._ps_devices, x,
+ hparams.mode == tf.contrib.learn.ModeKeys.TRAIN,
+ hparams.hidden_size,
+ hparams.moe_hidden_size, hparams.moe_n1, hparams.moe_n2,
+ hparams.moe_loss_coef)
+ extra_loss += loss
+ else:
+ y = dp(common_layers.conv_hidden_relu,
+ x,
+ hparams.filter_size,
+ hparams.hidden_size,
+ dropout=hparams.relu_dropout)
+ x = dp(residual_fn, x, y)
+ x = dp(long_answer_output, x, inputs)
+ return x, extra_loss
+
+
+def long_answer_prepare_decoder(inputs, targets, hparams):
+ """Prepare one shard of the model for the decoder.
+
+ Args:
+ inputs: a Tensor.
+ targets: a Tensor.
+ hparams: run hyperparameters
+
+ Returns:
+ decoder_input: a Tensor, bottom of decoder stack
+ """
+ decoder_input = tf.concat([
+ length_embedding(targets, hparams), inputs,
+ common_layers.shift_left_3d(targets)], 1)
+ if hparams.pos == "timing":
+ decoder_input = common_attention.add_timing_signal_1d(decoder_input)
+ return decoder_input
+
+
+def length_embedding(targets, hparams):
+ """An embedding indicating approximate target length.
+
+ This is a bit of a hack, where we want to be able to request a particular
+ target length during inference.
+ During training, we sometimes provide a target length.
+ During eval, we never provide a target length.
+
+ Args:
+ targets: a Tensor.
+ hparams: run hyperparameters
+
+ Returns:
+ a Tensor with shape [batch, 1, hparams.hidden_size]
+ """
+ # encode the approx target length in case we want to specify it
+ # during inference.
+ batch = tf.shape(targets)[0]
+ padded_target_length = tf.shape(targets)[1]
+ if hparams.mode == tf.contrib.learn.ModeKeys.TRAIN:
+ lengths = padded_target_length * tf.to_int32(
+ tf.less(tf.random_uniform([batch]),
+ hparams.answer_length_prob_train))
+ elif hparams.mode == tf.contrib.learn.ModeKeys.EVAL:
+ lengths = 0
+ else:
+ assert hparams.mode == tf.contrib.learn.ModeKeys.INFER
+ lengths = hparams.answer_length_infer
+ lengths = tf.to_int32(tf.log(tf.to_float(lengths + 1)))
+ lengths = tf.zeros([batch], dtype=tf.int32) + lengths
+ ret = tf.gather(
+ tf.get_variable("answer_length", [100, hparams.hidden_size]), lengths)
+ return tf.expand_dims(ret, 1)
+
+
+def long_answer_output(x, inputs):
+ """Strip initial part corresponding to the inputs and the length embedding."""
+ x = tf.slice(x, [0, tf.shape(inputs)[1] + 1, 0], [-1, -1, -1])
+ x = tf.expand_dims(x, 2)
+ return x
+
+
+@registry.register_hparams
+def long_answer_base():
+ """Set of hyperparameters.
+
+ Returns:
+ a hparams object
+ """
+ hparams = common_hparams.basic_params1()
+ hparams.hidden_size = 1024
+ hparams.batch_size = 8192
+ hparams.max_length = 8192
+ hparams.dropout = 0.0
+ hparams.batching_mantissa_bits = 3
+ hparams.clip_grad_norm = 0. # i.e. no gradient clipping
+ hparams.optimizer_adam_epsilon = 1e-9
+ hparams.learning_rate_decay_scheme = "noam"
+ hparams.learning_rate = 0.1
+ hparams.learning_rate_warmup_steps = 1000
+ hparams.initializer_gain = 1.0
+ hparams.num_hidden_layers = 4
+ hparams.initializer = "uniform_unit_scaling"
+ hparams.weight_decay = 0.0
+ hparams.optimizer_adam_beta1 = 0.9
+ hparams.optimizer_adam_beta2 = 0.98
+ hparams.num_sampled_classes = 0
+ hparams.label_smoothing = 0.0
+ hparams.shared_embedding_and_softmax_weights = int(True)
+ hparams.sampling_method = "random"
+ hparams.add_hparam("filter_size", 2048) # Add new ones like this.
+ # comma-separated list of layer numbers.
+ # At each of these layers, we replace the ffn with a mixture of experts.
+ hparams.add_hparam("moe_layers", "2")
+ # If moe_n2 is None, then use a flat MoE with moe_n1 experts.
+ # If moe_n2 is an integer, then use a hierarchical MoE
+ # consisting of moe_n1 groups of moe_n2 experts each.
+ hparams.add_hparam("moe_n1", 64)
+ hparams.add_hparam("moe_n2", 0)
+ hparams.add_hparam("moe_hidden_size", 2048)
+ hparams.add_hparam("moe_loss_coef", 1e-2)
+ # attention-related flags
+ hparams.add_hparam("num_heads", 8)
+ hparams.add_hparam("attention_key_channels", 0)
+ hparams.add_hparam("attention_value_channels", 0)
+ # All hyperparameters ending in "dropout" are automatically set to 0.0
+ # when not in training mode.
+ hparams.add_hparam("attention_dropout", 0.0)
+ hparams.add_hparam("relu_dropout", 0.0)
+ hparams.add_hparam("residual_dropout", 0.0)
+ hparams.add_hparam("pos", "timing") # timing, none
+ hparams.add_hparam("block_length", 512)
+ hparams.add_hparam("answer_length_prob_train", 0.5)
+ hparams.add_hparam("answer_length_infer", 1000)
+ # We cannot handle long sequence at this point, so drop them, during eval.
+ # This affects evaluation metrics.
+ # TODO(noam): find a different workaround
+ hparams.eval_drop_long_sequences = int(True)
+ return hparams
+
+
+@registry.register_hparams
+def long_answer_tiny():
+ """Cheap model for validation.
+
+ Returns:
+ an hparams object.
+ """
+ hparams = long_answer_base()
+ hparams.num_hidden_layers = 3
+ hparams.hidden_size = 512
+ hparams.filter_size = 1024
+ hparams.moe_layers = "2"
+ hparams.moe_hidden_size = 1024
+ hparams.block_length = 128
+ hparams.moe_n1 = 8
+ hparams.batch_size = 2048
+ hparams.max_length = 2048
+ return hparams
+
+
+@registry.register_hparams
+def long_answer_small():
+ """Cheap model for single-gpu training.
+
+ Returns:
+ an hparams object.
+ """
+ hparams = long_answer_base()
+ hparams.num_hidden_layers = 4
+ hparams.hidden_size = 512
+ hparams.filter_size = 2048
+ hparams.moe_n1 = 128
+ hparams.moe_layers = "2"
+ hparams.moe_hidden_size = 2048
+ return hparams
+
+
+@registry.register_hparams
+def long_answer_large():
+ """Large model for distributed training.
+
+ Returns:
+ an hparams object.
+ """
+ hparams = long_answer_base()
+ hparams.num_hidden_layers = 5
+ hparams.moe_layers = "3"
+ hparams.hidden_size = 1024
+ hparams.filter_size = 4096
+ hparams.moe_hidden_size = 4096
+ hparams.moe_n1 = 128
+ hparams.block_length = 1024
+ return hparams
diff --git a/tensor2tensor/models/models.py b/tensor2tensor/models/models.py
index 0ca11996e..2cf639426 100644
--- a/tensor2tensor/models/models.py
+++ b/tensor2tensor/models/models.py
@@ -26,6 +26,7 @@
from tensor2tensor.models import attention_lm_moe
from tensor2tensor.models import bluenet
from tensor2tensor.models import bytenet
+from tensor2tensor.models import long_answer
from tensor2tensor.models import lstm
from tensor2tensor.models import modalities
from tensor2tensor.models import multimodel
diff --git a/tensor2tensor/models/multimodel.py b/tensor2tensor/models/multimodel.py
index 6f12db86d..bf06dfd65 100644
--- a/tensor2tensor/models/multimodel.py
+++ b/tensor2tensor/models/multimodel.py
@@ -138,7 +138,6 @@ def flatten(inputs):
hparams.hidden_size,
hparams.num_heads,
hparams.attention_dropout,
- summaries=False,
name="decoder_self_attention")
z = dp(common_attention.multihead_attention,
y,
@@ -149,7 +148,6 @@ def flatten(inputs):
hparams.hidden_size,
hparams.num_heads,
hparams.attention_dropout,
- summaries=False,
name="encdec_attention")
x = dp(residual_fn3, x, y, z, hparams)
with tf.variable_scope("ffn"):
diff --git a/tensor2tensor/models/slicenet.py b/tensor2tensor/models/slicenet.py
index 43913eab1..2ad4c89d1 100644
--- a/tensor2tensor/models/slicenet.py
+++ b/tensor2tensor/models/slicenet.py
@@ -64,8 +64,7 @@ def attention(targets_shifted, inputs_encoded, norm_fn, hparams, bias=None):
hparams.hidden_size,
hparams.num_heads,
hparams.attention_dropout,
- name="self_attention",
- summaries=False)
+ name="self_attention")
qv = common_attention.multihead_attention(
qv,
inputs_encoded,
@@ -75,12 +74,11 @@ def attention(targets_shifted, inputs_encoded, norm_fn, hparams, bias=None):
hparams.hidden_size,
hparams.num_heads,
hparams.attention_dropout,
- name="encdec_attention",
- summaries=False)
+ name="encdec_attention")
return tf.expand_dims(qv, 2)
elif hparams.attention_type == "simple":
targets_with_attention = common_layers.simple_attention(
- targets_timed, inputs_encoded, bias=bias, summaries=False)
+ targets_timed, inputs_encoded, bias=bias)
return norm_fn(targets_shifted + targets_with_attention, name="attn_norm")
diff --git a/tensor2tensor/models/transformer.py b/tensor2tensor/models/transformer.py
index b341d6fe0..b24f7fa50 100644
--- a/tensor2tensor/models/transformer.py
+++ b/tensor2tensor/models/transformer.py
@@ -60,8 +60,6 @@ def residual_fn(x, y):
return common_layers.layer_norm(x + tf.nn.dropout(
y, 1.0 - hparams.residual_dropout))
- # encoder_input = tf.squeeze(encoder_input, 2)
- # decoder_input = tf.squeeze(decoder_input, 2)
encoder_input = tf.nn.dropout(encoder_input, 1.0 - hparams.residual_dropout)
decoder_input = tf.nn.dropout(decoder_input, 1.0 - hparams.residual_dropout)
encoder_output = transformer_encoder(encoder_input, residual_fn,
@@ -145,8 +143,6 @@ def transformer_encoder(encoder_input,
y: a Tensors
"""
x = encoder_input
- # Summaries don't work in multi-problem setting yet.
- summaries = "problems" not in hparams.values() or len(hparams.problems) == 1
with tf.variable_scope(name):
for layer in xrange(hparams.num_hidden_layers):
with tf.variable_scope("layer_%d" % layer):
@@ -161,7 +157,6 @@ def transformer_encoder(encoder_input,
hparams.hidden_size,
hparams.num_heads,
hparams.attention_dropout,
- summaries=summaries,
name="encoder_self_attention"))
x = residual_fn(x, transformer_ffn_layer(x, hparams))
return x
@@ -191,8 +186,6 @@ def transformer_decoder(decoder_input,
y: a Tensors
"""
x = decoder_input
- # Summaries don't work in multi-problem setting yet.
- summaries = "problems" not in hparams.values() or len(hparams.problems) == 1
with tf.variable_scope(name):
for layer in xrange(hparams.num_hidden_layers):
with tf.variable_scope("layer_%d" % layer):
@@ -207,7 +200,6 @@ def transformer_decoder(decoder_input,
hparams.hidden_size,
hparams.num_heads,
hparams.attention_dropout,
- summaries=summaries,
name="decoder_self_attention"))
x = residual_fn(
x,
@@ -220,7 +212,6 @@ def transformer_decoder(decoder_input,
hparams.hidden_size,
hparams.num_heads,
hparams.attention_dropout,
- summaries=summaries,
name="encdec_attention"))
x = residual_fn(x, transformer_ffn_layer(x, hparams))
return x
diff --git a/tensor2tensor/models/transformer_alternative.py b/tensor2tensor/models/transformer_alternative.py
index aed074d56..280dbc713 100644
--- a/tensor2tensor/models/transformer_alternative.py
+++ b/tensor2tensor/models/transformer_alternative.py
@@ -140,8 +140,6 @@ def alt_transformer_decoder(decoder_input,
"""Alternative decoder."""
x = decoder_input
- # Summaries don't work in multi-problem setting yet.
- summaries = "problems" not in hparams.values() or len(hparams.problems) == 1
with tf.variable_scope(name):
for layer in xrange(hparams.num_hidden_layers):
with tf.variable_scope("layer_%d" % layer):
@@ -155,7 +153,6 @@ def alt_transformer_decoder(decoder_input,
hparams.hidden_size,
hparams.num_heads,
hparams.attention_dropout,
- summaries=summaries,
name="encdec_attention")
x_ = residual_fn(x_, composite_layer(x_, mask, hparams))
diff --git a/tensor2tensor/models/transformer_test.py b/tensor2tensor/models/transformer_test.py
index ca099c653..997b5d172 100644
--- a/tensor2tensor/models/transformer_test.py
+++ b/tensor2tensor/models/transformer_test.py
@@ -38,6 +38,7 @@ def _testTransformer(self, net):
hparams = transformer.transformer_tiny()
p_hparams = problem_hparams.test_problem_hparams(hparams, vocab_size,
vocab_size)
+ hparams.problems = [p_hparams]
inputs = -1 + np.random.random_integers(
vocab_size, size=(batch_size, input_length, 1, 1))
targets = -1 + np.random.random_integers(
diff --git a/tensor2tensor/utils/data_reader.py b/tensor2tensor/utils/data_reader.py
index a3e9835ac..cb84b9e3e 100644
--- a/tensor2tensor/utils/data_reader.py
+++ b/tensor2tensor/utils/data_reader.py
@@ -181,10 +181,13 @@ def input_pipeline(data_file_pattern, capacity, mode):
"""Input pipeline, returns a dictionary of tensors from queues."""
# Read from image TFRecords if the file has "image" in its name.
if data_file_pattern and "image" in data_file_pattern:
+ label_key = "image/class/label"
+ if "fsns" in data_file_pattern:
+ label_key = "image/unpadded_label"
data_fields = {
"image/encoded": tf.FixedLenFeature((), tf.string),
"image/format": tf.FixedLenFeature((), tf.string),
- "image/class/label": tf.VarLenFeature(tf.int64)
+ label_key: tf.VarLenFeature(tf.int64)
}
data_items_to_decoders = {
"inputs":
@@ -193,7 +196,7 @@ def input_pipeline(data_file_pattern, capacity, mode):
format_key="image/format",
channels=1 if "mnist" in data_file_pattern else 3),
"targets":
- tf.contrib.slim.tfexample_decoder.Tensor("image/class/label"),
+ tf.contrib.slim.tfexample_decoder.Tensor(label_key),
}
elif data_file_pattern and "audio" in data_file_pattern:
data_type = tf.int64 if "timit" in data_file_pattern else tf.float32
diff --git a/tensor2tensor/utils/trainer_utils.py b/tensor2tensor/utils/trainer_utils.py
old mode 100755
new mode 100644
index 66a01487c..f7d3010a9
--- a/tensor2tensor/utils/trainer_utils.py
+++ b/tensor2tensor/utils/trainer_utils.py
@@ -550,6 +550,13 @@ def nth_model(n):
optimizer=opt,
colocate_gradients_with_ops=True)
+ # Remove summaries that will fail to run because they are in conditionals.
+ # TODO(cwhipkey): Test with this code removed, later in 2017.
+ summaries = tf.get_collection_ref(tf.GraphKeys.SUMMARIES)
+ for i in range(len(summaries)-1, -1, -1):
+ if summaries[i].name.startswith("cond_"):
+ del summaries[i]
+
tf.logging.info("Global model_fn finished.")
return run_info, total_loss, train_op
@@ -1037,7 +1044,10 @@ def input_fn():
capacity *= num_datashards
examples = data_reader.input_pipeline(data_file_patterns[n],
capacity, mode)
- drop_long_sequences = mode == tf.contrib.learn.ModeKeys.TRAIN
+ if mode == tf.contrib.learn.ModeKeys.TRAIN:
+ drop_long_sequences = True
+ else:
+ drop_long_sequences = hparams.eval_drop_long_sequences
batch_size_multiplier = hparams.problems[n].batch_size_multiplier
feature_map = data_reader.batch_examples(
examples,
diff --git a/tensor2tensor/utils/usr_dir.py b/tensor2tensor/utils/usr_dir.py
index ed5623c8e..0a2d0d15c 100644
--- a/tensor2tensor/utils/usr_dir.py
+++ b/tensor2tensor/utils/usr_dir.py
@@ -12,16 +12,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-"""Utility to load code from an external directory supplied by user."""
+"""Utility to load code from an external user-supplied directory."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+import importlib
import os
import sys
-import importlib
+
+# Dependency imports
+
import tensorflow as tf
def import_usr_dir(usr_dir):
- """Import user module, if provided."""
+ """Import module at usr_dir, if provided."""
if not usr_dir:
return
dir_path = os.path.expanduser(usr_dir)