From 160bed3fe2745c74aafd2f1a4d1568f43aabfab4 Mon Sep 17 00:00:00 2001 From: Lukasz Kaiser Date: Thu, 5 Apr 2018 11:57:04 -0700 Subject: [PATCH] Improvements to basic_conv_gen and autoencoder hparams. PiperOrigin-RevId: 191776372 --- tensor2tensor/models/research/autoencoders.py | 9 +++--- .../models/research/basic_conv_gen.py | 28 ++++++++++++------- 2 files changed, 22 insertions(+), 15 deletions(-) diff --git a/tensor2tensor/models/research/autoencoders.py b/tensor2tensor/models/research/autoencoders.py index 533ac7c30..a7c696499 100644 --- a/tensor2tensor/models/research/autoencoders.py +++ b/tensor2tensor/models/research/autoencoders.py @@ -316,8 +316,8 @@ def basic_discrete_autoencoder(): hparams = basic.basic_autoencoder() hparams.num_hidden_layers = 5 hparams.hidden_size = 64 - hparams.bottleneck_size = 2048 - hparams.bottleneck_noise = 0.2 + hparams.bottleneck_size = 4096 + hparams.bottleneck_noise = 0.1 hparams.bottleneck_warmup_steps = 3000 hparams.add_hparam("discretize_warmup_steps", 5000) return hparams @@ -327,8 +327,8 @@ def basic_discrete_autoencoder(): def residual_discrete_autoencoder(): """Residual discrete autoencoder model.""" hparams = residual_autoencoder() - hparams.bottleneck_size = 2048 - hparams.bottleneck_noise = 0.2 + hparams.bottleneck_size = 4096 + hparams.bottleneck_noise = 0.1 hparams.bottleneck_warmup_steps = 3000 hparams.add_hparam("discretize_warmup_steps", 5000) hparams.add_hparam("bottleneck_kind", "tanh_discrete") @@ -344,7 +344,6 @@ def residual_discrete_autoencoder_big(): hparams = residual_discrete_autoencoder() hparams.hidden_size = 128 hparams.max_hidden_size = 4096 - hparams.bottleneck_size = 8192 hparams.bottleneck_noise = 0.1 hparams.dropout = 0.1 hparams.residual_dropout = 0.4 diff --git a/tensor2tensor/models/research/basic_conv_gen.py b/tensor2tensor/models/research/basic_conv_gen.py index f6e34e9fb..144042896 100644 --- a/tensor2tensor/models/research/basic_conv_gen.py +++ b/tensor2tensor/models/research/basic_conv_gen.py @@ -40,26 +40,33 @@ def body(self, features): # Concat frames and down-stride. cur_frame = tf.to_float(features["inputs"]) prev_frame = tf.to_float(features["inputs_prev"]) - frames = tf.concat([cur_frame, prev_frame], axis=-1) - x = tf.layers.conv2d(frames, filters, kernel2, activation=tf.nn.relu, - strides=(2, 2), padding="SAME") + x = tf.concat([cur_frame, prev_frame], axis=-1) + for _ in xrange(hparams.num_compress_steps): + x = tf.layers.conv2d(x, filters, kernel2, activation=common_layers.belu, + strides=(2, 2), padding="SAME") + x = common_layers.layer_norm(x) + filters *= 2 # Add embedded action. - action = tf.reshape(features["action"], [-1, 1, 1, filters]) - x = tf.concat([x, action + tf.zeros_like(x)], axis=-1) + action = tf.reshape(features["action"], [-1, 1, 1, hparams.hidden_size]) + zeros = tf.zeros(common_layers.shape_list(x)[:-1] + [hparams.hidden_size]) + x = tf.concat([x, action + zeros], axis=-1) # Run a stack of convolutions. for i in xrange(hparams.num_hidden_layers): with tf.variable_scope("layer%d" % i): - y = tf.layers.conv2d(x, 2 * filters, kernel1, activation=tf.nn.relu, + y = tf.layers.conv2d(x, filters, kernel1, activation=common_layers.belu, strides=(1, 1), padding="SAME") if i == 0: x = y else: x = common_layers.layer_norm(x + y) # Up-convolve. - x = tf.layers.conv2d_transpose( - x, filters, kernel2, activation=tf.nn.relu, - strides=(2, 2), padding="SAME") + for _ in xrange(hparams.num_compress_steps): + filters //= 2 + x = tf.layers.conv2d_transpose( + x, filters, kernel2, activation=common_layers.belu, + strides=(2, 2), padding="SAME") + x = common_layers.layer_norm(x) # Reward prediction. reward_pred_h1 = tf.reduce_mean(x, axis=[1, 2], keep_dims=True) @@ -78,7 +85,7 @@ def basic_conv(): hparams = common_hparams.basic_params1() hparams.hidden_size = 64 hparams.batch_size = 8 - hparams.num_hidden_layers = 2 + hparams.num_hidden_layers = 3 hparams.optimizer = "Adam" hparams.learning_rate_constant = 0.0002 hparams.learning_rate_warmup_steps = 500 @@ -87,6 +94,7 @@ def basic_conv(): hparams.initializer = "uniform_unit_scaling" hparams.initializer_gain = 1.0 hparams.weight_decay = 0.0 + hparams.add_hparam("num_compress_steps", 2) return hparams