From b88c13b1b121e0924068564943d74bd8a3406383 Mon Sep 17 00:00:00 2001 From: Lukasz Kaiser Date: Wed, 5 Jul 2017 19:55:29 -0700 Subject: [PATCH] Correct text encoder, MultiModel, other merges. PiperOrigin-RevId: 161036600 --- setup.py | 2 +- .../data_generators/generator_utils.py | 17 +- tensor2tensor/data_generators/text_encoder.py | 17 +- tensor2tensor/models/bluenet.py | 402 ++++++++++++++++-- tensor2tensor/models/common_layers.py | 40 +- tensor2tensor/models/common_layers_test.py | 4 +- tensor2tensor/models/lstm.py | 206 ++++++++- tensor2tensor/models/lstm_test.py | 24 ++ tensor2tensor/models/multimodel.py | 258 ++++++----- tensor2tensor/models/multimodel_test.py | 3 +- tensor2tensor/utils/metrics.py | 22 +- 11 files changed, 796 insertions(+), 199 deletions(-) diff --git a/setup.py b/setup.py index ba3ea532a..254631d9f 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setup( name='tensor2tensor', - version='1.0.9', + version='1.0.10', description='Tensor2Tensor', author='Google Inc.', author_email='no-reply@google.com', diff --git a/tensor2tensor/data_generators/generator_utils.py b/tensor2tensor/data_generators/generator_utils.py index 8c2d75fbe..a5d4816b7 100644 --- a/tensor2tensor/data_generators/generator_utils.py +++ b/tensor2tensor/data_generators/generator_utils.py @@ -29,7 +29,7 @@ from six.moves import xrange # pylint: disable=redefined-builtin import six.moves.urllib_request as urllib # Imports urllib on Python2, urllib.request on Python3 -from tensor2tensor.data_generators.text_encoder import SubwordTextEncoder +from tensor2tensor.data_generators import text_encoder from tensor2tensor.data_generators.tokenizer import Tokenizer import tensorflow as tf @@ -218,15 +218,18 @@ def gunzip_file(gz_path, new_path): ] -def get_or_generate_vocab(tmp_dir, vocab_filename, vocab_size): - """Generate a vocabulary from the datasets listed in _DATA_FILE_URLS.""" +def get_or_generate_vocab(tmp_dir, vocab_filename, vocab_size, sources=None): + """Generate a vocabulary from the datasets in sources (_DATA_FILE_URLS).""" vocab_filepath = os.path.join(tmp_dir, vocab_filename) if os.path.exists(vocab_filepath): - vocab = SubwordTextEncoder(vocab_filepath) + tf.logging.info("Found vocab file: %s", vocab_filepath) + vocab = text_encoder.SubwordTextEncoder(vocab_filepath) return vocab + sources = sources or _DATA_FILE_URLS + tf.logging.info("Generating vocab from: %s", str(sources)) tokenizer = Tokenizer() - for source in _DATA_FILE_URLS: + for source in sources: url = source[0] filename = os.path.basename(url) read_type = "r:gz" if "tgz" in filename else "r" @@ -259,9 +262,9 @@ def get_or_generate_vocab(tmp_dir, vocab_filename, vocab_size): break line = line.strip() file_byte_budget -= len(line) - _ = tokenizer.encode(line) + _ = tokenizer.encode(text_encoder.native_to_unicode(line)) - vocab = SubwordTextEncoder.build_to_target_size( + vocab = text_encoder.SubwordTextEncoder.build_to_target_size( vocab_size, tokenizer.token_counts, 1, 1e3) vocab.store_to_file(vocab_filepath) return vocab diff --git a/tensor2tensor/data_generators/text_encoder.py b/tensor2tensor/data_generators/text_encoder.py index 7b00a85d2..0a05cb721 100644 --- a/tensor2tensor/data_generators/text_encoder.py +++ b/tensor2tensor/data_generators/text_encoder.py @@ -36,10 +36,10 @@ # Conversion between Unicode and UTF-8, if required (on Python2) -_native_to_unicode = (lambda s: s.decode("utf-8")) if PY2 else (lambda s: s) +native_to_unicode = (lambda s: s.decode("utf-8")) if PY2 else (lambda s: s) -_unicode_to_native = (lambda s: s.encode("utf-8")) if PY2 else (lambda s: s) +unicode_to_native = (lambda s: s.encode("utf-8")) if PY2 else (lambda s: s) # Reserved tokens for things like padding and EOS symbols. @@ -220,7 +220,7 @@ def encode(self, raw_text): a list of integers in the range [0, vocab_size) """ return self._tokens_to_subtokens(self._tokenizer.encode( - _native_to_unicode(raw_text))) + native_to_unicode(raw_text))) def decode(self, subtokens): """Converts a sequence of subtoken ids to a native string. @@ -230,7 +230,7 @@ def decode(self, subtokens): Returns: a native string """ - return _unicode_to_native(self._tokenizer.decode( + return unicode_to_native(self._tokenizer.decode( self._subtokens_to_tokens(subtokens))) @property @@ -335,6 +335,9 @@ def bisect(min_val, max_val): else: other_subtokenizer = bisect(min_val, present_count - 1) + if other_subtokenizer is None: + return subtokenizer + if (abs(other_subtokenizer.vocab_size - target_size) < abs(subtokenizer.vocab_size - target_size)): return other_subtokenizer @@ -449,13 +452,13 @@ def _load_from_file(self, filename): subtoken_strings = [] with tf.gfile.Open(filename) as f: for line in f: - subtoken_strings.append(_native_to_unicode(line.strip()[1:-1])) + subtoken_strings.append(native_to_unicode(line.strip()[1:-1])) self._init_from_list(subtoken_strings) def store_to_file(self, filename): with tf.gfile.Open(filename, "w") as f: for subtoken_string in self._all_subtoken_strings: - f.write("'" + _unicode_to_native(subtoken_string) + "'\n") + f.write("'" + unicode_to_native(subtoken_string) + "'\n") def _escape_token(self, token): r"""Escape away underscores and OOV characters and append '_'. @@ -524,7 +527,7 @@ def get_token_counts(cls, text_filepattern, corpus_max_lines): with tf.gfile.Open(text_filename) as f: for line in f: # The tokenizer updates token_counts in encode() - tok.encode(_native_to_unicode(line.strip())) + tok.encode(native_to_unicode(line.strip())) lines_read += 1 if corpus_max_lines > 0 and lines_read > corpus_max_lines: return tok.token_counts diff --git a/tensor2tensor/models/bluenet.py b/tensor2tensor/models/bluenet.py index 8f4c89eac..c0533ee42 100644 --- a/tensor2tensor/models/bluenet.py +++ b/tensor2tensor/models/bluenet.py @@ -18,8 +18,12 @@ from __future__ import division from __future__ import print_function +import collections + # Dependency imports +import numpy as np + from six.moves import xrange # pylint: disable=redefined-builtin from tensor2tensor.models import common_hparams @@ -30,6 +34,328 @@ import tensorflow as tf +# var: 1d tensor, raw weights for each choice +# tempered_var: raw weights with temperature applied +# inv_t: inverse of the temperature to use when normalizing `var` +# normalized: same shape as var, but where each item is between 0 and 1, and +# the sum is 1 +SelectionWeights = collections.namedtuple( + "SelectionWeights", ["var", "tempered_var", "inv_t", "normalized"]) + + +def create_selection_weights(name, + type_, + shape, + inv_t=1, + initializer=tf.zeros_initializer(), + regularizer=None, + names=None): + """Create a SelectionWeights tuple. + + Args: + name: Name for the underlying variable containing the unnormalized weights. + type_: "softmax" or "sigmoid" or ("softmax_topk", k) where k is an int. + shape: Shape for the variable. + inv_t: Inverse of the temperature to use in normalization. + initializer: Initializer for the variable, passed to `tf.get_variable`. + regularizer: Regularizer for the variable. A callable which accepts + `tempered_var` and `normalized`. + names: Name of each selection. + + Returns: + The created SelectionWeights tuple. + + Raises: + ValueError: if type_ is not in the supported range. + """ + var = tf.get_variable(name, shape, initializer=initializer) + + if callable(inv_t): + inv_t = inv_t(var) + if inv_t == 1: + tempered_var = var + else: + tempered_var = var * inv_t + + if type_ == "softmax": + weights = tf.nn.softmax(tempered_var) + elif type_ == "sigmoid": + weights = tf.nn.sigmoid(tempered_var) + elif isinstance(type_, (list, tuple)) and type_[0] == "softmax_topk": + assert len(shape) == 1 + # TODO(rshin): Change this to select without replacement? + selection = tf.multinomial(tf.expand_dims(var, axis=0), 4) + selection = tf.squeeze(selection, axis=0) # [k] selected classes. + to_run = tf.one_hot(selection, shape[0]) # [k x nmodules] one-hot. + # [nmodules], 0=not run, 1=run. + to_run = tf.minimum(tf.reduce_sum(to_run, axis=0), 1) + weights = tf.nn.softmax(tempered_var - 1e9 * (1.0 - to_run)) + else: + raise ValueError("Unknown type: %s" % type_) + + if regularizer is not None: + loss = regularizer(tempered_var, weights) + if loss is not None: + tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES, loss) + + if names is not None: + tf.get_collection_ref("selection_weight_names/" + var.name).extend( + names.flatten() + if isinstance(names, np.ndarray) else names) + tf.add_to_collection("selection_weight_names_tensor/" + var.name, + tf.constant(names)) + + return SelectionWeights( + var=var, + tempered_var=tempered_var, + inv_t=inv_t, + normalized=weights) + + +def kernel_premultiplier(max_kernel_size, kernel_sizes, input_channels, + kernel_selection_weights, channel_selection_weights): + """Get weights to multiply the kernel with, before convolving. + + Args: + max_kernel_size: (int, int) tuple giving the largest kernel size. + kernel_sizes: A list of (height, width) pairs of integers, containing + different kernel sizes to use. + input_channels: A list of (begin, end) pairs of integers, which describe + which channels in the input to use. + kernel_selection_weights: SelectionWeights object to use for choosing + among kernel sizes. + channel_selection_weights: SelectionWeights object to use for choosing + among which input channels to use. + + Returns: + The multiplier. + """ + kernel_weights = [] + for kernel_i, (h, w) in enumerate(kernel_sizes): + top = (max_kernel_size[0] - h) // 2 + bot = max_kernel_size[0] - h - top + left = (max_kernel_size[1] - w) // 2 + right = max_kernel_size[1] - w - left + kernel_weight = tf.fill((h, w), + kernel_selection_weights.normalized[kernel_i]) + if top != 0 or bot != 0 or left != 0 or right != 0: + kernel_weight = tf.pad(kernel_weight, [[top, bot], [left, right]]) + kernel_weights.append(kernel_weight) + kernel_weight = tf.add_n(kernel_weights) + + channel_weights = [] + min_channel = np.min(input_channels) + max_channel = np.max(input_channels) + for channel_i, (begin, end) in enumerate(input_channels): + channel_weight = tf.pad( + tf.fill((end - begin,), + channel_selection_weights.normalized[channel_i]), + [[begin - min_channel, max_channel - end]]) + channel_weights.append(channel_weight) + channel_weight = tf.add_n(channel_weights) + + multiplier = (tf.reshape(kernel_weight, max_kernel_size + (1, 1)) * + tf.reshape(channel_weight, (1, 1, -1, 1))) + return multiplier + + +def make_subseparable_kernel( + kernel_size, + input_channels, + filters, + separability, + kernel_initializer, + kernel_regularizer): + """Make a kernel to do subseparable convolution wiht `tf.nn.conv2d`. + + Args: + kernel_size: (height, width) tuple. + input_channels: Number of input channels. + filters: Number of output channels. + separability: Integer denoting separability. + kernel_initializer: Initializer to use for the kernel. + kernel_regularizer: Regularizer to use for the kernel. + + Returns: + A 4D tensor. + """ + if separability == 1: + # Non-separable convolution + return tf.get_variable( + "kernel", + kernel_size + (input_channels, filters), + initializer=kernel_initializer, + regularizer=kernel_regularizer) + + elif separability == 0 or separability == -1: + # Separable convolution + # TODO(rshin): Check initialization is as expected, as these are not 4D. + depthwise_kernel = tf.get_variable( + "depthwise_kernel", + kernel_size + (input_channels,), + initializer=kernel_initializer, + regularizer=kernel_regularizer) + + pointwise_kernel = tf.get_variable( + "pointwise_kernel", + (input_channels, filters), + initializer=kernel_initializer, + regularizer=kernel_regularizer) + + expanded_depthwise_kernel = tf.transpose( + tf.scatter_nd( + indices=tf.tile( + tf.expand_dims( + tf.range(0, input_channels), axis=1), [1, 2]), + updates=tf.transpose(depthwise_kernel, (2, 0, 1)), + shape=(input_channels, input_channels) + kernel_size), (2, 3, 0, 1)) + + return tf.reshape( + tf.matmul( + tf.reshape(expanded_depthwise_kernel, (-1, input_channels)), + pointwise_kernel), kernel_size + (input_channels, filters)) + + elif separability >= 2: + assert filters % separability == 0, (filters, separability) + assert input_channels % separability == 0, (filters, separability) + + raise NotImplementedError + + elif separability <= -2: + separability *= -1 + assert filters % separability == 0, (filters, separability) + assert input_channels % separability == 0, (filters, separability) + + raise NotImplementedError + + +def multi_subseparable_conv( + inputs, + filters, + kernel_sizes, + input_channels, + separabilities, + kernel_selection_weights=None, + channel_selection_weights=None, + separability_selection_weights=None, + kernel_selection_weights_params=None, + channel_selection_weights_params=None, + separability_selection_weights_params=None, + kernel_initializer=None, + kernel_regularizer=None, + scope=None): + """Simultaneously compute different kinds of convolutions on subsets of input. + + Args: + inputs: 4D tensor containing the input, in NHWC format. + filters: Integer, number of output channels. + kernel_sizes: A list of (height, width) pairs of integers, containing + different kernel sizes to use. + input_channels: A list of (begin, end) pairs of integers, which describe + which channels in the input to use. + separabilities: An integer or a list, how separable are the convolutions. + kernel_selection_weights: SelectionWeights object to use for choosing + among kernel sizes. + channel_selection_weights: SelectionWeights object to use for choosing + among which input channels to use. + separability_selection_weights: SelectionWeights object to use for choosing + separability. + kernel_selection_weights_params: dict with up to three keys + - initializer + - regularizer + - inv_t + channel_selection_weights_params: dict with up to three keys + - initializer + - regularizer + - inv_t + separability_selection_weights_params: dict with up to three keys + - initializer + - regularizer + - inv_t + kernel_initializer: Initializer to use for kernels. + kernel_regularizer: Regularizer to use for kernels. + scope: the scope to use. + + Returns: + Result of convolution. + """ + kernel_selection_weights_params = kernel_selection_weights_params or {} + channel_selection_weights_params = channel_selection_weights_params or {} + if separability_selection_weights_params is None: + separability_selection_weights_params = {} + + # Get input image size. + input_shape = inputs.get_shape().as_list() + assert len(input_shape) == 4 + in_channels = input_shape[3] + assert in_channels is not None + + max_kernel_size = tuple(np.max(kernel_sizes, axis=0)) + max_num_channels = np.max(input_channels) - np.min(input_channels) + + with tf.variable_scope(scope or "selection_weights"): + if kernel_selection_weights is None: + kernel_selection_weights = create_selection_weights( + "kernels", + "softmax", (len(kernel_sizes),), + names=[ + "kernel_h{}_w{}".format(h, w) for h, w in kernel_sizes + ], + **kernel_selection_weights_params) + + if channel_selection_weights is None: + channel_selection_weights = create_selection_weights( + "channels", + "softmax", (len(input_channels),), + names=[ + "channels_{}_{}".format(c1, c2) for c1, c2 in input_channels + ], + **channel_selection_weights_params) + + if separability_selection_weights is None: + separability_selection_weights = create_selection_weights( + "separability", + "softmax", (len(separabilities),), + names=[ + "separability_{}".format(s) for s in separabilities + ], + **separability_selection_weights_params) + + kernels = [] + for separability in separabilities: + with tf.variable_scope("separablity_{}".format(separability)): + kernel = make_subseparable_kernel( + max_kernel_size, + max_num_channels, + filters, + separability, + kernel_initializer, + kernel_regularizer) + + premultiplier = kernel_premultiplier( + max_kernel_size, kernel_sizes, input_channels, + kernel_selection_weights, + channel_selection_weights) + + kernels.append(kernel * premultiplier) + + kernel = tf.add_n([ + separability_selection_weights.normalized[i] * k + for i, k in enumerate(kernels) + ]) + + if np.min(input_channels) != 0 or np.max(input_channels) != in_channels: + inputs = inputs[:, :, :, np.min(input_channels):np.max(input_channels)] + + return tf.nn.conv2d( + inputs, + filter=kernel, + strides=[1, 1, 1, 1], + padding="SAME", + data_format="NHWC", + name="conv2d") + + def conv_module(kw, kh, sep, div): def convfn(x, hparams): return common_layers.subseparable_conv( @@ -39,6 +365,13 @@ def convfn(x, hparams): return convfn +def multi_conv_module(kernel_sizes, seps): + def convfn(x, hparams): + return multi_subseparable_conv(x, hparams.hidden_size, kernel_sizes, + [(0, hparams.hidden_size)], seps) + return convfn + + def layernorm_module(x, hparams): return common_layers.layer_norm(x, hparams.hidden_size, name="layer_norm") @@ -75,47 +408,46 @@ def shakeshake_binary_module(x, y, hparams): def run_binary_modules(modules, cur1, cur2, hparams): """Run binary modules.""" - selection_var = tf.get_variable("selection", [len(modules)], - initializer=tf.zeros_initializer()) - inv_t = 100.0 * common_layers.inverse_exp_decay( - hparams.anneal_until, min_value=0.01) - selected_weights = tf.nn.softmax(selection_var * inv_t) + selection_weights = create_selection_weights( + "selection", + "softmax", + shape=[len(modules)], + inv_t=100.0 * common_layers.inverse_exp_decay( + hparams.anneal_until, min_value=0.01)) all_res = [modules[n](cur1, cur2, hparams) for n in xrange(len(modules))] all_res = tf.concat([tf.expand_dims(r, axis=0) for r in all_res], axis=0) - res = all_res * tf.reshape(selected_weights, [-1, 1, 1, 1, 1]) + res = all_res * tf.reshape(selection_weights.normalized, [-1, 1, 1, 1, 1]) return tf.reduce_sum(res, axis=0) def run_unary_modules_basic(modules, cur, hparams): """Run unary modules.""" - selection_var = tf.get_variable("selection", [len(modules)], - initializer=tf.zeros_initializer()) - inv_t = 100.0 * common_layers.inverse_exp_decay( - hparams.anneal_until, min_value=0.01) - selected_weights = tf.nn.softmax(selection_var * inv_t) + selection_weights = create_selection_weights( + "selection", + "softmax", + shape=[len(modules)], + inv_t=100.0 * common_layers.inverse_exp_decay( + hparams.anneal_until, min_value=0.01)) all_res = [modules[n](cur, hparams) for n in xrange(len(modules))] all_res = tf.concat([tf.expand_dims(r, axis=0) for r in all_res], axis=0) - res = all_res * tf.reshape(selected_weights, [-1, 1, 1, 1, 1]) + res = all_res * tf.reshape(selection_weights.normalized, [-1, 1, 1, 1, 1]) return tf.reduce_sum(res, axis=0) def run_unary_modules_sample(modules, cur, hparams, k): """Run modules, sampling k.""" - selection_var = tf.get_variable("selection", [len(modules)], - initializer=tf.zeros_initializer()) - selection = tf.multinomial(tf.expand_dims(selection_var, axis=0), k) - selection = tf.squeeze(selection, axis=0) # [k] selected classes. - to_run = tf.one_hot(selection, len(modules)) # [k x nmodules] one-hot. - to_run = tf.reduce_sum(to_run, axis=0) # [nmodules], 0=not run, 1=run. - all_res = [tf.cond(tf.less(to_run[n], 0.1), + selection_weights = create_selection_weights( + "selection", + ("softmax_topk", k), + shape=[len(modules)], + inv_t=100.0 * common_layers.inverse_exp_decay( + hparams.anneal_until, min_value=0.01)) + all_res = [tf.cond(tf.less(selection_weights.normalized[n], 1e-6), lambda: tf.zeros_like(cur), lambda i=n: modules[i](cur, hparams)) for n in xrange(len(modules))] - inv_t = 100.0 * common_layers.inverse_exp_decay( - hparams.anneal_until, min_value=0.01) - selected_weights = tf.nn.softmax(selection_var * inv_t - 1e9 * (1.0 - to_run)) all_res = tf.concat([tf.expand_dims(r, axis=0) for r in all_res], axis=0) - res = all_res * tf.reshape(selected_weights, [-1, 1, 1, 1, 1]) + res = all_res * tf.reshape(selection_weights.normalized, [-1, 1, 1, 1, 1]) return tf.reduce_sum(res, axis=0) @@ -138,10 +470,10 @@ class BlueNet(t2t_model.T2TModel): def model_fn_body(self, features): hparams = self._hparams - conv_modules = [conv_module(kw, kw, sep, div) - for kw in [3, 5, 7] - for sep in [0, 1] - for div in [1]] + [identity_module] + # TODO(rshin): Give identity_module lower weight by default. + multi_conv = multi_conv_module( + kernel_sizes=[(3, 3), (5, 5), (7, 7)], seps=[0, 1]) + conv_modules = [multi_conv, identity_module] activation_modules = [identity_module, lambda x, _: tf.nn.relu(x), lambda x, _: tf.nn.elu(x), @@ -166,20 +498,24 @@ def run_unary(x, name): x.set_shape(x_shape) return tf.nn.dropout(x, 1.0 - hparams.dropout), batch_deviation(x) - cur1, cur2, extra_loss = inputs, inputs, 0.0 + cur1, cur2, cur3, extra_loss = inputs, inputs, inputs, 0.0 cur_shape = inputs.get_shape() for i in xrange(hparams.num_hidden_layers): with tf.variable_scope("layer_%d" % i): cur1, loss1 = run_unary(cur1, "unary1") cur2, loss2 = run_unary(cur2, "unary2") - extra_loss += (loss1 + loss2) / float(hparams.num_hidden_layers) + cur3, loss3 = run_unary(cur2, "unary3") + extra_loss += (loss1 + loss2 + loss3) / float(hparams.num_hidden_layers) with tf.variable_scope("binary1"): next1 = run_binary_modules(binary_modules, cur1, cur2, hparams) next1.set_shape(cur_shape) with tf.variable_scope("binary2"): - next2 = run_binary_modules(binary_modules, cur1, cur2, hparams) + next2 = run_binary_modules(binary_modules, cur1, cur3, hparams) next2.set_shape(cur_shape) - cur1, cur2 = next1, next2 + with tf.variable_scope("binary3"): + next3 = run_binary_modules(binary_modules, cur2, cur3, hparams) + next3.set_shape(cur_shape) + cur1, cur2, cur3 = next1, next2, next3 anneal = common_layers.inverse_exp_decay(hparams.anneal_until) extra_loss *= hparams.batch_deviation_loss_factor * anneal @@ -193,7 +529,7 @@ def bluenet_base(): hparams.batch_size = 4096 hparams.hidden_size = 256 hparams.dropout = 0.2 - hparams.symbol_dropout = 0.2 + hparams.symbol_dropout = 0.5 hparams.label_smoothing = 0.1 hparams.clip_grad_norm = 2.0 hparams.num_hidden_layers = 8 @@ -211,7 +547,7 @@ def bluenet_base(): hparams.optimizer_adam_beta2 = 0.997 hparams.add_hparam("imagenet_use_2d", True) hparams.add_hparam("anneal_until", 40000) - hparams.add_hparam("batch_deviation_loss_factor", 0.001) + hparams.add_hparam("batch_deviation_loss_factor", 5.0) return hparams diff --git a/tensor2tensor/models/common_layers.py b/tensor2tensor/models/common_layers.py index 2e2b74268..7a6ce96fb 100644 --- a/tensor2tensor/models/common_layers.py +++ b/tensor2tensor/models/common_layers.py @@ -293,7 +293,6 @@ def conv_internal(conv_fn, inputs, filters, kernel_size, **kwargs): static_shape = inputs.get_shape() if not static_shape or len(static_shape) != 4: raise ValueError("Inputs to conv must have statically known rank 4.") - inputs.set_shape([static_shape[0], None, None, static_shape[3]]) # Add support for left padding. if "padding" in kwargs and kwargs["padding"] == "LEFT": dilation_rate = (1, 1) @@ -307,9 +306,9 @@ def conv_internal(conv_fn, inputs, filters, kernel_size, **kwargs): width_padding = 0 if static_shape[2] == 1 else cond_padding padding = [[0, 0], [height_padding, 0], [width_padding, 0], [0, 0]] inputs = tf.pad(inputs, padding) + # Set middle two dimensions to None to prevent convolution from complaining + inputs.set_shape([static_shape[0], None, None, static_shape[3]]) kwargs["padding"] = "VALID" - # Special argument we use to force 2d kernels (see below). - force2d = kwargs.get("force2d", True) def conv2d_kernel(kernel_size_arg, name_suffix): """Call conv2d but add suffix to name.""" @@ -329,17 +328,7 @@ def conv2d_kernel(kernel_size_arg, name_suffix): kwargs["force2d"] = original_force2d return result - # Manually setting the shape to be unknown in the middle two dimensions so - # that the `tf.cond` below won't throw an error based on the convolution - # kernels being too large for the data. - inputs._shape = tf.TensorShape([static_shape[0], None, None, static_shape[3]]) # pylint: disable=protected-access - if kernel_size[1] == 1 or force2d: - # Avoiding the cond below can speed up graph and gradient construction. - return conv2d_kernel(kernel_size, "single") - return tf.cond( - tf.equal(tf.shape(inputs)[2], - 1), lambda: conv2d_kernel((kernel_size[0], 1), "small"), - lambda: conv2d_kernel(kernel_size, "std")) + return conv2d_kernel(kernel_size, "single") def conv(inputs, filters, kernel_size, **kwargs): @@ -566,20 +555,8 @@ def pool(inputs, window_size, pooling_type, padding, strides=(1, 1)): inputs = tf.pad(inputs, padding_) inputs.set_shape([static_shape[0], None, None, static_shape[3]]) padding = "VALID" - window_size_small = (window_size[0], 1) - strides_small = (strides[0], 1) - # Manually setting the shape to be unknown in the middle two dimensions so - # that the `tf.cond` below won't throw an error based on the convolution - # kernels being too large for the data. - inputs._shape = tf.TensorShape( # pylint: disable=protected-access - [static_shape[0], None, None, static_shape[3]]) - return tf.cond( - tf.equal(tf.shape(inputs)[2], 1), - lambda: tf.nn.pool( # pylint: disable=g-long-lambda - inputs, window_size_small, pooling_type, padding, - strides=strides_small), - lambda: tf.nn.pool( # pylint: disable=g-long-lambda - inputs, window_size, pooling_type, padding, strides=strides)) + + return tf.nn.pool(inputs, window_size, pooling_type, padding, strides=strides) def conv_block_downsample(x, @@ -1308,7 +1285,7 @@ def pad_with_zeros(logits, labels): logits, labels = pad_to_same_length(logits, labels) if len(labels.shape.as_list()) == 3: # 2-d labels. logits, labels = pad_to_same_length(logits, labels, axis=2) - return labels + return logits, labels def weights_nonzero(labels): @@ -1374,8 +1351,9 @@ def padded_cross_entropy(logits, confidence = 1.0 - label_smoothing vocab_size = tf.shape(logits)[-1] with tf.name_scope("padded_cross_entropy", [logits, labels]): - pad_labels = pad_with_zeros(logits, labels) - xent = smoothing_cross_entropy(logits, pad_labels, vocab_size, confidence) + pad_logits, pad_labels = pad_with_zeros(logits, labels) + xent = smoothing_cross_entropy(pad_logits, pad_labels, + vocab_size, confidence) weights = weights_fn(pad_labels) if not reduce_sum: return xent * weights, weights diff --git a/tensor2tensor/models/common_layers_test.py b/tensor2tensor/models/common_layers_test.py index 091f272d6..8d2b4dec1 100644 --- a/tensor2tensor/models/common_layers_test.py +++ b/tensor2tensor/models/common_layers_test.py @@ -277,13 +277,13 @@ def testShiftLeft(self): self.assertAllEqual(actual, expected) def testConvStride2MultiStep(self): - x1 = np.random.rand(5, 32, 1, 11) + x1 = np.random.rand(5, 32, 16, 11) with self.test_session() as session: a = common_layers.conv_stride2_multistep( tf.constant(x1, dtype=tf.float32), 4, 16) session.run(tf.global_variables_initializer()) actual = session.run(a[0]) - self.assertEqual(actual.shape, (5, 2, 0, 16)) + self.assertEqual(actual.shape, (5, 2, 1, 16)) def testDeconvStride2MultiStep(self): x1 = np.random.rand(5, 2, 1, 11) diff --git a/tensor2tensor/models/lstm.py b/tensor2tensor/models/lstm.py index 992c42db4..eb8b10cd2 100644 --- a/tensor2tensor/models/lstm.py +++ b/tensor2tensor/models/lstm.py @@ -12,19 +12,159 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Baseline models.""" +"""RNN LSTM models.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections + # Dependency imports +from tensor2tensor.models import common_hparams from tensor2tensor.models import common_layers from tensor2tensor.utils import registry from tensor2tensor.utils import t2t_model import tensorflow as tf +from tensorflow.python.util import nest + + +# Track Tuple of state and attention values +AttentionTuple = collections.namedtuple("AttentionTuple", + ("state", "attention")) + + +class ExternalAttentionCellWrapper(tf.contrib.rnn.RNNCell): + """Wrapper for external attention states for an encoder-decoder setup.""" + + def __init__(self, cell, attn_states, attn_vec_size=None, + input_size=None, state_is_tuple=True, reuse=None): + """Create a cell with attention. + + Args: + cell: an RNNCell, an attention is added to it. + attn_states: External attention states typically the encoder output in the + form [batch_size, time steps, hidden size] + attn_vec_size: integer, the number of convolutional features calculated + on attention state and a size of the hidden layer built from + base cell state. Equal attn_size to by default. + input_size: integer, the size of a hidden linear layer, + built from inputs and attention. Derived from the input tensor + by default. + state_is_tuple: If True, accepted and returned states are n-tuples, where + `n = len(cells)`. Must be set to True else will raise an exception + concatenated along the column axis. + reuse: (optional) Python boolean describing whether to reuse variables + in an existing scope. If not `True`, and the existing scope already has + the given variables, an error is raised. + Raises: + TypeError: if cell is not an RNNCell. + ValueError: if the flag `state_is_tuple` is `False` or if shape of + `attn_states` is not 3 or if innermost dimension (hidden size) is None. + """ + super(ExternalAttentionCellWrapper, self).__init__(_reuse=reuse) + if not state_is_tuple: + raise ValueError("Only tuple state is supported") + + self._cell = cell + self._input_size = input_size + + # Validate attn_states shape. + attn_shape = attn_states.get_shape() + if not attn_shape or len(attn_shape) != 3: + raise ValueError("attn_shape must be rank 3") + + self._attn_states = attn_states + self._attn_size = attn_shape[2].value + if self._attn_size is None: + raise ValueError("Hidden size of attn_states cannot be None") + + self._attn_vec_size = attn_vec_size + if self._attn_vec_size is None: + self._attn_vec_size = self._attn_size + + self._reuse = reuse + + @property + def state_size(self): + return AttentionTuple(self._cell.state_size, self._attn_size) + + @property + def output_size(self): + return self._attn_size + + def combine_state(self, previous_state): + """Combines previous state (from encoder) with internal attention values. + + You must use this function to derive the initial state passed into + this cell as it expects a named tuple (AttentionTuple). + + Args: + previous_state: State from another block that will be fed into this cell; + Must have same structure as the state of the cell wrapped by this. + Returns: + Combined state (AttentionTuple). + """ + batch_size = self._attn_states.get_shape()[0].value + if batch_size is None: + batch_size = tf.shape(self._attn_states)[0] + zeroed_state = self.zero_state(batch_size, self._attn_states.dtype) + return AttentionTuple(previous_state, zeroed_state.attention) + + def call(self, inputs, state): + """Long short-term memory cell with attention (LSTMA).""" + + if not isinstance(state, AttentionTuple): + raise TypeError("State must be of type AttentionTuple") + + state, attns = state + attn_states = self._attn_states + attn_length = attn_states.get_shape()[1].value + if attn_length is None: + attn_length = tf.shape(attn_states)[1] + + input_size = self._input_size + if input_size is None: + input_size = inputs.get_shape().as_list()[1] + if attns is not None: + inputs = tf.layers.dense(tf.concat([inputs, attns], axis=1), input_size) + lstm_output, new_state = self._cell(inputs, state) + + new_state_cat = tf.concat(nest.flatten(new_state), 1) + new_attns = self._attention(new_state_cat, attn_states, attn_length) + + with tf.variable_scope("attn_output_projection"): + output = tf.layers.dense(tf.concat([lstm_output, new_attns], axis=1), + self._attn_size) + + new_state = AttentionTuple(new_state, new_attns) + + return output, new_state + + def _attention(self, query, attn_states, attn_length): + conv2d = tf.nn.conv2d + reduce_sum = tf.reduce_sum + softmax = tf.nn.softmax + tanh = tf.tanh + + with tf.variable_scope("attention"): + k = tf.get_variable( + "attn_w", [1, 1, self._attn_size, self._attn_vec_size]) + v = tf.get_variable("attn_v", [self._attn_vec_size, 1]) + hidden = tf.reshape(attn_states, + [-1, attn_length, 1, self._attn_size]) + hidden_features = conv2d(hidden, k, [1, 1, 1, 1], "SAME") + y = tf.layers.dense(query, self._attn_vec_size) + y = tf.reshape(y, [-1, 1, 1, self._attn_vec_size]) + s = reduce_sum(v * tanh(hidden_features + y), [2, 3]) + a = softmax(s) + d = reduce_sum( + tf.reshape(a, [-1, attn_length, 1, 1]) * hidden, [1, 2]) + new_attns = tf.reshape(d, [-1, self._attn_size]) + + return new_attns def lstm(inputs, hparams, train, name, initial_state=None): @@ -45,6 +185,29 @@ def dropout_lstm_cell(): time_major=False) +def lstm_attention_decoder(inputs, hparams, train, name, + initial_state, attn_states): + """Run LSTM cell with attention on inputs of shape [batch x time x size].""" + + def dropout_lstm_cell(): + return tf.contrib.rnn.DropoutWrapper( + tf.nn.rnn_cell.BasicLSTMCell(hparams.hidden_size), + input_keep_prob=1.0 - hparams.dropout * tf.to_float(train)) + + layers = [dropout_lstm_cell() for _ in range(hparams.num_hidden_layers)] + cell = ExternalAttentionCellWrapper(tf.nn.rnn_cell.MultiRNNCell(layers), + attn_states, + attn_vec_size=hparams.attn_vec_size) + initial_state = cell.combine_state(initial_state) + with tf.variable_scope(name): + return tf.nn.dynamic_rnn( + cell, + inputs, + initial_state=initial_state, + dtype=tf.float32, + time_major=False) + + def lstm_seq2seq_internal(inputs, targets, hparams, train): """The basic LSTM seq2seq model, main step used for training.""" with tf.variable_scope("lstm_seq2seq"): @@ -64,6 +227,25 @@ def lstm_seq2seq_internal(inputs, targets, hparams, train): return tf.expand_dims(decoder_outputs, axis=2) +def lstm_seq2seq_internal_attention(inputs, targets, hparams, train): + """LSTM seq2seq model with attention, main step used for training.""" + with tf.variable_scope("lstm_seq2seq_attention"): + # Flatten inputs. + inputs = common_layers.flatten4d3d(inputs) + # LSTM encoder. + encoder_outputs, final_encoder_state = lstm( + tf.reverse(inputs, axis=[1]), hparams, train, "encoder") + # LSTM decoder with attention + shifted_targets = common_layers.shift_left(targets) + decoder_outputs, _ = lstm_attention_decoder( + common_layers.flatten4d3d(shifted_targets), + hparams, + train, + "decoder", + final_encoder_state, encoder_outputs) + return tf.expand_dims(decoder_outputs, axis=2) + + @registry.register_model("baseline_lstm_seq2seq") class LSTMSeq2Seq(t2t_model.T2TModel): @@ -71,3 +253,25 @@ def model_fn_body(self, features): train = self._hparams.mode == tf.contrib.learn.ModeKeys.TRAIN return lstm_seq2seq_internal(features["inputs"], features["targets"], self._hparams, train) + + +@registry.register_model("baseline_lstm_seq2seq_attention") +class LSTMSeq2SeqAttention(t2t_model.T2TModel): + + def model_fn_body(self, features): + train = self._hparams.mode == tf.contrib.learn.ModeKeys.TRAIN + return lstm_seq2seq_internal_attention( + features["inputs"], features["targets"], self._hparams, train) + + +@registry.register_hparams +def lstm_attention(): + """hparams for LSTM with attention.""" + hparams = common_hparams.basic_params1() + hparams.batch_size = 128 + hparams.hidden_size = 128 + hparams.num_hidden_layers = 2 + + # Attention + hparams.add_hparam("attn_vec_size", hparams.hidden_size) + return hparams diff --git a/tensor2tensor/models/lstm_test.py b/tensor2tensor/models/lstm_test.py index e5bdb184b..4c4c42909 100644 --- a/tensor2tensor/models/lstm_test.py +++ b/tensor2tensor/models/lstm_test.py @@ -51,6 +51,30 @@ def testLSTMSeq2Seq(self): res = session.run(logits) self.assertEqual(res.shape, (3, 6, 1, 1, vocab_size)) + def testLSTMSeq2SeqAttention(self): + vocab_size = 9 + x = np.random.random_integers(1, high=vocab_size - 1, size=(3, 5, 1, 1)) + y = np.random.random_integers(1, high=vocab_size - 1, size=(3, 6, 1, 1)) + hparams = lstm.lstm_attention() + + p_hparams = problem_hparams.test_problem_hparams(hparams, vocab_size, + vocab_size) + x = tf.constant(x, dtype=tf.int32) + x._shape = tf.TensorShape([None, None, 1, 1]) + + with self.test_session() as session: + features = { + "inputs": x, + "targets": tf.constant(y, dtype=tf.int32), + } + model = lstm.LSTMSeq2SeqAttention( + hparams, tf.contrib.learn.ModeKeys.TRAIN, p_hparams) + sharded_logits, _, _ = model.model_fn(features) + logits = tf.concat(sharded_logits, 0) + session.run(tf.global_variables_initializer()) + res = session.run(logits) + self.assertEqual(res.shape, (3, 6, 1, 1, vocab_size)) + if __name__ == "__main__": tf.test.main() diff --git a/tensor2tensor/models/multimodel.py b/tensor2tensor/models/multimodel.py index 66a8491f2..60f098e5e 100644 --- a/tensor2tensor/models/multimodel.py +++ b/tensor2tensor/models/multimodel.py @@ -19,52 +19,66 @@ # Dependency imports +from tensor2tensor.models import common_attention +from tensor2tensor.models import common_hparams from tensor2tensor.models import common_layers from tensor2tensor.models import modalities from tensor2tensor.models import slicenet -from tensor2tensor.utils import expert_utils as eu from tensor2tensor.utils import registry from tensor2tensor.utils import t2t_model import tensorflow as tf -def experts(xs, moe_n1, moe_n2, hidden_size, filter_size, dp, ps, train): - """Mixture-of-Experts layer.""" - # Set up the hyperparameters for the gating networks. - primary_gating_hp = eu.NoisyTopKGatingParams() - primary_gating_hp.num_experts = moe_n1 - if moe_n2: - # Hierarchical MoE containing moe_n1 groups of moe_n2 experts. - assert moe_n2 > 1 - secondary_gating_hp = eu.NoisyTopKGatingParams() - secondary_gating_hp.num_experts = moe_n2 - else: - # Flat mixture of moe_n1 experts. - secondary_gating_hp = None - # Set up the hyperparameters for the expert networks. - # Each expert contains a hidden RELU layer of size filter_size - expert_hp = eu.FeedForwardExpertParams() - expert_hp.hidden_layer_sizes = [filter_size] - # Create the mixture of experts. - moe = eu.DistributedMixtureOfExperts(primary_gating_hp, secondary_gating_hp, - expert_hp, hidden_size, hidden_size, ps, - "moe") - # MoE expects input tensors to be 2d. Flatten out spatial dimensions. - xs_2d = dp(tf.reshape, xs, [[-1, hidden_size]] * dp.n) - # Call the MoE - moe_out_2d, importance, load, _, _ = moe.Eval( - dp.devices, xs_2d, train, summaries=False, identifiers=None) - # Reshape the output to the original shape. - moe_out = dp(tf.reshape, moe_out_2d, dp(tf.shape, xs)) - # These losses encourage equal load on the different experts. - loss = eu.CVSquared(importance) + eu.CVSquared(load) - - # Apply residual and normalize. - def add_and_normalize(x, y): - return common_layers.layer_norm(x + y, hidden_size, name="moe_norm") - - return dp(add_and_normalize, xs, moe_out), loss +def conv_res_step(x, hparams, padding, mask): + """One step of convolutions and mid-residual.""" + k = (hparams.kernel_height, hparams.kernel_width) + k2 = (hparams.large_kernel_size, 1) + dilations_and_kernels1 = [((1, 1), k), ((1, 1), k)] + dilations_and_kernels2 = [((1, 1), k2), ((4, 4), k2)] + with tf.variable_scope("conv_res_step"): + y = common_layers.subseparable_conv_block( + x, hparams.filter_size, dilations_and_kernels1, + padding=padding, mask=mask, separabilities=0, name="residual1") + y = tf.nn.dropout(y, 1.0 - hparams.dropout) + return common_layers.subseparable_conv_block( + y, hparams.hidden_size, dilations_and_kernels2, + padding=padding, mask=mask, separabilities=0, name="residual2") + + +def residual_fn2(x, y, hparams): + y = tf.nn.dropout(y, 1.0 - hparams.dropout) + return common_layers.layer_norm(x + y) + + +def residual_fn3(x, y, z, hparams): + y = tf.nn.dropout(y, 1.0 - hparams.dropout) + z = tf.nn.dropout(z, 1.0 - hparams.dropout) + return common_layers.layer_norm(x + y + z) + + +def conv_experts(xs, hparams, dp, ps, padding, mask, layer_id): + """Convolutions + Mixture-of-Experts layer.""" + del layer_id # Unused. + train = hparams.mode == tf.contrib.learn.ModeKeys.TRAIN, + conv_out = dp(conv_res_step, xs, hparams, padding, mask) + loss = 0.0 + moe_out, loss = common_layers.moe_layer( + dp, ps, xs, train, hparams.hidden_size, hparams.filter_size, + hparams.moe_n1, hparams.moe_n2, 1.0) + return dp(residual_fn3, xs, moe_out, conv_out, hparams), loss + + +def prepare_decoder(targets, target_space_emb): + """Prepare decoder.""" + decoder_self_attention_bias = ( + common_attention.attention_bias_lower_triangle(tf.shape(targets)[1])) + target_space_emb = tf.reshape(target_space_emb, [1, 1, -1]) + target_space_emb = tf.tile(target_space_emb, [tf.shape(targets)[0], 1, 1]) + decoder_input = common_layers.shift_left_3d( + targets, pad_value=target_space_emb) + decoder_input = common_attention.add_timing_signal_1d(decoder_input) + return (decoder_input, decoder_self_attention_bias) @registry.register_model @@ -74,87 +88,119 @@ def model_fn_body_sharded(self, sharded_features): train = self._hparams.mode == tf.contrib.learn.ModeKeys.TRAIN dp = self._data_parallelism hparams = self._hparams - targets = sharded_features["targets"] def flatten(inputs): return tf.expand_dims(common_layers.flatten4d3d(inputs), axis=2) inputs = dp(flatten, sharded_features["inputs"]) - - # Encode inputs. - def encode_half(inputs, inputs_mask, hparams): - # Add timing and encode. - inputs = common_layers.add_timing_signal(inputs) - return slicenet.multi_conv_res(inputs, "SAME", "encoder1", - hparams.num_hidden_layers // 2, - hparams, mask=inputs_mask) - - target_space_emb = dp(slicenet.embed_target_space, - sharded_features["target_space_id"], - hparams.hidden_size) inputs_pad = dp(slicenet.embedding_to_padding, inputs) inputs_mask = dp(lambda x: 1.0 - x, inputs_pad) - inputs_encoded = dp(encode_half, inputs, inputs_mask, hparams) - with tf.variable_scope("experts_enc"): - inputs_encoded, expert_loss = experts( - inputs_encoded, hparams.moe_n1, hparams.moe_n2, hparams.hidden_size, - hparams.hidden_size, dp, self._ps_devices, train) - expert_loss *= hparams.moe_loss_coef - inputs_encoded = dp( - slicenet.multi_conv_res, inputs_encoded, "SAME", - "encoder2", hparams.num_hidden_layers, hparams, - mask=inputs_mask) + inputs_encoded = dp(common_layers.add_timing_signal, inputs) + expert_loss = 0.0 + for i in xrange(hparams.num_hidden_layers): + with tf.variable_scope("enc_layer_%d" % i): + inputs_encoded, moe_loss = conv_experts( + inputs_encoded, hparams, dp, self._ps_devices, "SAME", + inputs_mask, i) + expert_loss += tf.reduce_mean(moe_loss) * hparams.moe_loss_coef # If we're just predicing a class, there is no use for a decoder, return. if isinstance(hparams.problems[self._problem_idx].target_modality, modalities.ClassLabelModality): return inputs_encoded, tf.reduce_mean(expert_loss) - # Do the middle part. - decoder_start, similarity_loss = dp( - slicenet.slicenet_middle, inputs_encoded, targets, - target_space_emb, inputs_mask, hparams) - - # Decode. - decoder_half = dp( - slicenet.multi_conv_res, - decoder_start, - "LEFT", - "decoder1", - hparams.num_hidden_layers // 2, - hparams, - train, - mask=inputs_mask, - source=inputs_encoded) - with tf.variable_scope("experts_dec"): - decoder_half, expert_dec_loss = experts( - decoder_half, hparams.moe_n1, hparams.moe_n2, hparams.hidden_size, - hparams.hidden_size, dp, self._ps_devices, train) - expert_loss += expert_dec_loss * hparams.moe_loss_coef - decoder_final = dp( - slicenet.multi_conv_res, - decoder_half, - "LEFT", - "decoder2", - hparams.num_hidden_layers // 2, - hparams, - mask=inputs_mask, - source=inputs_encoded) - - total_loss = tf.reduce_mean(expert_loss) + tf.reduce_mean(similarity_loss) - return decoder_final, total_loss - - -@registry.register_hparams("multimodel_1p8") -def multimodel_params1_p8(): - """Version for eight problem runs.""" - hparams = slicenet.slicenet_params1() - hparams.problem_choice = "distributed" - hparams.attention_type = "simple" # TODO(lukaszkaiser): add transformer. - hparams.hidden_size = 1536 - hparams.moe_n1 = 120 - hparams.shared_embedding_and_softmax_weights = int(False) + # Decoder. + inputs3d = dp(tf.squeeze, inputs, 2) + inputs_encoded3d = dp(tf.squeeze, inputs_encoded, 2) + encoder_padding = dp(common_attention.embedding_to_padding, inputs3d) + encoder_attention_bias = dp( + common_attention.attention_bias_ignore_padding, encoder_padding) + targets = dp(common_layers.flatten4d3d, sharded_features["targets"]) + target_space_emb = dp(slicenet.embed_target_space, + sharded_features["target_space_id"], + hparams.hidden_size) + + (decoder_input, decoder_self_attention_bias) = dp( + prepare_decoder, targets, target_space_emb) + + x = dp(tf.nn.dropout, decoder_input, 1.0 - hparams.dropout) + for layer in xrange(hparams.num_hidden_layers): + with tf.variable_scope("dec_layer_%d" % layer): + with tf.variable_scope("attention"): + y = dp(common_attention.multihead_attention, + x, + None, + decoder_self_attention_bias, + hparams.hidden_size, + hparams.hidden_size, + hparams.hidden_size, + hparams.num_heads, + hparams.attention_dropout, + summaries=False, + name="decoder_self_attention") + z = dp(common_attention.multihead_attention, + y, + inputs_encoded3d, + encoder_attention_bias, + hparams.hidden_size, + hparams.hidden_size, + hparams.hidden_size, + hparams.num_heads, + hparams.attention_dropout, + summaries=False, + name="encdec_attention") + x = dp(residual_fn3, x, y, z, hparams) + with tf.variable_scope("ffn"): + if str(layer) in hparams.moe_layers.split(","): + y, moe_loss = common_layers.moe_layer( + dp, self._ps_devices, x, train, + hparams.hidden_size, hparams.filter_size, + hparams.moe_n1, hparams.moe_n2, hparams.moe_loss_coef) + expert_loss += tf.reduce_mean(moe_loss) + else: + y = dp(common_layers.conv_hidden_relu, + x, + hparams.filter_size, + hparams.hidden_size, + dropout=hparams.dropout) + x = dp(residual_fn2, x, y, hparams) + + x = dp(tf.expand_dims, x, 2) + return x, tf.reduce_mean(expert_loss) + + +@registry.register_hparams +def multimodel_base(): + """Base parameters for MultiModel.""" + hparams = common_hparams.basic_params1() + hparams.hidden_size = 512 + hparams.batch_size = 2048 + hparams.num_hidden_layers = 4 + hparams.learning_rate_decay_scheme = "noam" + hparams.learning_rate = 0.1 + hparams.learning_rate_warmup_steps = 4000 + hparams.initializer_gain = 1.0 hparams.dropout = 0.1 - hparams.attention_dropout = 0.1 - hparams.learning_rate_decay_scheme = "exp500k" + hparams.add_hparam("filter_size", 2048) # Add new ones like this. + hparams.add_hparam("large_kernel_size", 15) + hparams.add_hparam("attention_dropout", 0.1) + hparams.add_hparam("num_heads", 8) + hparams.add_hparam("moe_n1", 30) + hparams.add_hparam("moe_n2", 0) + hparams.add_hparam("moe_layers", "2") + hparams.add_hparam("moe_loss_coef", 1e-2) + hparams.add_hparam("imagenet_use_2d", int(True)) + return hparams + + +@registry.register_hparams +def multimodel_tiny(): + """Tiny parameters for MultiModel.""" + hparams = multimodel_base() + hparams.hidden_size = 128 + hparams.filter_size = 512 + hparams.batch_size = 512 + hparams.num_hidden_layers = 2 + hparams.moe_n1 = 10 + hparams.moe_layers = "0" return hparams diff --git a/tensor2tensor/models/multimodel_test.py b/tensor2tensor/models/multimodel_test.py index 72fe4a326..dbbd3fa8e 100644 --- a/tensor2tensor/models/multimodel_test.py +++ b/tensor2tensor/models/multimodel_test.py @@ -24,7 +24,6 @@ from tensor2tensor.data_generators import problem_hparams from tensor2tensor.models import multimodel -from tensor2tensor.models import slicenet import tensorflow as tf @@ -34,7 +33,7 @@ class MultiModelTest(tf.test.TestCase): def testMultiModel(self): x = np.random.random_integers(0, high=255, size=(3, 5, 4, 3)) y = np.random.random_integers(0, high=9, size=(3, 5, 1, 1)) - hparams = slicenet.slicenet_params1_tiny() + hparams = multimodel.multimodel_tiny() p_hparams = problem_hparams.image_cifar10(hparams) hparams.problems = [p_hparams] with self.test_session() as session: diff --git a/tensor2tensor/utils/metrics.py b/tensor2tensor/utils/metrics.py index ecc02fd5e..97da4cd35 100644 --- a/tensor2tensor/utils/metrics.py +++ b/tensor2tensor/utils/metrics.py @@ -37,10 +37,11 @@ def padded_accuracy_topk(predictions, weights_fn=common_layers.weights_nonzero): """Percentage of times that top-k predictions matches labels on non-0s.""" with tf.variable_scope("padded_accuracy_topk", values=[predictions, labels]): - padded_labels = common_layers.pad_with_zeros(predictions, labels) + padded_predictions, padded_labels = common_layers.pad_with_zeros( + predictions, labels) weights = weights_fn(padded_labels) - effective_k = tf.minimum(k, tf.shape(predictions)[-1]) - _, outputs = tf.nn.top_k(predictions, k=effective_k) + effective_k = tf.minimum(k, tf.shape(padded_predictions)[-1]) + _, outputs = tf.nn.top_k(padded_predictions, k=effective_k) outputs = tf.to_int32(outputs) padded_labels = tf.expand_dims(padded_labels, axis=-1) padded_labels += tf.zeros_like(outputs) # Pad to same shape. @@ -61,9 +62,10 @@ def padded_sequence_accuracy(predictions, """Percentage of times that predictions matches labels everywhere (non-0).""" with tf.variable_scope( "padded_sequence_accuracy", values=[predictions, labels]): - padded_labels = common_layers.pad_with_zeros(predictions, labels) + padded_predictions, padded_labels = common_layers.pad_with_zeros( + predictions, labels) weights = weights_fn(padded_labels) - outputs = tf.to_int32(tf.argmax(predictions, axis=-1)) + outputs = tf.to_int32(tf.argmax(padded_predictions, axis=-1)) not_correct = tf.to_float(tf.not_equal(outputs, padded_labels)) * weights axis = list(range(1, len(outputs.get_shape()))) correct_seq = 1.0 - tf.minimum(1.0, tf.reduce_sum(not_correct, axis=axis)) @@ -84,9 +86,10 @@ def padded_accuracy(predictions, weights_fn=common_layers.weights_nonzero): """Percentage of times that predictions matches labels on non-0s.""" with tf.variable_scope("padded_accuracy", values=[predictions, labels]): - padded_labels = common_layers.pad_with_zeros(predictions, labels) + padded_predictions, padded_labels = common_layers.pad_with_zeros( + predictions, labels) weights = weights_fn(padded_labels) - outputs = tf.to_int32(tf.argmax(predictions, axis=-1)) + outputs = tf.to_int32(tf.argmax(padded_predictions, axis=-1)) return tf.to_float(tf.equal(outputs, padded_labels)), weights @@ -119,8 +122,9 @@ def fn(predictions, labels, weights, idx, weights_fn): for i, problem in enumerate(problems): name = "metrics-%s/%s" % (problem, metric_name) - weights_fn = (common_layers.weights_concatenated - if "concat" in problem else common_layers.weights_nonzero) + class_output = "image" in problem and "coco" not in problem + weights_fn = (common_layers.weights_all if class_output + else common_layers.weights_nonzero) eval_metrics[name] = functools.partial(fn, idx=i, weights_fn=weights_fn) def global_fn(predictions, labels, weights):