Skip to content
This repository has been archived by the owner on Jul 7, 2023. It is now read-only.

Commit

Permalink
Fix transformer decoding when using attention other than dot_product.
Browse files Browse the repository at this point in the history
Note that _beam_decode_slow should not be wrapped in variable_scope(name)
unlike _fast_decode which must be wrapped so.
Fixes #674 and allows to decode with transformer_relative.
  • Loading branch information
martinpopel committed Mar 30, 2018
1 parent 9094fc1 commit 64c7f6a
Showing 1 changed file with 7 additions and 0 deletions.
7 changes: 7 additions & 0 deletions tensor2tensor/models/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,13 @@ def _beam_decode(self, features, decode_length, beam_size, top_beams, alpha):
None if using greedy decoding (beam_size=1)
}
"""
if self._hparams.self_attention_type != "dot_product":
# Caching is not guaranteed to work with attention types other than
# dot_product.
# TODO(petershaw): Support fast decoding when using relative
# position representations, i.e. "dot_product_relative" attention.
return self._beam_decode_slow(features, decode_length, beam_size,
top_beams, alpha)
with tf.variable_scope(self.name):
return self._fast_decode(
features, decode_length, beam_size, top_beams, alpha)
Expand Down

0 comments on commit 64c7f6a

Please sign in to comment.