Skip to content
This repository has been archived by the owner on Jul 7, 2023. It is now read-only.

Commit

Permalink
Update comment on shape in SymbolModality
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 191759697
  • Loading branch information
Ryan Sepassi committed Apr 5, 2018
1 parent fc9335c commit b39d152
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion tensor2tensor/layers/modalities.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,12 @@ def _get_weights(self, hidden_dim=None):

def bottom_simple(self, x, name, reuse):
with tf.variable_scope(name, reuse=reuse):
# Squeeze out the channels dimension.
# Ensure the inputs are 3-D
if len(x.get_shape()) == 4:
x = tf.squeeze(x, axis=3)
while len(x.get_shape()) < 3:
x = tf.expand_dims(x, axis=-1)

var = self._get_weights()
x = common_layers.dropout_no_scaling(
x, 1.0 - self._model_hparams.symbol_dropout)
Expand Down

0 comments on commit b39d152

Please sign in to comment.