diff --git a/tensorflow/core/framework/embedding/cpu_hash_map_kv.h b/tensorflow/core/framework/embedding/cpu_hash_map_kv.h index f491d652f0a..750ba282285 100644 --- a/tensorflow/core/framework/embedding/cpu_hash_map_kv.h +++ b/tensorflow/core/framework/embedding/cpu_hash_map_kv.h @@ -148,7 +148,8 @@ class LocklessHashMap : public KVInterface { for (int64 j = 0; j < bucket_count; j++) { if (hash_map_dump[j].first != LocklessHashMap::EMPTY_KEY_ && hash_map_dump[j].first != LocklessHashMap::DELETED_KEY_ - && hash_map_dump[j].first % 1000 % partition_nums != partition_id) { + && hash_map_dump[j].first % kSavedPartitionNum + % partition_nums != partition_id) { key_list->emplace_back(hash_map_dump[j].first); value_ptr_list->emplace_back(hash_map_dump[j].second); } diff --git a/tensorflow/core/framework/embedding/dense_hash_map_kv.h b/tensorflow/core/framework/embedding/dense_hash_map_kv.h index e91ba35971f..8a27404b66f 100644 --- a/tensorflow/core/framework/embedding/dense_hash_map_kv.h +++ b/tensorflow/core/framework/embedding/dense_hash_map_kv.h @@ -131,7 +131,7 @@ class DenseHashMap : public KVInterface { } for (int i = 0; i< partition_num_; i++) { for (const auto it : hash_map_dump[i].hash_map) { - if (it.first % 1000 % partition_nums != partition_id) { + if (it.first % kSavedPartitionNum % partition_nums != partition_id) { key_list->push_back(it.first); value_ptr_list->push_back(it.second); } diff --git a/tensorflow/core/framework/embedding/leveldb_kv.h b/tensorflow/core/framework/embedding/leveldb_kv.h index 01ec7faa13c..47c8a39dfbd 100644 --- a/tensorflow/core/framework/embedding/leveldb_kv.h +++ b/tensorflow/core/framework/embedding/leveldb_kv.h @@ -203,7 +203,7 @@ class LevelDBKV : public KVInterface { for (it->SeekToFirst(); it->Valid(); it->Next()) { K key; memcpy((char*)&key, it->key().ToString().data(), sizeof(K)); - if (key % 1000 % partition_nums == partition_id) continue; + if (key % kSavedPartitionNum % partition_nums == partition_id) continue; key_list->emplace_back(key); FeatureDescriptor hbm_feat_desc( 1, 1, ev_allocator()/*useless*/, diff --git a/tensorflow/python/ops/embedding_ops.py b/tensorflow/python/ops/embedding_ops.py index cb2b7bb8154..e239c9ba8d5 100644 --- a/tensorflow/python/ops/embedding_ops.py +++ b/tensorflow/python/ops/embedding_ops.py @@ -44,6 +44,7 @@ from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util.tf_export import tf_export +SAVED_PARTITIONED_NUM = 1000 def _clip(params, ids, max_norm): """Helper function for _embedding_lookup_and_transform. @@ -216,7 +217,7 @@ def _embedding_lookup_and_transform(params, if isinstance(params[0], kv_variable_ops.EmbeddingVariable): new_ids = flat_ids - p_assignments = flat_ids % 1000 % np + p_assignments = flat_ids % SAVED_PARTITIONED_NUM % np elif partition_strategy == "mod": p_assignments = flat_ids % np new_ids = flat_ids // np