Skip to content
This repository has been archived by the owner on Jul 7, 2023. It is now read-only.

Commit

Permalink
Change summary generation to work better in multi-model case.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 162429483
  • Loading branch information
T2T Team authored and Ryan Sepassi committed Jul 19, 2017
1 parent d3502cb commit f703629
Show file tree
Hide file tree
Showing 10 changed files with 27 additions and 59 deletions.
3 changes: 0 additions & 3 deletions tensor2tensor/models/attention_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,6 @@ def attention_lm_decoder(decoder_input,
y: a Tensors
"""
x = decoder_input
# Summaries don't work in multi-problem setting yet.
summaries = "problems" not in hparams.values() or len(hparams.problems) == 1
with tf.variable_scope(name):
for layer in xrange(hparams.num_hidden_layers):
with tf.variable_scope("layer_%d" % layer):
Expand All @@ -117,7 +115,6 @@ def attention_lm_decoder(decoder_input,
hparams.hidden_size,
hparams.num_heads,
hparams.attention_dropout,
summaries=summaries,
name="decoder_self_attention"))
x = residual_fn(x,
common_layers.conv_hidden_relu(
Expand Down
1 change: 0 additions & 1 deletion tensor2tensor/models/attention_lm_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ def residual_fn(x, y):
hparams.hidden_size,
hparams.num_heads,
hparams.attention_dropout,
summaries=True,
name="decoder_self_attention")
x = dp(residual_fn, x, y)
with tf.variable_scope("ffn"):
Expand Down
17 changes: 5 additions & 12 deletions tensor2tensor/models/common_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,6 @@ def dot_product_attention(q,
v,
bias,
dropout_rate=0.0,
summaries=False,
image_shapes=None,
name=None):
"""dot-product attention.
Expand All @@ -323,7 +322,6 @@ def dot_product_attention(q,
v: a Tensor with shape [batch, heads, length_kv, depth_v]
bias: bias Tensor (see attention_bias())
dropout_rate: a floating point number
summaries: a boolean
image_shapes: optional tuple of integer scalars.
see comments for attention_image_summary()
name: an optional string
Expand All @@ -340,13 +338,13 @@ def dot_product_attention(q,
weights = tf.nn.softmax(logits, name="attention_weights")
# dropping out the attention links for each of the heads
weights = tf.nn.dropout(weights, 1.0 - dropout_rate)
if summaries and not tf.get_variable_scope().reuse:
if not tf.get_variable_scope().reuse:
attention_image_summary(weights, image_shapes)
return tf.matmul(weights, v)


def masked_local_attention_1d(
q, k, v, block_length=128, summaries=True, name=None):
q, k, v, block_length=128, name=None):
"""Attention to the source position and a neigborhood to the left of it.
The sequence is divided into blocks of length block_size.
Expand All @@ -362,7 +360,6 @@ def masked_local_attention_1d(
k: a Tensor with shape [batch, heads, length, depth_k]
v: a Tensor with shape [batch, heads, length, depth_v]
block_length: an integer
summaries: a boolean
name: an optional string
Returns:
Expand Down Expand Up @@ -394,7 +391,7 @@ def masked_local_attention_1d(
first_v = tf.slice(v, [0, 0, 0, 0], [-1, -1, block_length, -1])
first_output = dot_product_attention(
first_q, first_k, first_v, attention_bias_lower_triangle(block_length),
summaries=summaries, name="fist_block")
name="fist_block")

# compute attention for all subsequent query blocks.
q = tf.reshape(q, [batch, heads, num_blocks, block_length, depth_k])
Expand Down Expand Up @@ -442,7 +439,6 @@ def multihead_attention(query_antecedent,
output_depth,
num_heads,
dropout_rate,
summaries=False,
image_shapes=None,
attention_type="dot_product",
block_length=128,
Expand All @@ -458,7 +454,6 @@ def multihead_attention(query_antecedent,
output_depth: an integer
num_heads: an integer dividing total_key_depth and total_value_depth
dropout_rate: a floating point number
summaries: a boolean
image_shapes: optional tuple of integer scalars.
see comments for attention_image_summary()
attention_type: a string, either "dot_product" or "local_mask_right"
Expand Down Expand Up @@ -509,12 +504,10 @@ def multihead_attention(query_antecedent,
q *= key_depth_per_head**-0.5
if attention_type == "dot_product":
x = dot_product_attention(
q, k, v, bias, dropout_rate, summaries, image_shapes)
q, k, v, bias, dropout_rate, image_shapes)
else:
assert attention_type == "local_mask_right"
x = masked_local_attention_1d(q, k, v,
block_length=block_length,
summaries=summaries)
x = masked_local_attention_1d(q, k, v, block_length=block_length)
x = combine_heads(x)
x = common_layers.conv1d(x, output_depth, 1, name="output_transform")
return x
Expand Down
26 changes: 9 additions & 17 deletions tensor2tensor/models/common_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -777,15 +777,15 @@ def moe_layer(data_parallelism,
xs_2d = dp(tf.reshape, xs, [[-1, model_hidden_size]] * dp.n)
# Call the MoE
moe_out_2d, importance, load, _, _ = moe.Eval(
dp.devices, xs_2d, train, identifiers=None, summaries=True)
dp.devices, xs_2d, train, identifiers=None)
# Reshape the output to the original shape.
moe_out = dp(tf.reshape, moe_out_2d, dp(tf.shape, xs))
# These losses encourage equal load on the different experts.
loss = loss_coef * (eu.CVSquared(importance) + eu.CVSquared(load))
return moe_out, loss


def simple_attention(target, source, bias=None, summaries=True):
def simple_attention(target, source, bias=None):
"""A simple attention function.
Args:
Expand All @@ -795,7 +795,6 @@ def simple_attention(target, source, bias=None, summaries=True):
`[batch, source_timesteps_1, source_timesteps_2, depth]`
bias: an optional `Tensor` with shape `[batch, timesteps, 1, 1]` used
to mask the attention to not attend to padding of input.
summaries: Boolean, whether to output summaries.
Returns:
a `Tensor` with same shape as `target`
Expand All @@ -814,7 +813,7 @@ def simple_attention(target, source, bias=None, summaries=True):
if bias is not None:
attention += tf.expand_dims(tf.squeeze(bias, axis=[2, 3]), axis=1)
attention = tf.nn.softmax(attention)
if summaries and not tf.get_variable_scope().reuse:
if not tf.get_variable_scope().reuse:
tf.summary.image("attention", tf.expand_dims(attention, 3), max_outputs=5)
attended = tf.matmul(attention, source)
return tf.reshape(attended, target_shape)
Expand Down Expand Up @@ -861,8 +860,7 @@ def multiscale_conv_sum(inputs, output_size, dilation_rates_and_kernel_sizes,
def multiscale_conv_and_attention(x,
padding,
hparams,
source=None,
summaries=True):
source=None):
"""A common part of t2t layers.
First, do a linear multiscale convolution
Expand All @@ -875,7 +873,6 @@ def multiscale_conv_and_attention(x,
padding: a padding type
hparams: hyperparameters for model
source: optional source tensor for attention. (encoder output)
summaries: Boolean, whether to output summaries.
Returns:
a Tensor.
Expand All @@ -893,7 +890,7 @@ def multiscale_conv_and_attention(x,
x = conv(x, hparams.hidden_size, (1, 1))
x = noam_norm(x + conv_sum)
if source is not None:
x = noam_norm(x + simple_attention(x, source, summaries=summaries))
x = noam_norm(x + simple_attention(x, source))
return x


Expand Down Expand Up @@ -930,8 +927,7 @@ def conv_with_pools(inputs, output_size, kernel_size, pool_sizes, pooling_type,
def conv_with_pools_and_attention(x,
padding,
hparams,
source=None,
summaries=True):
source=None):
"""A common part of t2t layers.
First, do conv_with_pools
Expand All @@ -944,7 +940,6 @@ def conv_with_pools_and_attention(x,
padding: a padding type
hparams: hyperparameters for model
source: optional source tensor for attention. (encoder output)
summaries: Boolean, whether to output summaries.
Returns:
a Tensor.
Expand All @@ -959,7 +954,7 @@ def conv_with_pools_and_attention(x,
conv_sum += x
x = noam_norm(conv_sum)
if source is not None:
x = noam_norm(x + simple_attention(x, source, summaries=summaries))
x = noam_norm(x + simple_attention(x, source))
return x


Expand Down Expand Up @@ -1057,7 +1052,6 @@ def attention_1d_v0(source,
transform_source=True,
transform_target=True,
transform_output=True,
summaries=True,
name=None):
"""multi-headed attention.
Expand All @@ -1075,7 +1069,6 @@ def attention_1d_v0(source,
transform_source: a boolean
transform_target: a boolean
transform_output: a boolean
summaries: a boolean
name: an optional string
Returns:
Expand Down Expand Up @@ -1116,7 +1109,7 @@ def _maybe_transform(t, size, should_transform, name):
mask = (1.0 - mask) * -1e9
attention += mask
attention = tf.nn.softmax(attention)
if summaries and not tf.get_variable_scope().reuse:
if not tf.get_variable_scope().reuse:
# Compute a color image summary.
image = tf.reshape(attention,
[batch, num_heads, target_length, source_length])
Expand Down Expand Up @@ -1162,7 +1155,6 @@ def conv_hidden_relu(inputs,
output_size,
kernel_size=(1, 1),
second_kernel_size=(1, 1),
summaries=True,
dropout=0.0,
**kwargs):
"""Hidden layer with RELU activation followed by linear projection."""
Expand All @@ -1183,7 +1175,7 @@ def conv_hidden_relu(inputs,
**kwargs)
if dropout != 0.0:
h = tf.nn.dropout(h, 1.0 - dropout)
if summaries and not tf.get_variable_scope().reuse:
if not tf.get_variable_scope().reuse:
tf.summary.histogram("hidden_density_logit",
relu_density_logit(
h, list(range(inputs.shape.ndims - 1))))
Expand Down
1 change: 0 additions & 1 deletion tensor2tensor/models/long_answer.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ def residual_fn(x, y):
hparams.hidden_size,
hparams.num_heads,
hparams.attention_dropout,
summaries=True,
attention_type="local_mask_right",
block_length=hparams.block_length,
name="decoder_self_attention")
Expand Down
5 changes: 1 addition & 4 deletions tensor2tensor/models/multimodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,6 @@ def flatten(inputs):
hparams.hidden_size,
hparams.num_heads,
hparams.attention_dropout,
summaries=False,
name="decoder_self_attention")
z = dp(common_attention.multihead_attention,
y,
Expand All @@ -149,7 +148,6 @@ def flatten(inputs):
hparams.hidden_size,
hparams.num_heads,
hparams.attention_dropout,
summaries=False,
name="encdec_attention")
x = dp(residual_fn3, x, y, z, hparams)
with tf.variable_scope("ffn"):
Expand All @@ -164,8 +162,7 @@ def flatten(inputs):
x,
hparams.filter_size,
hparams.hidden_size,
dropout=hparams.dropout,
summaries=False)
dropout=hparams.dropout)
x = dp(residual_fn2, x, y, hparams)

x = dp(tf.expand_dims, x, 2)
Expand Down
8 changes: 3 additions & 5 deletions tensor2tensor/models/slicenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,7 @@ def attention(targets_shifted, inputs_encoded, norm_fn, hparams, bias=None):
hparams.hidden_size,
hparams.num_heads,
hparams.attention_dropout,
name="self_attention",
summaries=False)
name="self_attention")
qv = common_attention.multihead_attention(
qv,
inputs_encoded,
Expand All @@ -75,12 +74,11 @@ def attention(targets_shifted, inputs_encoded, norm_fn, hparams, bias=None):
hparams.hidden_size,
hparams.num_heads,
hparams.attention_dropout,
name="encdec_attention",
summaries=False)
name="encdec_attention")
return tf.expand_dims(qv, 2)
elif hparams.attention_type == "simple":
targets_with_attention = common_layers.simple_attention(
targets_timed, inputs_encoded, bias=bias, summaries=False)
targets_timed, inputs_encoded, bias=bias)
return norm_fn(targets_shifted + targets_with_attention, name="attn_norm")


Expand Down
15 changes: 2 additions & 13 deletions tensor2tensor/models/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,6 @@ def transformer_encoder(encoder_input,
y: a Tensors
"""
x = encoder_input
# Summaries don't work in multi-problem setting yet.
summaries = len(hparams.problems) < 2
with tf.variable_scope(name):
for layer in xrange(hparams.num_hidden_layers):
with tf.variable_scope("layer_%d" % layer):
Expand All @@ -159,7 +157,6 @@ def transformer_encoder(encoder_input,
hparams.hidden_size,
hparams.num_heads,
hparams.attention_dropout,
summaries=summaries,
name="encoder_self_attention"))
x = residual_fn(x, transformer_ffn_layer(x, hparams))
return x
Expand Down Expand Up @@ -189,8 +186,6 @@ def transformer_decoder(decoder_input,
y: a Tensors
"""
x = decoder_input
# Summaries don't work in multi-problem setting yet.
summaries = len(hparams.problems) < 2
with tf.variable_scope(name):
for layer in xrange(hparams.num_hidden_layers):
with tf.variable_scope("layer_%d" % layer):
Expand All @@ -205,7 +200,6 @@ def transformer_decoder(decoder_input,
hparams.hidden_size,
hparams.num_heads,
hparams.attention_dropout,
summaries=summaries,
name="decoder_self_attention"))
x = residual_fn(
x,
Expand All @@ -218,7 +212,6 @@ def transformer_decoder(decoder_input,
hparams.hidden_size,
hparams.num_heads,
hparams.attention_dropout,
summaries=summaries,
name="encdec_attention"))
x = residual_fn(x, transformer_ffn_layer(x, hparams))
return x
Expand All @@ -234,15 +227,12 @@ def transformer_ffn_layer(x, hparams):
Returns:
a Tensor of shape [batch_size, length, hparams.hidden_size]
"""
# Summaries don't work in multi-problem setting yet.
summaries = len(hparams.problems) < 2
if hparams.ffn_layer == "conv_hidden_relu":
return common_layers.conv_hidden_relu(
x,
hparams.filter_size,
hparams.hidden_size,
dropout=hparams.relu_dropout,
summaries=summaries)
dropout=hparams.relu_dropout)
elif hparams.ffn_layer == "parameter_attention":
return common_attention.parameter_attention(
x,
Expand All @@ -260,8 +250,7 @@ def transformer_ffn_layer(x, hparams):
kernel_size=(3, 1),
second_kernel_size=(31, 1),
padding="LEFT",
dropout=hparams.relu_dropout,
summaries=summaries)
dropout=hparams.relu_dropout)
else:
assert hparams.ffn_layer == "none"
return x
Expand Down
3 changes: 0 additions & 3 deletions tensor2tensor/models/transformer_alternative.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,6 @@ def alt_transformer_decoder(decoder_input,
"""Alternative decoder."""
x = decoder_input

# Summaries don't work in multi-problem setting yet.
summaries = "problems" not in hparams.values() or len(hparams.problems) == 1
with tf.variable_scope(name):
for layer in xrange(hparams.num_hidden_layers):
with tf.variable_scope("layer_%d" % layer):
Expand All @@ -155,7 +153,6 @@ def alt_transformer_decoder(decoder_input,
hparams.hidden_size,
hparams.num_heads,
hparams.attention_dropout,
summaries=summaries,
name="encdec_attention")

x_ = residual_fn(x_, composite_layer(x_, mask, hparams))
Expand Down
7 changes: 7 additions & 0 deletions tensor2tensor/utils/trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,13 @@ def nth_model(n):
optimizer=opt,
colocate_gradients_with_ops=True)

# Remove summaries that will fail to run because they are in conditionals.
# TODO(cwhipkey): Test with this code removed, later in 2017.
summaries = tf.get_collection_ref(tf.GraphKeys.SUMMARIES)
for i in range(len(summaries)-1, -1, -1):
if summaries[i].name.startswith("cond_"):
del summaries[i]

tf.logging.info("Global model_fn finished.")
return run_info, total_loss, train_op

Expand Down

0 comments on commit f703629

Please sign in to comment.