From f703629068ae3d9f984044e241f2686f481a948a Mon Sep 17 00:00:00 2001 From: T2T Team Date: Tue, 18 Jul 2017 17:47:09 -0700 Subject: [PATCH] Change summary generation to work better in multi-model case. PiperOrigin-RevId: 162429483 --- tensor2tensor/models/attention_lm.py | 3 --- tensor2tensor/models/attention_lm_moe.py | 1 - tensor2tensor/models/common_attention.py | 17 ++++-------- tensor2tensor/models/common_layers.py | 26 +++++++------------ tensor2tensor/models/long_answer.py | 1 - tensor2tensor/models/multimodel.py | 5 +--- tensor2tensor/models/slicenet.py | 8 +++--- tensor2tensor/models/transformer.py | 15 ++--------- .../models/transformer_alternative.py | 3 --- tensor2tensor/utils/trainer_utils.py | 7 +++++ 10 files changed, 27 insertions(+), 59 deletions(-) diff --git a/tensor2tensor/models/attention_lm.py b/tensor2tensor/models/attention_lm.py index 947dc9306..752de038e 100644 --- a/tensor2tensor/models/attention_lm.py +++ b/tensor2tensor/models/attention_lm.py @@ -101,8 +101,6 @@ def attention_lm_decoder(decoder_input, y: a Tensors """ x = decoder_input - # Summaries don't work in multi-problem setting yet. - summaries = "problems" not in hparams.values() or len(hparams.problems) == 1 with tf.variable_scope(name): for layer in xrange(hparams.num_hidden_layers): with tf.variable_scope("layer_%d" % layer): @@ -117,7 +115,6 @@ def attention_lm_decoder(decoder_input, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, - summaries=summaries, name="decoder_self_attention")) x = residual_fn(x, common_layers.conv_hidden_relu( diff --git a/tensor2tensor/models/attention_lm_moe.py b/tensor2tensor/models/attention_lm_moe.py index 952ff1a71..2754e8366 100644 --- a/tensor2tensor/models/attention_lm_moe.py +++ b/tensor2tensor/models/attention_lm_moe.py @@ -69,7 +69,6 @@ def residual_fn(x, y): hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, - summaries=True, name="decoder_self_attention") x = dp(residual_fn, x, y) with tf.variable_scope("ffn"): diff --git a/tensor2tensor/models/common_attention.py b/tensor2tensor/models/common_attention.py index b0d0403cd..6aa8a2a07 100644 --- a/tensor2tensor/models/common_attention.py +++ b/tensor2tensor/models/common_attention.py @@ -312,7 +312,6 @@ def dot_product_attention(q, v, bias, dropout_rate=0.0, - summaries=False, image_shapes=None, name=None): """dot-product attention. @@ -323,7 +322,6 @@ def dot_product_attention(q, v: a Tensor with shape [batch, heads, length_kv, depth_v] bias: bias Tensor (see attention_bias()) dropout_rate: a floating point number - summaries: a boolean image_shapes: optional tuple of integer scalars. see comments for attention_image_summary() name: an optional string @@ -340,13 +338,13 @@ def dot_product_attention(q, weights = tf.nn.softmax(logits, name="attention_weights") # dropping out the attention links for each of the heads weights = tf.nn.dropout(weights, 1.0 - dropout_rate) - if summaries and not tf.get_variable_scope().reuse: + if not tf.get_variable_scope().reuse: attention_image_summary(weights, image_shapes) return tf.matmul(weights, v) def masked_local_attention_1d( - q, k, v, block_length=128, summaries=True, name=None): + q, k, v, block_length=128, name=None): """Attention to the source position and a neigborhood to the left of it. The sequence is divided into blocks of length block_size. @@ -362,7 +360,6 @@ def masked_local_attention_1d( k: a Tensor with shape [batch, heads, length, depth_k] v: a Tensor with shape [batch, heads, length, depth_v] block_length: an integer - summaries: a boolean name: an optional string Returns: @@ -394,7 +391,7 @@ def masked_local_attention_1d( first_v = tf.slice(v, [0, 0, 0, 0], [-1, -1, block_length, -1]) first_output = dot_product_attention( first_q, first_k, first_v, attention_bias_lower_triangle(block_length), - summaries=summaries, name="fist_block") + name="fist_block") # compute attention for all subsequent query blocks. q = tf.reshape(q, [batch, heads, num_blocks, block_length, depth_k]) @@ -442,7 +439,6 @@ def multihead_attention(query_antecedent, output_depth, num_heads, dropout_rate, - summaries=False, image_shapes=None, attention_type="dot_product", block_length=128, @@ -458,7 +454,6 @@ def multihead_attention(query_antecedent, output_depth: an integer num_heads: an integer dividing total_key_depth and total_value_depth dropout_rate: a floating point number - summaries: a boolean image_shapes: optional tuple of integer scalars. see comments for attention_image_summary() attention_type: a string, either "dot_product" or "local_mask_right" @@ -509,12 +504,10 @@ def multihead_attention(query_antecedent, q *= key_depth_per_head**-0.5 if attention_type == "dot_product": x = dot_product_attention( - q, k, v, bias, dropout_rate, summaries, image_shapes) + q, k, v, bias, dropout_rate, image_shapes) else: assert attention_type == "local_mask_right" - x = masked_local_attention_1d(q, k, v, - block_length=block_length, - summaries=summaries) + x = masked_local_attention_1d(q, k, v, block_length=block_length) x = combine_heads(x) x = common_layers.conv1d(x, output_depth, 1, name="output_transform") return x diff --git a/tensor2tensor/models/common_layers.py b/tensor2tensor/models/common_layers.py index 1e7050570..638535aa2 100644 --- a/tensor2tensor/models/common_layers.py +++ b/tensor2tensor/models/common_layers.py @@ -777,7 +777,7 @@ def moe_layer(data_parallelism, xs_2d = dp(tf.reshape, xs, [[-1, model_hidden_size]] * dp.n) # Call the MoE moe_out_2d, importance, load, _, _ = moe.Eval( - dp.devices, xs_2d, train, identifiers=None, summaries=True) + dp.devices, xs_2d, train, identifiers=None) # Reshape the output to the original shape. moe_out = dp(tf.reshape, moe_out_2d, dp(tf.shape, xs)) # These losses encourage equal load on the different experts. @@ -785,7 +785,7 @@ def moe_layer(data_parallelism, return moe_out, loss -def simple_attention(target, source, bias=None, summaries=True): +def simple_attention(target, source, bias=None): """A simple attention function. Args: @@ -795,7 +795,6 @@ def simple_attention(target, source, bias=None, summaries=True): `[batch, source_timesteps_1, source_timesteps_2, depth]` bias: an optional `Tensor` with shape `[batch, timesteps, 1, 1]` used to mask the attention to not attend to padding of input. - summaries: Boolean, whether to output summaries. Returns: a `Tensor` with same shape as `target` @@ -814,7 +813,7 @@ def simple_attention(target, source, bias=None, summaries=True): if bias is not None: attention += tf.expand_dims(tf.squeeze(bias, axis=[2, 3]), axis=1) attention = tf.nn.softmax(attention) - if summaries and not tf.get_variable_scope().reuse: + if not tf.get_variable_scope().reuse: tf.summary.image("attention", tf.expand_dims(attention, 3), max_outputs=5) attended = tf.matmul(attention, source) return tf.reshape(attended, target_shape) @@ -861,8 +860,7 @@ def multiscale_conv_sum(inputs, output_size, dilation_rates_and_kernel_sizes, def multiscale_conv_and_attention(x, padding, hparams, - source=None, - summaries=True): + source=None): """A common part of t2t layers. First, do a linear multiscale convolution @@ -875,7 +873,6 @@ def multiscale_conv_and_attention(x, padding: a padding type hparams: hyperparameters for model source: optional source tensor for attention. (encoder output) - summaries: Boolean, whether to output summaries. Returns: a Tensor. @@ -893,7 +890,7 @@ def multiscale_conv_and_attention(x, x = conv(x, hparams.hidden_size, (1, 1)) x = noam_norm(x + conv_sum) if source is not None: - x = noam_norm(x + simple_attention(x, source, summaries=summaries)) + x = noam_norm(x + simple_attention(x, source)) return x @@ -930,8 +927,7 @@ def conv_with_pools(inputs, output_size, kernel_size, pool_sizes, pooling_type, def conv_with_pools_and_attention(x, padding, hparams, - source=None, - summaries=True): + source=None): """A common part of t2t layers. First, do conv_with_pools @@ -944,7 +940,6 @@ def conv_with_pools_and_attention(x, padding: a padding type hparams: hyperparameters for model source: optional source tensor for attention. (encoder output) - summaries: Boolean, whether to output summaries. Returns: a Tensor. @@ -959,7 +954,7 @@ def conv_with_pools_and_attention(x, conv_sum += x x = noam_norm(conv_sum) if source is not None: - x = noam_norm(x + simple_attention(x, source, summaries=summaries)) + x = noam_norm(x + simple_attention(x, source)) return x @@ -1057,7 +1052,6 @@ def attention_1d_v0(source, transform_source=True, transform_target=True, transform_output=True, - summaries=True, name=None): """multi-headed attention. @@ -1075,7 +1069,6 @@ def attention_1d_v0(source, transform_source: a boolean transform_target: a boolean transform_output: a boolean - summaries: a boolean name: an optional string Returns: @@ -1116,7 +1109,7 @@ def _maybe_transform(t, size, should_transform, name): mask = (1.0 - mask) * -1e9 attention += mask attention = tf.nn.softmax(attention) - if summaries and not tf.get_variable_scope().reuse: + if not tf.get_variable_scope().reuse: # Compute a color image summary. image = tf.reshape(attention, [batch, num_heads, target_length, source_length]) @@ -1162,7 +1155,6 @@ def conv_hidden_relu(inputs, output_size, kernel_size=(1, 1), second_kernel_size=(1, 1), - summaries=True, dropout=0.0, **kwargs): """Hidden layer with RELU activation followed by linear projection.""" @@ -1183,7 +1175,7 @@ def conv_hidden_relu(inputs, **kwargs) if dropout != 0.0: h = tf.nn.dropout(h, 1.0 - dropout) - if summaries and not tf.get_variable_scope().reuse: + if not tf.get_variable_scope().reuse: tf.summary.histogram("hidden_density_logit", relu_density_logit( h, list(range(inputs.shape.ndims - 1)))) diff --git a/tensor2tensor/models/long_answer.py b/tensor2tensor/models/long_answer.py index 15067e120..7bb6a4a55 100644 --- a/tensor2tensor/models/long_answer.py +++ b/tensor2tensor/models/long_answer.py @@ -75,7 +75,6 @@ def residual_fn(x, y): hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, - summaries=True, attention_type="local_mask_right", block_length=hparams.block_length, name="decoder_self_attention") diff --git a/tensor2tensor/models/multimodel.py b/tensor2tensor/models/multimodel.py index ee079fa6d..bf06dfd65 100644 --- a/tensor2tensor/models/multimodel.py +++ b/tensor2tensor/models/multimodel.py @@ -138,7 +138,6 @@ def flatten(inputs): hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, - summaries=False, name="decoder_self_attention") z = dp(common_attention.multihead_attention, y, @@ -149,7 +148,6 @@ def flatten(inputs): hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, - summaries=False, name="encdec_attention") x = dp(residual_fn3, x, y, z, hparams) with tf.variable_scope("ffn"): @@ -164,8 +162,7 @@ def flatten(inputs): x, hparams.filter_size, hparams.hidden_size, - dropout=hparams.dropout, - summaries=False) + dropout=hparams.dropout) x = dp(residual_fn2, x, y, hparams) x = dp(tf.expand_dims, x, 2) diff --git a/tensor2tensor/models/slicenet.py b/tensor2tensor/models/slicenet.py index 43913eab1..2ad4c89d1 100644 --- a/tensor2tensor/models/slicenet.py +++ b/tensor2tensor/models/slicenet.py @@ -64,8 +64,7 @@ def attention(targets_shifted, inputs_encoded, norm_fn, hparams, bias=None): hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, - name="self_attention", - summaries=False) + name="self_attention") qv = common_attention.multihead_attention( qv, inputs_encoded, @@ -75,12 +74,11 @@ def attention(targets_shifted, inputs_encoded, norm_fn, hparams, bias=None): hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, - name="encdec_attention", - summaries=False) + name="encdec_attention") return tf.expand_dims(qv, 2) elif hparams.attention_type == "simple": targets_with_attention = common_layers.simple_attention( - targets_timed, inputs_encoded, bias=bias, summaries=False) + targets_timed, inputs_encoded, bias=bias) return norm_fn(targets_shifted + targets_with_attention, name="attn_norm") diff --git a/tensor2tensor/models/transformer.py b/tensor2tensor/models/transformer.py index 0bb01c0f8..b24f7fa50 100644 --- a/tensor2tensor/models/transformer.py +++ b/tensor2tensor/models/transformer.py @@ -143,8 +143,6 @@ def transformer_encoder(encoder_input, y: a Tensors """ x = encoder_input - # Summaries don't work in multi-problem setting yet. - summaries = len(hparams.problems) < 2 with tf.variable_scope(name): for layer in xrange(hparams.num_hidden_layers): with tf.variable_scope("layer_%d" % layer): @@ -159,7 +157,6 @@ def transformer_encoder(encoder_input, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, - summaries=summaries, name="encoder_self_attention")) x = residual_fn(x, transformer_ffn_layer(x, hparams)) return x @@ -189,8 +186,6 @@ def transformer_decoder(decoder_input, y: a Tensors """ x = decoder_input - # Summaries don't work in multi-problem setting yet. - summaries = len(hparams.problems) < 2 with tf.variable_scope(name): for layer in xrange(hparams.num_hidden_layers): with tf.variable_scope("layer_%d" % layer): @@ -205,7 +200,6 @@ def transformer_decoder(decoder_input, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, - summaries=summaries, name="decoder_self_attention")) x = residual_fn( x, @@ -218,7 +212,6 @@ def transformer_decoder(decoder_input, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, - summaries=summaries, name="encdec_attention")) x = residual_fn(x, transformer_ffn_layer(x, hparams)) return x @@ -234,15 +227,12 @@ def transformer_ffn_layer(x, hparams): Returns: a Tensor of shape [batch_size, length, hparams.hidden_size] """ - # Summaries don't work in multi-problem setting yet. - summaries = len(hparams.problems) < 2 if hparams.ffn_layer == "conv_hidden_relu": return common_layers.conv_hidden_relu( x, hparams.filter_size, hparams.hidden_size, - dropout=hparams.relu_dropout, - summaries=summaries) + dropout=hparams.relu_dropout) elif hparams.ffn_layer == "parameter_attention": return common_attention.parameter_attention( x, @@ -260,8 +250,7 @@ def transformer_ffn_layer(x, hparams): kernel_size=(3, 1), second_kernel_size=(31, 1), padding="LEFT", - dropout=hparams.relu_dropout, - summaries=summaries) + dropout=hparams.relu_dropout) else: assert hparams.ffn_layer == "none" return x diff --git a/tensor2tensor/models/transformer_alternative.py b/tensor2tensor/models/transformer_alternative.py index aed074d56..280dbc713 100644 --- a/tensor2tensor/models/transformer_alternative.py +++ b/tensor2tensor/models/transformer_alternative.py @@ -140,8 +140,6 @@ def alt_transformer_decoder(decoder_input, """Alternative decoder.""" x = decoder_input - # Summaries don't work in multi-problem setting yet. - summaries = "problems" not in hparams.values() or len(hparams.problems) == 1 with tf.variable_scope(name): for layer in xrange(hparams.num_hidden_layers): with tf.variable_scope("layer_%d" % layer): @@ -155,7 +153,6 @@ def alt_transformer_decoder(decoder_input, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, - summaries=summaries, name="encdec_attention") x_ = residual_fn(x_, composite_layer(x_, mask, hparams)) diff --git a/tensor2tensor/utils/trainer_utils.py b/tensor2tensor/utils/trainer_utils.py index 08faeed2c..f7d3010a9 100644 --- a/tensor2tensor/utils/trainer_utils.py +++ b/tensor2tensor/utils/trainer_utils.py @@ -550,6 +550,13 @@ def nth_model(n): optimizer=opt, colocate_gradients_with_ops=True) + # Remove summaries that will fail to run because they are in conditionals. + # TODO(cwhipkey): Test with this code removed, later in 2017. + summaries = tf.get_collection_ref(tf.GraphKeys.SUMMARIES) + for i in range(len(summaries)-1, -1, -1): + if summaries[i].name.startswith("cond_"): + del summaries[i] + tf.logging.info("Global model_fn finished.") return run_info, total_loss, train_op