Skip to content

Commit

Permalink
[Embedding] Fix incorrect frequency in shared-embedding. (#931)
Browse files Browse the repository at this point in the history
Signed-off-by: lixy9474 <lxy268263@alibaba-inc.com>
  • Loading branch information
lixy9474 authored Sep 19, 2023
1 parent 821d5e8 commit fe194b0
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 8 deletions.
74 changes: 74 additions & 0 deletions tensorflow/python/ops/embedding_variable_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2816,5 +2816,79 @@ def testSetInitializedWithRestore(self):
result = sess.run(var._is_initialized_op)
self.assertEqual(True, result)

def testCountsTensor(self):
os.environ["TF_RECORD_FREQ"] = "1"
checkpoint_directory = self.get_temp_dir()
ckpt_path = os.path.join(checkpoint_directory, "model.ckpt")
with ops.Graph().as_default() as g, ops.device('/cpu:0'):
var = variable_scope.get_embedding_variable("var_1",
embedding_dim = 3)
sp1 = sparse_tensor.SparseTensor(
indices=[[0,0],[1,0],[2,0],[3,0],[4,0],[5,0]],
values=math_ops.cast([0,0,0,1,1,2], dtypes.int64),
dense_shape=[6, 1])
sp2 = sparse_tensor.SparseTensor(
indices=[[0,0],[1,0],[2,0],[3,0],[4,0],[5,0]],
values=math_ops.cast([3,3,3,4,4,1], dtypes.int64),
dense_shape=[6, 1])
emb1 = embedding_ops.embedding_lookup_sparse(var, sp1, None)
emb2 = embedding_ops.embedding_lookup_sparse(var, sp2, None)
emb = emb1 + emb2
fun = math_ops.multiply(emb, 2.0, name='multiply')
loss = math_ops.reduce_sum(fun, name='reduce_sum')
gs = training_util.get_or_create_global_step()
opt = adagrad_decay.AdagradDecayOptimizer(0.1, gs)
g_v = opt.compute_gradients(loss)
train_op = opt.apply_gradients(g_v)
saver = saver_module.Saver()
init = variables.global_variables_initializer()
with self.test_session(graph=g) as sess:
sess.run([init])
sess.run(train_op)
saver.save(sess, ckpt_path)

for name, shape in checkpoint_utils.list_variables(ckpt_path):
if name == "var_1-freqs":
value = checkpoint_utils.load_variable(ckpt_path, name)
self.assertAllEqual(value, [3, 3, 1, 3, 2])

def testCountsTensorWithGradientDescent(self):
os.environ["TF_RECORD_FREQ"] = "1"
checkpoint_directory = self.get_temp_dir()
ckpt_path = os.path.join(checkpoint_directory, "model.ckpt")
with ops.Graph().as_default() as g, ops.device('/cpu:0'):
var = variable_scope.get_embedding_variable("var_1",
embedding_dim = 3)
sp1 = sparse_tensor.SparseTensor(
indices=[[0,0],[1,0],[2,0],[3,0],[4,0],[5,0]],
values=math_ops.cast([0,0,0,1,1,2], dtypes.int64),
dense_shape=[6, 1])
sp2 = sparse_tensor.SparseTensor(
indices=[[0,0],[1,0],[2,0],[3,0],[4,0],[5,0]],
values=math_ops.cast([3,3,3,4,4,1], dtypes.int64),
dense_shape=[6, 1])
emb1 = embedding_ops.embedding_lookup_sparse(var, sp1, None)
emb2 = embedding_ops.embedding_lookup_sparse(var, sp2, None)
emb = emb1 + emb2
fun = math_ops.multiply(emb, 2.0, name='multiply')
loss = math_ops.reduce_sum(fun, name='reduce_sum')
gs = training_util.get_or_create_global_step()
opt = gradient_descent.GradientDescentOptimizer(0.1)
g_v = opt.compute_gradients(loss)
train_op = opt.apply_gradients(g_v)
saver = saver_module.Saver()
init = variables.global_variables_initializer()
with self.test_session(graph=g) as sess:
sess.run([init])
sess.run(train_op)
saver.save(sess, ckpt_path)

for name, shape in checkpoint_utils.list_variables(ckpt_path):
if name == "var_1-freqs":
value = checkpoint_utils.load_variable(ckpt_path, name)
self.assertAllEqual(value, [3, 3, 1, 3, 2])

del os.environ["TF_RECORD_FREQ"]

if __name__ == "__main__":
googletest.main()
4 changes: 2 additions & 2 deletions tensorflow/python/ops/kv_variable_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ def _init_from_args(self,
self._dtype = initial_value.dtype.base_dtype
self._constraint = constraint
self._gather_op = None
self._counts_tensor = None
self._counts_tensor = {}
if self._is_primary:
self._slot_num = 0
else:
Expand Down Expand Up @@ -850,7 +850,7 @@ def sparse_read(self, indices, name=None, ev_init_value=None, counts=None):
default_value,
counts, is_inference=True,
name=name)
self._counts_tensor = counts
self._counts_tensor[indices] = counts
else:
value = gen_kv_variable_ops.kv_resource_gather(self._handle,
indices,
Expand Down
15 changes: 13 additions & 2 deletions tensorflow/python/training/gradient_descent.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,23 @@ def _resource_apply_dense(self, grad, handle):
def _resource_apply_sparse_duplicate_indices(self, grad, handle, indices):
if isinstance(handle, kv_variable_ops.EmbeddingVariable):
global_step = training_util.get_or_create_global_step()
if handle.need_counts() and handle._counts_tensor is not None:
if handle.need_counts() and len(handle._counts_tensor.keys()) != 0:
if indices.op.type == "ConcatV2":
total_counts = []
for tensor in indices.op.inputs:
if tensor.op.type == "Reshape":
indices_tensor = tensor.op.inputs[0]
total_counts.append(handle._counts_tensor[indices_tensor])
from tensorflow.python.ops import array_ops
counts_tensor = array_ops.concat(total_counts, 0)
elif indices.op.type == "Reshape":
indices_tensor = indices.op.inputs[0]
counts_tensor = handle._counts_tensor[indices_tensor]
return training_ops.kv_resource_sparse_apply_gradient_descent_with_counts(
handle.handle, math_ops.cast(self._learning_rate_tensor,
grad.dtype.base_dtype),
grad, indices, global_step,
handle._counts_tensor, use_locking=self._use_locking)
counts_tensor, use_locking=self._use_locking)
else:
return training_ops.kv_resource_sparse_apply_gradient_descent(
handle.handle, math_ops.cast(self._learning_rate_tensor,
Expand Down
30 changes: 26 additions & 4 deletions tensorflow/python/training/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,18 @@ def _deduplicate_indexed_slices_with_counts(values, indices):
array_ops.shape(unique_indices)[0])
return (summed_values, unique_indices, indices_counts)

def _deduplicate_indexed_slices_with_counts_reduction(values, indices, counts):
"""Sums `values` associated with any non-unique `indices`
and return counts of each count in `values`."""
unique_indices, new_index_positions = array_ops.unique(indices)
summed_values = math_ops.unsorted_segment_sum(
values, new_index_positions,
array_ops.shape(unique_indices)[0])
summed_counts = math_ops.unsorted_segment_sum(
counts, new_index_positions,
array_ops.shape(unique_indices)[0])
return (summed_values, unique_indices, summed_counts)

def _var_key(var):
# TODO(ashankar): Consolidate handling for eager and graph
if hasattr(var, "op"):
Expand Down Expand Up @@ -1088,14 +1100,24 @@ def _resource_apply_sparse_duplicate_indices(self, grad, handle, indices):
"""
from tensorflow.python.ops import kv_variable_ops
if isinstance(handle, kv_variable_ops.EmbeddingVariable) and handle.need_counts():
if handle._counts_tensor is None:
if len(handle._counts_tensor.keys()) == 0:
summed_grad, unique_indices, indices_counts = \
_deduplicate_indexed_slices_with_counts(
values=grad, indices=indices)
else:
summed_grad, unique_indices = _deduplicate_indexed_slices(
values=grad, indices=indices)
indices_counts = handle._counts_tensor
if indices.op.type == "ConcatV2":
total_counts = []
for tensor in indices.op.inputs:
if tensor.op.type == "Reshape":
indices_tensor = tensor.op.inputs[0]
total_counts.append(handle._counts_tensor[indices_tensor])
counts_tensor = array_ops.concat(total_counts, 0)
elif indices.op.type == "Reshape":
indices_tensor = indices.op.inputs[0]
counts_tensor = handle._counts_tensor[indices_tensor]
summed_grad, unique_indices, indices_counts = \
_deduplicate_indexed_slices_with_counts_reduction(
grad, indices, counts_tensor)
return self._resource_apply_sparse(
summed_grad, handle, unique_indices, indices_counts)
else:
Expand Down

0 comments on commit fe194b0

Please sign in to comment.