Skip to content

Commit

Permalink
[Embedding] Fix set initialized flag too early in restore subgraph. (#…
Browse files Browse the repository at this point in the history
…920)

Signed-off-by: lixy9474 <lxy268263@alibaba-inc.com>
  • Loading branch information
lixy9474 authored Aug 11, 2023
1 parent f09e5ec commit 8d8e16a
Show file tree
Hide file tree
Showing 9 changed files with 158 additions and 26 deletions.
4 changes: 4 additions & 0 deletions tensorflow/core/framework/embedding/config.proto
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,7 @@ enum ValuePosition {
IN_DRAM = 0;
NOT_IN_DRAM = 1;
}

enum IsSetInitialized {
NOT_SET_INITAILIZED = 0;
}
10 changes: 6 additions & 4 deletions tensorflow/core/framework/embedding/multi_tier_storage.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,12 @@ class MultiTierStorage : public Storage<K, V> {
}

void InitCache(embedding::CacheStrategy cache_strategy) override {
cache_ = CacheFactory::Create<K>(cache_strategy, name_);
eviction_manager_ = EvictionManagerCreator::Create<K, V>();
eviction_manager_->AddStorage(this);
cache_thread_pool_ = CacheThreadPoolCreator::Create();
if (cache_ == nullptr) {
cache_ = CacheFactory::Create<K>(cache_strategy, name_);
eviction_manager_ = EvictionManagerCreator::Create<K, V>();
eviction_manager_->AddStorage(this);
cache_thread_pool_ = CacheThreadPoolCreator::Create();
}
}

Status BatchCommit(const std::vector<K>& keys,
Expand Down
2 changes: 2 additions & 0 deletions tensorflow/core/framework/variable.proto
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ message VariableDef {

// EmebddingVariable
bool is_embedding_var = 91;

string initialize_op_for_restore = 92;
}

message SaveSliceInfoDef {
Expand Down
28 changes: 14 additions & 14 deletions tensorflow/core/kernels/kv_variable_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,6 @@ limitations under the License.

namespace tensorflow {

namespace {
const int64 kEmbeddingVarUseDB = -214;
const int64 kInitializableEmbeddingVarUseDB = -215;
}

Status MoveMatchingFiles(
Env* env,
const tstring& pattern,
Expand Down Expand Up @@ -207,6 +202,15 @@ class InitializeKvVariableOp : public OpKernel {
(embedding_var_type ==
embedding::EmbeddingVariableType::IMMUTABLE);

//initial_num_buckets is useless, so is used to set is_set_initialized_.
int64 initial_num_buckets = 0;
OP_REQUIRES_OK(c, c->GetAttr("initial_num_buckets", &initial_num_buckets));
is_set_initialized_ = true;
if (initial_num_buckets ==
embedding::IsSetInitialized::NOT_SET_INITAILIZED) {
is_set_initialized_ = false;
}

int64 storage_type = 0;
OP_REQUIRES_OK(c, c->GetAttr("storage_type", &storage_type));
storage_type_ = static_cast<embedding::StorageType>(storage_type);
Expand Down Expand Up @@ -263,15 +267,10 @@ class InitializeKvVariableOp : public OpKernel {
" should be DRAM when layout is 'compact'."));
}

if (steps_to_live_ == kEmbeddingVarUseDB ||
steps_to_live_ == kInitializableEmbeddingVarUseDB) {
LOG(INFO) << "hashmap use db";
//use_db_ = true;
} else {
OP_REQUIRES(c, steps_to_live_ >= 0,
errors::InvalidArgument(
OP_REQUIRES(c, steps_to_live_ >= 0,
errors::InvalidArgument(
"steps_to_live must >= 0, ", std::to_string(steps_to_live_)));
}

OP_REQUIRES_OK(c, c->GetAttr("ht_type", &ht_type_));
if (embedding::StorageType::LEVELDB == storage_type_) {
ht_type_ = "leveldb_kv";
Expand Down Expand Up @@ -406,7 +405,7 @@ class InitializeKvVariableOp : public OpKernel {
core::ScopedUnref unref_me(primary_variable);
}
core::ScopedUnref unref_me(ev);
if (steps_to_live_ != kEmbeddingVarUseDB) {
if (is_set_initialized_) {
ev->SetInitialized();
}
}
Expand Down Expand Up @@ -436,6 +435,7 @@ class InitializeKvVariableOp : public OpKernel {
bool record_freq_;
bool record_version_;
bool is_inference_;
bool is_set_initialized_;
};

#define REGISTER_KERNELS(ktype, vtype) \
Expand Down
65 changes: 65 additions & 0 deletions tensorflow/python/ops/embedding_variable_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2751,5 +2751,70 @@ def testCPUFbjOptWithBloomFilter(self):
self.assertNotEqual(val, 1.0)
del os.environ["TF_EMBEDDING_FBJ_OPT"]

def testSetInitializedWithoutRestore(self):
print("testSetInitializedWithoutRestore")
with ops.device("/cpu:0"):
var = variable_scope.get_embedding_variable("var_1",
embedding_dim = 3)
emb = embedding_ops.embedding_lookup(var, math_ops.cast([1], dtypes.int64))
fun = math_ops.multiply(emb, 2.0, name='multiply')
loss = math_ops.reduce_sum(fun, name='reduce_sum')
gs = training_util.get_or_create_global_step()
opt = adagrad_decay.AdagradDecayOptimizer(0.1, gs)
g_v = opt.compute_gradients(loss)
train_op = opt.apply_gradients(g_v)
init = variables.global_variables_initializer()
saver = saver_module.Saver()
with self.test_session() as sess:
result = sess.run(var._is_initialized_op)
self.assertEqual(False, result)
sess.run([init])
result = sess.run(var._is_initialized_op)
self.assertEqual(True, result)

def testSetInitializedWithRestore(self):
print("testSetInitializedWitRestore")
checkpoint_directory = self.get_temp_dir()
ckpt_path = os.path.join(checkpoint_directory, "model.ckpt")
with ops.Graph().as_default() as g, ops.device('/cpu:0'):
var = variable_scope.get_embedding_variable("var_1",
embedding_dim = 3)
emb = embedding_ops.embedding_lookup(var, math_ops.cast([1,2 ,3], dtypes.int64))
fun = math_ops.multiply(emb, 2.0, name='multiply')
loss = math_ops.reduce_sum(fun, name='reduce_sum')
gs = training_util.get_or_create_global_step()
opt = adagrad_decay.AdagradDecayOptimizer(0.1, gs)
g_v = opt.compute_gradients(loss)
train_op = opt.apply_gradients(g_v)
saver = saver_module.Saver()
init = variables.global_variables_initializer()
with self.test_session(graph=g) as sess:
sess.run([init])
sess.run(train_op)
saver.save(sess, ckpt_path)

with ops.Graph().as_default() as g, ops.device('/cpu:0'):
var = variable_scope.get_embedding_variable("var_1",
embedding_dim = 3)
emb = embedding_ops.embedding_lookup(var, math_ops.cast([1, 2, 3], dtypes.int64))
fun = math_ops.multiply(emb, 2.0, name='multiply')
loss = math_ops.reduce_sum(fun, name='reduce_sum')
gs = training_util.get_or_create_global_step()
opt = adagrad_decay.AdagradDecayOptimizer(0.1, gs)
g_v = opt.compute_gradients(loss)
train_op = opt.apply_gradients(g_v)
saver = saver_module.Saver()
init = variables.global_variables_initializer()
with self.test_session(graph=g) as sess:
result = sess.run(var._is_initialized_op)
self.assertEqual(False, result)
sess.run([var._initializer_for_restore])
result = sess.run(var._is_initialized_op)
self.assertEqual(False, result)

saver.restore(sess, ckpt_path)
result = sess.run(var._is_initialized_op)
self.assertEqual(True, result)

if __name__ == "__main__":
googletest.main()
52 changes: 52 additions & 0 deletions tensorflow/python/ops/kv_variable_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,8 @@ def is_multi_tier(storage_type):
with ops.control_dependencies(set_attr_ops + [self._init_op]):
self._initializer_op = control_flow_ops.no_op()

self.create_init_op_for_restore(name, initial_value, invalid_key, rank)

self._graph_element = self._handle
self._cached_value = None
if not context.executing_eagerly():
Expand All @@ -444,6 +446,49 @@ def is_multi_tier(storage_type):
def export(self):
return gen_kv_variable_ops.kv_resource_export(self._handle, Tkeys=self._invalid_key_type)


def create_init_op_for_restore(self, name, initial_value, invalid_key, rank):
with ops.control_dependencies(None if self._is_primary else [self._primary._init_op_for_restore]):
self._initializer_for_restore = gen_kv_variable_ops.initialize_kv_variable_v2_op(
self._handle,
self._primary._handle,
variables._try_guard_against_uninitialized_dependencies(name, initial_value),
ops.convert_to_tensor(invalid_key),
initial_num_buckets=config_pb2.IsSetInitialized.NOT_SET_INITAILIZED,
slot_num = self._slot_num,
shape=initial_value.get_shape()[rank:],
steps_to_live=self._steps_to_live,
emb_index=self._emb_index, block_num=self.block_num,
slot_index=self._slot_index,
ht_type=self._ht_type,
ht_partition_num=self._ht_partition_num,
filter_freq = self._filter_freq,
l2_weight_threshold = self._l2_weight_threshold,
max_element_size = self._max_element_size,
false_positive_probability = self._false_positive_probability,
counter_type = self._counter_type,
max_freq = 99999,
layout = self._layout,
storage_type = self._storage_type,
storage_path = self._storage_path,
storage_size = self._storage_size,
default_value_dim = self._default_value_dim,
default_value_no_permission = self._default_value_no_permission,
record_freq = self._record_freq,
record_version = self._record_version,
embedding_variable_type=config_pb2.EmbeddingVariableType.IMMUTABLE)
set_attr_ops = []
if self._is_primary and self._is_multi_tier:
with ops.control_dependencies([self._initializer_for_restore]):
set_cache_op = gen_kv_variable_ops.kv_resource_init_cache_strategy_op(
self._handle,
cache_strategy=self._storage_cache_strategy,
Tkeys=self._invalid_key_type,
dtype=self._dtype)
set_attr_ops.append(set_cache_op)
with ops.control_dependencies(set_attr_ops + [self._initializer_for_restore]):
self._init_op_for_restore = control_flow_ops.no_op()

def need_counts(self):
return (self._record_freq or (self._filter_freq > 0) or self._is_multi_tier)
@property
Expand Down Expand Up @@ -482,6 +527,11 @@ def _init_from_proto(self, variable_def, import_scope=None):
cache_op = op
elif self._initializer_op.type == "InitializeKvVariableOp":
init_op = self._initializer_op

self._init_op_for_restore = g.as_graph_element(
ops.prepend_name_scope(
variable_def.initialize_op_for_restore,
import_scope=import_scope))
self._trainable = getattr(variable_def, "trainable", True)
if variable_def.snapshot_name:
self._cached_value = g.as_graph_element(
Expand Down Expand Up @@ -842,6 +892,8 @@ def to_proto(self, export_scope=None):
if self._save_slice_info:
var_def.save_slice_info_def.MergeFrom(
self._save_slice_info.to_proto(export_scope=export_scope))
var_def.initialize_op_for_restore = ops.strip_name_scope(
self._init_op_for_restore.name, export_scope)
return var_def
else:
return None
Expand Down
3 changes: 1 addition & 2 deletions tensorflow/python/training/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,8 +243,7 @@ def _get_processor(v):
if v.op.type == "KvVarHandleOp":
from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.framework.embedding import config_pb2
v._init_op._set_attr("embedding_variable_type",
attr_value_pb2.AttrValue(i=config_pb2.EmbeddingVariableType.MUTABLE))
slot_creator._set_init_op_embedding_type_attr(v, config_pb2.EmbeddingVariableType.MUTABLE)
return _DenseResourceVariableProcessor(v)
if isinstance(v, variables.Variable):
return _RefVariableProcessor(v)
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/python/training/saving/saveable_object_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def restore(self, restored_tensors, unused_restored_shapes):
if self.var._init_data_source is not None:
return self.var.recover_from_init_data_source(self.var._init_data_source, self.partition_id, self.partition_num)
else:
with ops.control_dependencies([self.var._initializer_op]):
with ops.control_dependencies([self.var._init_op_for_restore]):
rank = self.op.initial_value.get_shape().rank - 1
restore_op = gen_kv_variable_ops.kv_resource_import_v3(
restored_tensors[0],
Expand Down
18 changes: 13 additions & 5 deletions tensorflow/python/training/slot_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,7 @@ def _create_slot_var(primary, val, scope, validate_shape, shape, dtype, slot_con
validate_shape=validate_shape,
steps_to_live=primary._steps_to_live,
ht_partition_num=primary._ht_partition_num)
slot._init_op._set_attr("embedding_variable_type",
attr_value_pb2.AttrValue(i=config_pb2.EmbeddingVariableType.MUTABLE))
_set_init_op_embedding_type_attr(slot, config_pb2.EmbeddingVariableType.MUTABLE)
else:
filter_strategy = None
if primary._filter_freq != 0:
Expand All @@ -107,7 +106,7 @@ def _create_slot_var(primary, val, scope, validate_shape, shape, dtype, slot_con
else:
filter_strategy = variables.CounterFilter(filter_freq=primary._filter_freq)
if slot_config.slot_type is config_pb2.SlotType.EMBEDDING_VARIABLE:
primary._init_op._set_attr("slot_num", attr_value_pb2.AttrValue(i=slot_config.slot_num))
_set_init_op_slot_num_attr(primary, slot_config.slot_num)
primary._slot_num = slot_config.slot_num
emb_index = primary._emb_index
if primary.block_num > 1:
Expand All @@ -132,8 +131,7 @@ def _create_slot_var(primary, val, scope, validate_shape, shape, dtype, slot_con
l2_weight_threshold=primary._l2_weight_threshold,
filter_strategy=filter_strategy)
)
slot._init_op._set_attr("embedding_variable_type",
attr_value_pb2.AttrValue(i=config_pb2.EmbeddingVariableType.MUTABLE))
_set_init_op_embedding_type_attr(slot, config_pb2.EmbeddingVariableType.MUTABLE)
else:
slot = variable_scope.get_variable(
scope,
Expand Down Expand Up @@ -300,3 +298,13 @@ def create_zeros_slot(primary, name, dtype=None, colocate_with_primary=True, slo
return create_slot(primary, val, name,
colocate_with_primary=colocate_with_primary,
slot_config=slot_config)

def _set_init_op_embedding_type_attr(var, embedding_type):
var._init_op._set_attr("embedding_variable_type",
attr_value_pb2.AttrValue(i=embedding_type))
var._initializer_for_restore._set_attr("embedding_variable_type",
attr_value_pb2.AttrValue(i=embedding_type))

def _set_init_op_slot_num_attr(var, slot_num):
var._init_op._set_attr("slot_num", attr_value_pb2.AttrValue(i=slot_num))
var._initializer_for_restore._set_attr("slot_num", attr_value_pb2.AttrValue(i=slot_num))

0 comments on commit 8d8e16a

Please sign in to comment.