Skip to content

Commit

Permalink
[Embedding] Improve code format.
Browse files Browse the repository at this point in the history
Signed-off-by: JunqiHu <silenceki@hotmail.com>
  • Loading branch information
Mesilenceki committed Oct 19, 2023
1 parent eb68cfd commit 88748e5
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 4 deletions.
3 changes: 2 additions & 1 deletion tensorflow/core/framework/embedding/cpu_hash_map_kv.h
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,8 @@ class LocklessHashMap : public KVInterface<K, V> {
for (int64 j = 0; j < bucket_count; j++) {
if (hash_map_dump[j].first != LocklessHashMap<K, V>::EMPTY_KEY_
&& hash_map_dump[j].first != LocklessHashMap<K, V>::DELETED_KEY_
&& hash_map_dump[j].first % 1000 % partition_nums != partition_id) {
&& hash_map_dump[j].first % kSavedPartitionNum
% partition_nums != partition_id) {
key_list->emplace_back(hash_map_dump[j].first);
value_ptr_list->emplace_back(hash_map_dump[j].second);
}
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/core/framework/embedding/dense_hash_map_kv.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ class DenseHashMap : public KVInterface<K, V> {
}
for (int i = 0; i< partition_num_; i++) {
for (const auto it : hash_map_dump[i].hash_map) {
if (it.first % 1000 % partition_nums != partition_id) {
if (it.first % kSavedPartitionNum % partition_nums != partition_id) {
key_list->push_back(it.first);
value_ptr_list->push_back(it.second);
}
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/core/framework/embedding/leveldb_kv.h
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ class LevelDBKV : public KVInterface<K, V> {
for (it->SeekToFirst(); it->Valid(); it->Next()) {
K key;
memcpy((char*)&key, it->key().ToString().data(), sizeof(K));
if (key % 1000 % partition_nums == partition_id) continue;
if (key % kSavedPartitionNum % partition_nums == partition_id) continue;
key_list->emplace_back(key);
FeatureDescriptor<V> hbm_feat_desc(
1, 1, ev_allocator()/*useless*/,
Expand Down
3 changes: 2 additions & 1 deletion tensorflow/python/ops/embedding_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util.tf_export import tf_export

SAVED_PARTITIONED_NUM = 1000

def _clip(params, ids, max_norm):
"""Helper function for _embedding_lookup_and_transform.
Expand Down Expand Up @@ -216,7 +217,7 @@ def _embedding_lookup_and_transform(params,

if isinstance(params[0], kv_variable_ops.EmbeddingVariable):
new_ids = flat_ids
p_assignments = flat_ids % 1000 % np
p_assignments = flat_ids % SAVED_PARTITIONED_NUM % np
elif partition_strategy == "mod":
p_assignments = flat_ids % np
new_ids = flat_ids // np
Expand Down

0 comments on commit 88748e5

Please sign in to comment.