From 64c7f6a03c87ee211f2c0a061aeb148da66ae867 Mon Sep 17 00:00:00 2001 From: Martin Popel Date: Fri, 30 Mar 2018 11:27:47 +0200 Subject: [PATCH] Fix transformer decoding when using attention other than dot_product. Note that _beam_decode_slow should not be wrapped in variable_scope(name) unlike _fast_decode which must be wrapped so. Fixes #674 and allows to decode with transformer_relative. --- tensor2tensor/models/transformer.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tensor2tensor/models/transformer.py b/tensor2tensor/models/transformer.py index 8a913bb44..e3edecfce 100644 --- a/tensor2tensor/models/transformer.py +++ b/tensor2tensor/models/transformer.py @@ -225,6 +225,13 @@ def _beam_decode(self, features, decode_length, beam_size, top_beams, alpha): None if using greedy decoding (beam_size=1) } """ + if self._hparams.self_attention_type != "dot_product": + # Caching is not guaranteed to work with attention types other than + # dot_product. + # TODO(petershaw): Support fast decoding when using relative + # position representations, i.e. "dot_product_relative" attention. + return self._beam_decode_slow(features, decode_length, beam_size, + top_beams, alpha) with tf.variable_scope(self.name): return self._fast_decode( features, decode_length, beam_size, top_beams, alpha)