Skip to content
This repository has been archived by the owner on Jul 7, 2023. It is now read-only.

Commit

Permalink
Add the recent group normalization to common layers.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 191769014
  • Loading branch information
Lukasz Kaiser authored and Ryan Sepassi committed Apr 5, 2018
1 parent b39d152 commit b951c79
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 1 deletion.
25 changes: 24 additions & 1 deletion tensor2tensor/layers/common_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,7 +577,7 @@ def layer_norm_compute(x, epsilon, scale, bias):
def layer_norm(x, filters=None, epsilon=1e-6, name=None, reuse=None):
"""Layer normalize the tensor x, averaging over the last dimension."""
if filters is None:
filters = x.get_shape()[-1]
filters = shape_list(x)[-1]
with tf.variable_scope(
name, default_name="layer_norm", values=[x], reuse=reuse):
scale = tf.get_variable(
Expand All @@ -592,6 +592,27 @@ def layer_norm(x, filters=None, epsilon=1e-6, name=None, reuse=None):
return result


def group_norm(x, filters=None, num_groups=8, epsilon=1e-5):
"""Group normalization as in https://arxiv.org/abs/1803.08494."""
x_shape = shape_list(x)
if filters is None:
filters = x_shape[-1]
assert len(x_shape) == 4
assert filters % num_groups == 0
# Prepare variables.
scale = tf.get_variable(
"group_norm_scale", [filters], initializer=tf.ones_initializer())
bias = tf.get_variable(
"group_norm_bias", [filters], initializer=tf.zeros_initializer())
epsilon, scale, bias = [tf.cast(t, x.dtype) for t in [epsilon, scale, bias]]
# Reshape and compute group norm.
x = tf.reshape(x, x_shape[:-1] + [num_groups, filters // num_groups])
# Calculate mean and variance on heights, width, channels (not groups).
mean, variance = tf.nn.moments(x, [1, 2, 4], keep_dims=True)
norm_x = (x - mean) * tf.rsqrt(variance + epsilon)
return tf.reshape(norm_x, x_shape) * scale + bias


def noam_norm(x, epsilon=1.0, name=None):
"""One version of layer normalization."""
with tf.name_scope(name, default_name="noam_norm", values=[x]):
Expand All @@ -605,6 +626,8 @@ def apply_norm(x, norm_type, depth, epsilon):
"""Apply Normalization."""
if norm_type == "layer":
return layer_norm(x, filters=depth, epsilon=epsilon)
if norm_type == "group":
return group_norm(x, filters=depth, epsilon=epsilon)
if norm_type == "batch":
return tf.layers.batch_normalization(x, epsilon=epsilon)
if norm_type == "noam":
Expand Down
8 changes: 8 additions & 0 deletions tensor2tensor/layers/common_layers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,14 @@ def testLayerNorm(self):
res = session.run(y)
self.assertEqual(res.shape, (5, 7, 11))

def testGroupNorm(self):
x = np.random.rand(5, 7, 3, 16)
with self.test_session() as session:
y = common_layers.group_norm(tf.constant(x, dtype=tf.float32))
session.run(tf.global_variables_initializer())
res = session.run(y)
self.assertEqual(res.shape, (5, 7, 3, 16))

def testConvLSTM(self):
x = np.random.rand(5, 7, 11, 13)
with self.test_session() as session:
Expand Down

0 comments on commit b951c79

Please sign in to comment.