Skip to content
This repository has been archived by the owner on Jul 7, 2023. It is now read-only.

Commit

Permalink
Merge pull request #678 from martinpopel/transforer-relative-decode-fix
Browse files Browse the repository at this point in the history
Fix transformer decoding when using attention other than dot_product.
  • Loading branch information
lukaszkaiser authored Apr 3, 2018
2 parents 957b5cd + 64c7f6a commit b99fac2
Showing 1 changed file with 7 additions and 0 deletions.
7 changes: 7 additions & 0 deletions tensor2tensor/models/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,13 @@ def _beam_decode(self, features, decode_length, beam_size, top_beams, alpha):
None if using greedy decoding (beam_size=1)
}
"""
if self._hparams.self_attention_type != "dot_product":
# Caching is not guaranteed to work with attention types other than
# dot_product.
# TODO(petershaw): Support fast decoding when using relative
# position representations, i.e. "dot_product_relative" attention.
return self._beam_decode_slow(features, decode_length, beam_size,
top_beams, alpha)
with tf.variable_scope(self.name):
return self._fast_decode(
features, decode_length, beam_size, top_beams, alpha)
Expand Down

0 comments on commit b99fac2

Please sign in to comment.