From 821d5e8d39156d477bacd2ede9f68f76ede0f77d Mon Sep 17 00:00:00 2001 From: lixy9474 Date: Tue, 19 Sep 2023 09:56:20 +0800 Subject: [PATCH] [Embedding] Remove the dependency on private header file in EmbeddingVariable. (#927) Signed-off-by: lixy9474 --- tensorflow/core/BUILD | 5 +- .../framework/embedding/embedding_config.h | 3 + .../core/framework/embedding/embedding_var.h | 1 - .../embedding/embedding_var_ckpt_data.cc | 262 +++++++ .../embedding/embedding_var_ckpt_data.h | 190 +---- .../embedding/embedding_var_dump_iterator.h | 7 +- .../embedding/embedding_var_restore.cc | 647 ++++++++++++++++++ .../embedding/embedding_var_restore.h | 534 +-------------- .../core/framework/embedding/kv_interface.h | 8 +- .../embedding/ssd_record_descriptor.cc | 88 +++ .../embedding/ssd_record_descriptor.h | 49 +- tensorflow/core/framework/embedding/storage.h | 4 +- tensorflow/core/kernels/BUILD | 5 +- 13 files changed, 1041 insertions(+), 762 deletions(-) create mode 100644 tensorflow/core/framework/embedding/embedding_var_ckpt_data.cc create mode 100644 tensorflow/core/framework/embedding/embedding_var_restore.cc create mode 100644 tensorflow/core/framework/embedding/ssd_record_descriptor.cc diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 8ae5b4f156c..95bbbab5624 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -3026,7 +3026,10 @@ tf_cuda_library( "framework/embedding/gpu_hash_table.cu.cc", "framework/embedding/gpu_hash_table.h", "framework/embedding/embedding_var.cu.cc", - "framework/embedding/multi_tier_storage.cu.cc" + "framework/embedding/multi_tier_storage.cu.cc", + "framework/embedding/embedding_var_ckpt_data.cc", + "framework/embedding/embedding_var_restore.cc", + "framework/embedding/ssd_record_descriptor.cc" ], ) + select({ "//tensorflow:windows": [], diff --git a/tensorflow/core/framework/embedding/embedding_config.h b/tensorflow/core/framework/embedding/embedding_config.h index 0a50b492159..d47d07d4205 100644 --- a/tensorflow/core/framework/embedding/embedding_config.h +++ b/tensorflow/core/framework/embedding/embedding_config.h @@ -3,6 +3,9 @@ #include #include "tensorflow/core/framework/embedding/config.pb.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/default/logging.h" namespace tensorflow { struct EmbeddingConfig { diff --git a/tensorflow/core/framework/embedding/embedding_var.h b/tensorflow/core/framework/embedding/embedding_var.h index b29493f2169..28ce5094d87 100644 --- a/tensorflow/core/framework/embedding/embedding_var.h +++ b/tensorflow/core/framework/embedding/embedding_var.h @@ -37,7 +37,6 @@ limitations under the License. #include "tensorflow/core/framework/embedding/storage.h" #include "tensorflow/core/framework/embedding/storage_factory.h" #include "tensorflow/core/framework/typed_allocator.h" -#include "tensorflow/core/util/tensor_bundle/tensor_bundle.h" namespace tensorflow { using CPUDevice = Eigen::ThreadPoolDevice; diff --git a/tensorflow/core/framework/embedding/embedding_var_ckpt_data.cc b/tensorflow/core/framework/embedding/embedding_var_ckpt_data.cc new file mode 100644 index 00000000000..c1b43a608b5 --- /dev/null +++ b/tensorflow/core/framework/embedding/embedding_var_ckpt_data.cc @@ -0,0 +1,262 @@ +/* Copyright 2022 The DeepRec Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================*/ +#include "tensorflow/core/framework/embedding/embedding_var_ckpt_data.h" +#include "tensorflow/core/framework/embedding/embedding_var_dump_iterator.h" +#include "tensorflow/core/kernels/save_restore_tensor.h" +#include "tensorflow/core/framework/register_types.h" + +namespace tensorflow { +namespace embedding { +template +void EmbeddingVarCkptData::Emplace( + K key, ValuePtr* value_ptr, + const EmbeddingConfig& emb_config, + V* default_value, int64 value_offset, + bool is_save_freq, + bool is_save_version, + bool save_unfiltered_features) { + if((int64)value_ptr == ValuePtrStatus::IS_DELETED) + return; + + V* primary_val = value_ptr->GetValue(0, 0); + bool is_not_admit = + primary_val == nullptr + && emb_config.filter_freq != 0; + + if (!is_not_admit) { + key_vec_.emplace_back(key); + + if (primary_val == nullptr) { + value_ptr_vec_.emplace_back(default_value); + } else if ( + (int64)primary_val == ValuePosition::NOT_IN_DRAM) { + value_ptr_vec_.emplace_back((V*)ValuePosition::NOT_IN_DRAM); + } else { + V* val = value_ptr->GetValue(emb_config.emb_index, + value_offset); + value_ptr_vec_.emplace_back(val); + } + + + if(is_save_version) { + int64 dump_version = value_ptr->GetStep(); + version_vec_.emplace_back(dump_version); + } + + if(is_save_freq) { + int64 dump_freq = value_ptr->GetFreq(); + freq_vec_.emplace_back(dump_freq); + } + } else { + if (!save_unfiltered_features) + return; + + key_filter_vec_.emplace_back(key); + + if(is_save_version) { + int64 dump_version = value_ptr->GetStep(); + version_filter_vec_.emplace_back(dump_version); + } + + int64 dump_freq = value_ptr->GetFreq(); + freq_filter_vec_.emplace_back(dump_freq); + } +} +#define REGISTER_KERNELS(ktype, vtype) \ + template void EmbeddingVarCkptData::Emplace( \ + ktype, ValuePtr*, const EmbeddingConfig&, \ + vtype*, int64, bool, bool, bool); +#define REGISTER_KERNELS_ALL_INDEX(type) \ + REGISTER_KERNELS(int32, type) \ + REGISTER_KERNELS(int64, type) +TF_CALL_FLOAT_TYPES(REGISTER_KERNELS_ALL_INDEX) +#undef REGISTER_KERNELS_ALL_INDEX +#undef REGISTER_KERNELS + + +template +void EmbeddingVarCkptData::Emplace(K key, V* value_ptr) { + key_vec_.emplace_back(key); + value_ptr_vec_.emplace_back(value_ptr); +} +#define REGISTER_KERNELS(ktype, vtype) \ + template void EmbeddingVarCkptData::Emplace( \ + ktype, vtype*); +#define REGISTER_KERNELS_ALL_INDEX(type) \ + REGISTER_KERNELS(int32, type) \ + REGISTER_KERNELS(int64, type) +TF_CALL_FLOAT_TYPES(REGISTER_KERNELS_ALL_INDEX) +#undef REGISTER_KERNELS_ALL_INDEX +#undef REGISTER_KERNELS + +template +void EmbeddingVarCkptData::SetWithPartition( + std::vector>& ev_ckpt_data_parts) { + part_offset_.resize(kSavedPartitionNum + 1); + part_filter_offset_.resize(kSavedPartitionNum + 1); + part_offset_[0] = 0; + part_filter_offset_[0] = 0; + for (int i = 0; i < kSavedPartitionNum; i++) { + part_offset_[i + 1] = + part_offset_[i] + ev_ckpt_data_parts[i].key_vec_.size(); + + part_filter_offset_[i + 1] = + part_filter_offset_[i] + + ev_ckpt_data_parts[i].key_filter_vec_.size(); + + for (int64 j = 0; j < ev_ckpt_data_parts[i].key_vec_.size(); j++) { + key_vec_.emplace_back(ev_ckpt_data_parts[i].key_vec_[j]); + } + + for (int64 j = 0; j < ev_ckpt_data_parts[i].value_ptr_vec_.size(); j++) { + value_ptr_vec_.emplace_back(ev_ckpt_data_parts[i].value_ptr_vec_[j]); + } + + for (int64 j = 0; j < ev_ckpt_data_parts[i].version_vec_.size(); j++) { + version_vec_.emplace_back(ev_ckpt_data_parts[i].version_vec_[j]); + } + + for (int64 j = 0; j < ev_ckpt_data_parts[i].freq_vec_.size(); j++) { + freq_vec_.emplace_back(ev_ckpt_data_parts[i].freq_vec_[j]); + } + + for (int64 j = 0; j < ev_ckpt_data_parts[i].key_filter_vec_.size(); j++) { + key_filter_vec_.emplace_back(ev_ckpt_data_parts[i].key_filter_vec_[j]); + } + + for (int64 j = 0; j < ev_ckpt_data_parts[i].version_filter_vec_.size(); j++) { + version_filter_vec_.emplace_back(ev_ckpt_data_parts[i].version_filter_vec_[j]); + } + + for (int64 j = 0; j < ev_ckpt_data_parts[i].freq_filter_vec_.size(); j++) { + freq_filter_vec_.emplace_back(ev_ckpt_data_parts[i].freq_filter_vec_[j]); + } + } +} + +#define REGISTER_KERNELS(ktype, vtype) \ + template void EmbeddingVarCkptData::SetWithPartition( \ + std::vector>&); +#define REGISTER_KERNELS_ALL_INDEX(type) \ + REGISTER_KERNELS(int32, type) \ + REGISTER_KERNELS(int64, type) +TF_CALL_FLOAT_TYPES(REGISTER_KERNELS_ALL_INDEX) +#undef REGISTER_KERNELS_ALL_INDEX +#undef REGISTER_KERNELS + +template +Status EmbeddingVarCkptData::ExportToCkpt( + const string& tensor_name, + BundleWriter* writer, + int64 value_len, + ValueIterator* value_iter) { + size_t bytes_limit = 8 << 20; + std::unique_ptr dump_buffer(new char[bytes_limit]); + + EVVectorDataDumpIterator key_dump_iter(key_vec_); + Status s = SaveTensorWithFixedBuffer( + tensor_name + "-keys", writer, dump_buffer.get(), + bytes_limit, &key_dump_iter, + TensorShape({key_vec_.size()})); + if (!s.ok()) + return s; + + EV2dVectorDataDumpIterator value_dump_iter( + value_ptr_vec_, value_len, value_iter); + s = SaveTensorWithFixedBuffer( + tensor_name + "-values", writer, dump_buffer.get(), + bytes_limit, &value_dump_iter, + TensorShape({value_ptr_vec_.size(), value_len})); + if (!s.ok()) + return s; + + EVVectorDataDumpIterator version_dump_iter(version_vec_); + s = SaveTensorWithFixedBuffer( + tensor_name + "-versions", writer, dump_buffer.get(), + bytes_limit, &version_dump_iter, + TensorShape({version_vec_.size()})); + if (!s.ok()) + return s; + + EVVectorDataDumpIterator freq_dump_iter(freq_vec_); + s = SaveTensorWithFixedBuffer( + tensor_name + "-freqs", writer, dump_buffer.get(), + bytes_limit, &freq_dump_iter, + TensorShape({freq_vec_.size()})); + if (!s.ok()) + return s; + + EVVectorDataDumpIterator filtered_key_dump_iter(key_filter_vec_); + s = SaveTensorWithFixedBuffer( + tensor_name + "-keys_filtered", writer, dump_buffer.get(), + bytes_limit, &filtered_key_dump_iter, + TensorShape({key_filter_vec_.size()})); + if (!s.ok()) + return s; + + EVVectorDataDumpIterator + filtered_version_dump_iter(version_filter_vec_); + s = SaveTensorWithFixedBuffer( + tensor_name + "-versions_filtered", + writer, dump_buffer.get(), + bytes_limit, &filtered_version_dump_iter, + TensorShape({version_filter_vec_.size()})); + if (!s.ok()) + return s; + + EVVectorDataDumpIterator + filtered_freq_dump_iter(freq_filter_vec_); + s = SaveTensorWithFixedBuffer( + tensor_name + "-freqs_filtered", + writer, dump_buffer.get(), + bytes_limit, &filtered_freq_dump_iter, + TensorShape({freq_filter_vec_.size()})); + if (!s.ok()) + return s; + + EVVectorDataDumpIterator + part_offset_dump_iter(part_offset_); + s = SaveTensorWithFixedBuffer( + tensor_name + "-partition_offset", + writer, dump_buffer.get(), + bytes_limit, &part_offset_dump_iter, + TensorShape({part_offset_.size()})); + if (!s.ok()) + return s; + + EVVectorDataDumpIterator + part_filter_offset_dump_iter(part_filter_offset_); + s = SaveTensorWithFixedBuffer( + tensor_name + "-partition_filter_offset", + writer, dump_buffer.get(), + bytes_limit, &part_filter_offset_dump_iter, + TensorShape({part_filter_offset_.size()})); + if (!s.ok()) + return s; + + return Status::OK(); +} + +#define REGISTER_KERNELS(ktype, vtype) \ + template Status EmbeddingVarCkptData::ExportToCkpt( \ + const string&, BundleWriter*, int64, ValueIterator*); +#define REGISTER_KERNELS_ALL_INDEX(type) \ + REGISTER_KERNELS(int32, type) \ + REGISTER_KERNELS(int64, type) +TF_CALL_FLOAT_TYPES(REGISTER_KERNELS_ALL_INDEX) +#undef REGISTER_KERNELS_ALL_INDEX +#undef REGISTER_KERNELS +}// namespace embedding +}// namespace tensorflow \ No newline at end of file diff --git a/tensorflow/core/framework/embedding/embedding_var_ckpt_data.h b/tensorflow/core/framework/embedding/embedding_var_ckpt_data.h index aa1a08cbcfd..6d7b09e70b0 100644 --- a/tensorflow/core/framework/embedding/embedding_var_ckpt_data.h +++ b/tensorflow/core/framework/embedding/embedding_var_ckpt_data.h @@ -15,11 +15,11 @@ limitations under the License. #ifndef TENSORFLOW_CORE_FRAMEWORK_EMBEDDING_EMBEDDING_VAR_CKPT_DATA_ #define TENSORFLOW_CORE_FRAMEWORK_EMBEDDING_EMBEDDING_VAR_CKPT_DATA_ #include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/tensor_bundle/tensor_bundle.h" -#include "tensorflow/core/kernels/save_restore_tensor.h" #include "tensorflow/core/framework/embedding/embedding_config.h" #include "tensorflow/core/framework/embedding/embedding_var_dump_iterator.h" namespace tensorflow { +class BundleWriter; + namespace embedding { template @@ -30,195 +30,17 @@ class EmbeddingVarCkptData { V* default_value, int64 value_offset, bool is_save_freq, bool is_save_version, - bool save_unfiltered_features) { - if((int64)value_ptr == ValuePtrStatus::IS_DELETED) - return; - - V* primary_val = value_ptr->GetValue(0, 0); - bool is_not_admit = - primary_val == nullptr - && emb_config.filter_freq != 0; - - if (!is_not_admit) { - key_vec_.emplace_back(key); - - if (primary_val == nullptr) { - value_ptr_vec_.emplace_back(default_value); - } else if ( - (int64)primary_val == ValuePosition::NOT_IN_DRAM) { - value_ptr_vec_.emplace_back((V*)ValuePosition::NOT_IN_DRAM); - } else { - V* val = value_ptr->GetValue(emb_config.emb_index, - value_offset); - value_ptr_vec_.emplace_back(val); - } - - - if(is_save_version) { - int64 dump_version = value_ptr->GetStep(); - version_vec_.emplace_back(dump_version); - } - - if(is_save_freq) { - int64 dump_freq = value_ptr->GetFreq(); - freq_vec_.emplace_back(dump_freq); - } - } else { - if (!save_unfiltered_features) - return; - - key_filter_vec_.emplace_back(key); + bool save_unfiltered_features); - if(is_save_version) { - int64 dump_version = value_ptr->GetStep(); - version_filter_vec_.emplace_back(dump_version); - } - - int64 dump_freq = value_ptr->GetFreq(); - freq_filter_vec_.emplace_back(dump_freq); - } - } - - void Emplace(K key, V* value_ptr) { - key_vec_.emplace_back(key); - value_ptr_vec_.emplace_back(value_ptr); - } + void Emplace(K key, V* value_ptr); void SetWithPartition( - std::vector>& ev_ckpt_data_parts) { - part_offset_.resize(kSavedPartitionNum + 1); - part_filter_offset_.resize(kSavedPartitionNum + 1); - part_offset_[0] = 0; - part_filter_offset_[0] = 0; - for (int i = 0; i < kSavedPartitionNum; i++) { - part_offset_[i + 1] = - part_offset_[i] + ev_ckpt_data_parts[i].key_vec_.size(); - - part_filter_offset_[i + 1] = - part_filter_offset_[i] + - ev_ckpt_data_parts[i].key_filter_vec_.size(); - - for (int64 j = 0; j < ev_ckpt_data_parts[i].key_vec_.size(); j++) { - key_vec_.emplace_back(ev_ckpt_data_parts[i].key_vec_[j]); - } - - for (int64 j = 0; j < ev_ckpt_data_parts[i].value_ptr_vec_.size(); j++) { - value_ptr_vec_.emplace_back(ev_ckpt_data_parts[i].value_ptr_vec_[j]); - } - - for (int64 j = 0; j < ev_ckpt_data_parts[i].version_vec_.size(); j++) { - version_vec_.emplace_back(ev_ckpt_data_parts[i].version_vec_[j]); - } - - for (int64 j = 0; j < ev_ckpt_data_parts[i].freq_vec_.size(); j++) { - freq_vec_.emplace_back(ev_ckpt_data_parts[i].freq_vec_[j]); - } - - for (int64 j = 0; j < ev_ckpt_data_parts[i].key_filter_vec_.size(); j++) { - key_filter_vec_.emplace_back(ev_ckpt_data_parts[i].key_filter_vec_[j]); - } - - for (int64 j = 0; j < ev_ckpt_data_parts[i].version_filter_vec_.size(); j++) { - version_filter_vec_.emplace_back(ev_ckpt_data_parts[i].version_filter_vec_[j]); - } - - for (int64 j = 0; j < ev_ckpt_data_parts[i].freq_filter_vec_.size(); j++) { - freq_filter_vec_.emplace_back(ev_ckpt_data_parts[i].freq_filter_vec_[j]); - } - } - } + std::vector>& ev_ckpt_data_parts); Status ExportToCkpt(const string& tensor_name, BundleWriter* writer, int64 value_len, - ValueIterator* value_iter = nullptr) { - size_t bytes_limit = 8 << 20; - std::unique_ptr dump_buffer(new char[bytes_limit]); - - EVVectorDataDumpIterator key_dump_iter(key_vec_); - Status s = SaveTensorWithFixedBuffer( - tensor_name + "-keys", writer, dump_buffer.get(), - bytes_limit, &key_dump_iter, - TensorShape({key_vec_.size()})); - if (!s.ok()) - return s; - - EV2dVectorDataDumpIterator value_dump_iter( - value_ptr_vec_, value_len, value_iter); - s = SaveTensorWithFixedBuffer( - tensor_name + "-values", writer, dump_buffer.get(), - bytes_limit, &value_dump_iter, - TensorShape({value_ptr_vec_.size(), value_len})); - if (!s.ok()) - return s; - - EVVectorDataDumpIterator version_dump_iter(version_vec_); - s = SaveTensorWithFixedBuffer( - tensor_name + "-versions", writer, dump_buffer.get(), - bytes_limit, &version_dump_iter, - TensorShape({version_vec_.size()})); - if (!s.ok()) - return s; - - EVVectorDataDumpIterator freq_dump_iter(freq_vec_); - s = SaveTensorWithFixedBuffer( - tensor_name + "-freqs", writer, dump_buffer.get(), - bytes_limit, &freq_dump_iter, - TensorShape({freq_vec_.size()})); - if (!s.ok()) - return s; - - EVVectorDataDumpIterator filtered_key_dump_iter(key_filter_vec_); - s = SaveTensorWithFixedBuffer( - tensor_name + "-keys_filtered", writer, dump_buffer.get(), - bytes_limit, &filtered_key_dump_iter, - TensorShape({key_filter_vec_.size()})); - if (!s.ok()) - return s; - - EVVectorDataDumpIterator - filtered_version_dump_iter(version_filter_vec_); - s = SaveTensorWithFixedBuffer( - tensor_name + "-versions_filtered", - writer, dump_buffer.get(), - bytes_limit, &filtered_version_dump_iter, - TensorShape({version_filter_vec_.size()})); - if (!s.ok()) - return s; - - EVVectorDataDumpIterator - filtered_freq_dump_iter(freq_filter_vec_); - s = SaveTensorWithFixedBuffer( - tensor_name + "-freqs_filtered", - writer, dump_buffer.get(), - bytes_limit, &filtered_freq_dump_iter, - TensorShape({freq_filter_vec_.size()})); - if (!s.ok()) - return s; - - EVVectorDataDumpIterator - part_offset_dump_iter(part_offset_); - s = SaveTensorWithFixedBuffer( - tensor_name + "-partition_offset", - writer, dump_buffer.get(), - bytes_limit, &part_offset_dump_iter, - TensorShape({part_offset_.size()})); - if (!s.ok()) - return s; - - EVVectorDataDumpIterator - part_filter_offset_dump_iter(part_filter_offset_); - s = SaveTensorWithFixedBuffer( - tensor_name + "-partition_filter_offset", - writer, dump_buffer.get(), - bytes_limit, &part_filter_offset_dump_iter, - TensorShape({part_filter_offset_.size()})); - if (!s.ok()) - return s; - - return Status::OK(); - } - + ValueIterator* value_iter = nullptr); private: std::vector key_vec_; std::vector value_ptr_vec_; diff --git a/tensorflow/core/framework/embedding/embedding_var_dump_iterator.h b/tensorflow/core/framework/embedding/embedding_var_dump_iterator.h index 71ba054b873..84c823a90dc 100644 --- a/tensorflow/core/framework/embedding/embedding_var_dump_iterator.h +++ b/tensorflow/core/framework/embedding/embedding_var_dump_iterator.h @@ -15,9 +15,12 @@ limitations under the License. #ifndef TENSORFLOW_CORE_FRAMEWORK_EMBEDDING_EMBEDDING_VAR_DUMP_ITERATOR_ #define TENSORFLOW_CORE_FRAMEWORK_EMBEDDING_EMBEDDING_VAR_DUMP_ITERATOR_ #include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/tensor_bundle/tensor_bundle.h" -#include "tensorflow/core/kernels/save_restore_tensor.h" +#include "tensorflow/core/framework/embedding/embedding_config.h" +#include "tensorflow/core/framework/embedding/kv_interface.h" namespace tensorflow { +template +class DumpIterator; + namespace embedding { template class EVVectorDataDumpIterator: public DumpIterator { diff --git a/tensorflow/core/framework/embedding/embedding_var_restore.cc b/tensorflow/core/framework/embedding/embedding_var_restore.cc new file mode 100644 index 00000000000..11c13008995 --- /dev/null +++ b/tensorflow/core/framework/embedding/embedding_var_restore.cc @@ -0,0 +1,647 @@ +/* Copyright 2023 The DeepRec Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================*/ +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/bounds_check.h" +#include "tensorflow/core/framework/embedding/embedding_var_restore.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/kernels/save_restore_tensor.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/random/philox_random.h" +#include "tensorflow/core/lib/random/random.h" +#include "tensorflow/core/lib/random/random_distributions.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/tensor_bundle/tensor_bundle.h" + +namespace tensorflow { +template +int64 ReadRecord(BundleReader* reader, const string& record_key, K** buffer) { + TensorShape shape; + Status st; + st = reader->LookupTensorShape(record_key, &shape); + if (!st.ok()) { + LOG(FATAL) << "Restore record " << record_key << " failed"; + } + st = reader->LookupHeader(record_key, sizeof(K) * shape.dim_size(0)); + if (!st.ok()) { + LOG(FATAL) << "Restore record " << record_key << " failed"; + } + size_t bytes_read = 0; + *buffer = new K[shape.dim_size(0)]; + st = reader->LookupSegment(record_key, sizeof(K) * shape.dim_size(0), + (char*)*buffer, bytes_read); + if (!st.ok()) { + LOG(FATAL) << "Restore record " << record_key << " failed"; + } + return shape.dim_size(0); +} +#define REGISTER_KERNELS(ktype) \ + template int64 ReadRecord(BundleReader*, const string&, ktype**); +REGISTER_KERNELS(int32); +REGISTER_KERNELS(int64); +#undef REGISTER_KERNELS + +template +void CheckpointLoader::RestoreSSD() { + std::string name_string_temp(restore_args_.m_name_string); + std::string new_str = "_"; + int64 pos = name_string_temp.find("/"); + while (pos != std::string::npos) { + name_string_temp.replace(pos, 1, new_str.data(), 1); + pos = name_string_temp.find("/"); + } + std::string ssd_record_file_name = restore_args_.m_file_name_string + "-" + + name_string_temp + "-ssd_record"; + if (Env::Default()->FileExists(ssd_record_file_name + ".index").ok()) { + std::string ssd_emb_file_name = restore_args_.m_file_name_string + "-" + + name_string_temp + "-emb_files"; + BundleReader ssd_record_reader(Env::Default(), ssd_record_file_name); + RestoreSSDBuffer ssd_buffer(&ssd_record_reader); + VLOG(1) << "Loading SSD record... " << ssd_record_file_name; + storage_->RestoreSSD(ev_->GetEmbeddingIndex(), + ev_->GetEmbeddingSlotNum(), ev_->ValueLen(), + ssd_emb_file_name, ev_, ssd_buffer); + } +} +#define REGISTER_KERNELS(ktype, vtype) \ + template void CheckpointLoader::RestoreSSD(); +#define REGISTER_KERNELS_ALL_INDEX(type) \ + REGISTER_KERNELS(int32, type) \ + REGISTER_KERNELS(int64, type) +TF_CALL_FLOAT_TYPES(REGISTER_KERNELS_ALL_INDEX) +#undef REGISTER_KERNELS_ALL_INDEX +#undef REGISTER_KERNELS + +template +void CheckpointLoader::RestoreInternal( + const std::string& name_string, + const EmbeddingConfig& emb_config, + const Eigen::GpuDevice* device, + RestoreBuffer& restore_buff) { + Status s = EVInitTensorNameAndShape(name_string); + if (!s.ok()) { + LOG(ERROR) << "EVInitTensorNameAndShape fail:" << s.ToString(); + return; + } + + Tensor part_offset_tensor; + Tensor part_filter_offset_tensor; + if (!restore_args_.m_is_oldform) { + /****** InitPartOffsetTensor ******/ + TensorShape part_offset_shape, part_filter_offset_shape; + DataType part_offset_type, part_filter_offset_type; + string offset_tensor_name; + if (!restore_args_.m_is_incr) { + offset_tensor_name = name_string + kPartOffsetTensorSuffsix; + } else { + offset_tensor_name = name_string + kIncrPartOffsetTensorSuffsix; + } + + string offset_filter_tensor_name = + name_string + kPartFilterOffsetTensorSuffsix; + Status s = reader_->LookupDtypeAndShape( + offset_tensor_name, &part_offset_type, &part_offset_shape); + if (!s.ok()) { + LOG(ERROR) << "EV restoring fail:" << s.error_message(); + } + s = reader_->LookupDtypeAndShape(offset_filter_tensor_name, + &part_filter_offset_type, + &part_filter_offset_shape); + if (!s.ok()) { + LOG(ERROR) << "EV restoring fail: " << s.error_message(); + } + part_offset_tensor = + Tensor(cpu_allocator(), part_offset_type, part_offset_shape); + part_filter_offset_tensor = Tensor( + cpu_allocator(), part_filter_offset_type, part_filter_offset_shape); + s = reader_->Lookup(offset_tensor_name, &part_offset_tensor); + if (!s.ok()) { + LOG(ERROR) << "EV restoring fail:" << s.error_message(); + } + + s = reader_->Lookup(offset_filter_tensor_name, + &part_filter_offset_tensor); + if (!s.ok()) { + LOG(ERROR) << "EV restoring fail: " << s.error_message(); + } + } + auto part_offset_flat = part_offset_tensor.flat(); + auto part_filter_offset_flat = part_filter_offset_tensor.flat(); + + if (restore_args_.m_is_oldform) { + VLOG(1) << "old form, EV name:" << name_string + << ", partition_id:" << restore_args_.m_partition_id + << ", new partition num:" << restore_args_.m_partition_num; + int64 new_dim = ev_->ValueLen(); + TensorShape key_shape; + Status st = + reader_->LookupTensorShape(restore_args_.m_tensor_key, &key_shape); + if (!st.ok()) { + LOG(ERROR) << "EVRestoreFeaturesOld fail: " << st.error_message(); + } + int tot_key_num = key_shape.dim_size(0); + Status s = EVRestoreFeatures(tot_key_num, 0, 0, 0, 0, restore_buff, + new_dim, emb_config, device); + if (!s.ok()) { + LOG(ERROR) << "EVRestoreFeaturesOld fail: " << s.error_message(); + } + } else { + int64 new_dim = ev_->ValueLen(); + VLOG(1) << "new form checkpoint... :" << name_string + << " , partition_id:" << restore_args_.m_partition_id + << " , partition_num:" << restore_args_.m_partition_num; + for (size_t i = 0; i < restore_args_.m_loaded_parts.size(); i++) { + int subpart_id = restore_args_.m_loaded_parts[i]; + size_t value_unit_bytes = sizeof(V) * restore_args_.m_old_dim; + size_t value_unit_bytes_new = sizeof(V) * new_dim; + int subpart_offset = part_offset_flat(subpart_id); + int tot_key_num = part_offset_flat(subpart_id + 1) - subpart_offset; + int64 key_part_offset = subpart_offset * sizeof(K); + int64 value_part_offset = + subpart_offset * sizeof(V) * restore_args_.m_old_dim; + int64 version_part_offset = subpart_offset * sizeof(int64); + int64 freq_part_offset = subpart_offset * sizeof(int64); + VLOG(1) << "dynamically load ev : " << name_string + << ", subpartid:" << subpart_id; + + EVRestoreFeatures(tot_key_num, key_part_offset, value_part_offset, + version_part_offset, freq_part_offset, restore_buff, + new_dim, emb_config, device); + + if (restore_args_.m_has_filter) { + Status s = EVRestoreFilteredFeatures( + subpart_id, new_dim, restore_buff, part_filter_offset_flat, + emb_config, device); + if (!s.ok()) { + LOG(ERROR) << "EVRestoreFilteredFeatures fail: " << s.error_message(); + } + } + } + } +} +#define REGISTER_KERNELS(ktype, vtype) \ + template void CheckpointLoader::RestoreInternal( \ + const std::string&, const EmbeddingConfig&, \ + const Eigen::GpuDevice*, RestoreBuffer&); +#define REGISTER_KERNELS_ALL_INDEX(type) \ + REGISTER_KERNELS(int32, type) \ + REGISTER_KERNELS(int64, type) +TF_CALL_FLOAT_TYPES(REGISTER_KERNELS_ALL_INDEX) +#undef REGISTER_KERNELS_ALL_INDEX +#undef REGISTER_KERNELS + +template +bool CheckpointLoader::IsOldCheckpoint( + const std::string& curr_partid_str, + const std::string& kPartOffsetTensorSuffsix) { + if (restore_args_.m_name_string.find(kPartStr) == std::string::npos) { + string tensor_name = restore_args_.m_name_string; + TensorShape part_offset_shape; + DataType part_offset_type; + Status st = + reader_->LookupDtypeAndShape(tensor_name + kPartOffsetTensorSuffsix, + &part_offset_type, &part_offset_shape); + if (st.ok()) return false; + + string part_id = std::to_string(0); + tensor_name = restore_args_.m_name_string + "/" + kPartStr + part_id; + + Status form_st = + reader_->LookupDtypeAndShape(tensor_name + kPartOffsetTensorSuffsix, + &part_offset_type, &part_offset_shape); + if (form_st.ok()) return false; + } else { + string part_id = std::to_string(0); + size_t part_pos = restore_args_.m_name_string.find(kPartStr); + size_t part_size = strlen(kPartStr); + size_t cur_part_size = curr_partid_str.size(); + + string pre_subname = restore_args_.m_name_string.substr(0, part_pos); + string post_subname = restore_args_.m_name_string.substr( + part_pos + part_size + cur_part_size); + string tensor_name = pre_subname + kPartStr + part_id + post_subname; + + TensorShape part_offset_shape; + DataType part_offset_type; + Status form_st = + reader_->LookupDtypeAndShape(tensor_name + kPartOffsetTensorSuffsix, + &part_offset_type, &part_offset_shape); + if (form_st.ok()) return false; + pre_subname = + restore_args_.m_name_string.substr(0, part_pos - 1); /* var1*/ + post_subname = restore_args_.m_name_string.substr(part_pos + part_size + + cur_part_size); + tensor_name = pre_subname + post_subname; + + Status st = + reader_->LookupDtypeAndShape(tensor_name + kPartOffsetTensorSuffsix, + &part_offset_type, &part_offset_shape); + if (st.ok()) return false; + } + + return true; +} +#define REGISTER_KERNELS(ktype, vtype) \ + template bool CheckpointLoader::IsOldCheckpoint( \ + const std::string&, const std::string&); +#define REGISTER_KERNELS_ALL_INDEX(type) \ + REGISTER_KERNELS(int32, type) \ + REGISTER_KERNELS(int64, type) +TF_CALL_FLOAT_TYPES(REGISTER_KERNELS_ALL_INDEX) +#undef REGISTER_KERNELS_ALL_INDEX +#undef REGISTER_KERNELS + + +template +void CheckpointLoader::InitPartNumAndLoadedParts( + std::vector& tensor_name_vec) { + std::string tmp_key_suffix; + std::string tmp_kPartOffsetTensorSuffsix; + if (!restore_args_.m_is_incr) { + tmp_key_suffix = kKeySuffix; + tmp_kPartOffsetTensorSuffsix = kPartOffsetTensorSuffsix; + } else { + tmp_key_suffix = kIncrKeySuffix; + tmp_kPartOffsetTensorSuffsix = kIncrPartOffsetTensorSuffsix; + } + + restore_args_.m_loaded_parts.reserve(kSavedPartitionNum); + int orig_partnum = 0; + const string& curr_partid_str = std::to_string(restore_args_.m_partition_id); + size_t part_pos = restore_args_.m_name_string.find(kPartStr); + + if (IsOldCheckpoint(curr_partid_str, tmp_kPartOffsetTensorSuffsix)) { + restore_args_.m_is_oldform = true; + } + + if (part_pos == std::string::npos) { + for (;; orig_partnum++) { + string part_id = std::to_string(orig_partnum); + string tensor_name = + restore_args_.m_name_string + "/" + kPartStr + part_id; + string tensor_key = tensor_name + tmp_key_suffix; + TensorShape key_shape; + Status st = reader_->LookupTensorShape(tensor_key, &key_shape); + if (!st.ok()) { + break; + } + tensor_name_vec.emplace_back(tensor_name); + } + if (orig_partnum == 0) { + tensor_name_vec.emplace_back(restore_args_.m_name_string); + } + for (int i = 0; i < kSavedPartitionNum; ++i) { + restore_args_.m_loaded_parts.push_back(i); + } + } else { + for (;; orig_partnum++) { + string part_id = std::to_string(orig_partnum); + string pre_subname = restore_args_.m_name_string.substr(0, part_pos); + string post_subname = restore_args_.m_name_string.substr( + part_pos + strlen(kPartStr) + curr_partid_str.size()); + string tensor_name = pre_subname + kPartStr + part_id + post_subname; + string tensor_key = tensor_name + tmp_key_suffix; + TensorShape key_shape; + Status st = reader_->LookupTensorShape(tensor_key, &key_shape); + if (!st.ok()) { + break; + } + tensor_name_vec.emplace_back(tensor_name); + } + if (orig_partnum == 0) { + string pre_subname = restore_args_.m_name_string.substr(0, part_pos - 1); + string post_subname = restore_args_.m_name_string.substr( + part_pos + strlen(kPartStr) + curr_partid_str.size()); + string tmp_name = pre_subname + post_subname; + tensor_name_vec.emplace_back(tmp_name); + } + for (int i = 0; i < kSavedPartitionNum; i++) { + if (i % restore_args_.m_partition_num == restore_args_.m_partition_id) { + restore_args_.m_loaded_parts.push_back(i); + } + } + } + for (auto& tensor_name : tensor_name_vec) { + VLOG(1) << "**** " << restore_args_.m_name_string << " " << tensor_name + << " ****"; + } +} +#define REGISTER_KERNELS(ktype, vtype) \ + template void CheckpointLoader::InitPartNumAndLoadedParts(\ + std::vector&); +#define REGISTER_KERNELS_ALL_INDEX(type) \ + REGISTER_KERNELS(int32, type) \ + REGISTER_KERNELS(int64, type) +TF_CALL_FLOAT_TYPES(REGISTER_KERNELS_ALL_INDEX) +#undef REGISTER_KERNELS_ALL_INDEX +#undef REGISTER_KERNELS + +template +Status CheckpointLoader::EVInitTensorNameAndShape( + const std::string& tensor_name) { + if (!restore_args_.m_is_incr) { + restore_args_.m_tensor_key = tensor_name + kKeySuffix; + restore_args_.m_tensor_value = tensor_name + kValueSuffix; + restore_args_.m_tensor_version = tensor_name + kVersionSuffix; + restore_args_.m_tensor_freq = tensor_name + kFreqSuffix; + } else { + restore_args_.m_tensor_key = tensor_name + kIncrKeySuffix; + restore_args_.m_tensor_value = tensor_name + kIncrValueSuffix; + restore_args_.m_tensor_version = tensor_name + kIncrVersionSuffix; + restore_args_.m_tensor_freq = tensor_name + kIncrFreqSuffix; + } + + TensorShape key_shape, value_shape, version_shape, freq_shape; + + Status st = + reader_->LookupTensorShape(restore_args_.m_tensor_key, &key_shape); + if (!st.ok()) { + return st; + } + st = reader_->LookupTensorShape(restore_args_.m_tensor_value, &value_shape); + if (!st.ok()) { + return st; + } + st = reader_->LookupTensorShape(restore_args_.m_tensor_version, + &version_shape); + if (!st.ok()) { + return st; + } + st = reader_->LookupHeader(restore_args_.m_tensor_key, + sizeof(K) * key_shape.dim_size(0)); + if (!st.ok()) { + return st; + } + st = reader_->LookupHeader(restore_args_.m_tensor_value, + sizeof(V) * value_shape.dim_size(0) * + value_shape.dim_size(1)); + if (!st.ok()) { + return st; + } + st = reader_->LookupHeader(restore_args_.m_tensor_version, + sizeof(int64) * version_shape.dim_size(0)); + if (!st.ok()) { + return st; + } + st = reader_->LookupTensorShape(restore_args_.m_tensor_freq, &freq_shape); + if (!st.ok()) { + if (st.code() == error::NOT_FOUND) { + freq_shape = version_shape; + } else { + return st; + } + } + st = reader_->LookupHeader(restore_args_.m_tensor_freq, + sizeof(int64) * freq_shape.dim_size(0)); + if (!st.ok()) { + if (st.code() == error::NOT_FOUND) { + restore_args_.m_has_freq = false; + } else { + return st; + } + } + restore_args_.m_old_dim = value_shape.dim_size(1); + + if (!restore_args_.m_is_oldform) { + TensorShape key_filter_shape, version_filter_shape, freq_filter_shape; + st = reader_->LookupTensorShape(restore_args_.m_tensor_key + "_filtered", + &key_filter_shape); + if (!st.ok()) { + if (st.code() == error::NOT_FOUND) { + key_filter_shape = key_shape; + restore_args_.m_has_filter = false; + } else { + return st; + } + } + st = reader_->LookupTensorShape( + restore_args_.m_tensor_version + "_filtered", &version_filter_shape); + if ((!st.ok()) && (st.code() != error::NOT_FOUND)) { + return st; + } + st = reader_->LookupHeader(restore_args_.m_tensor_key + "_filtered", + sizeof(K) * key_filter_shape.dim_size(0)); + if (!st.ok()) { + if (st.code() == error::NOT_FOUND) { + restore_args_.m_has_filter = false; + } else { + return st; + } + } + st = reader_->LookupHeader(restore_args_.m_tensor_version + "_filtered", + sizeof(K) * version_filter_shape.dim_size(0)); + if (!st.ok()) { + return st; + } + st = reader_->LookupTensorShape(restore_args_.m_tensor_freq + "_filtered", + &freq_filter_shape); + if (!st.ok()) { + if (st.code() == error::NOT_FOUND) { + freq_filter_shape = freq_shape; + } else { + return st; + } + } + + st = reader_->LookupHeader(restore_args_.m_tensor_freq + "_filtered", + sizeof(K) * freq_filter_shape.dim_size(0)); + if (!st.ok() && st.code() != error::NOT_FOUND) { + return st; + } + } + return st; +} +#define REGISTER_KERNELS(ktype, vtype) \ + template Status CheckpointLoader::EVInitTensorNameAndShape(\ + const std::string&); +#define REGISTER_KERNELS_ALL_INDEX(type) \ + REGISTER_KERNELS(int32, type) \ + REGISTER_KERNELS(int64, type) +TF_CALL_FLOAT_TYPES(REGISTER_KERNELS_ALL_INDEX) +#undef REGISTER_KERNELS_ALL_INDEX +#undef REGISTER_KERNELS + +template +Status CheckpointLoader::EVRestoreFeatures( + int tot_key_num, int64 key_part_offset, + int64 value_part_offset, int64 version_part_offset, + int64 freq_part_offset, RestoreBuffer& restore_buff, + int64 new_dim, const EmbeddingConfig& emb_config, + const Eigen::GpuDevice* device) { + size_t value_unit_bytes = sizeof(V) * restore_args_.m_old_dim; + size_t value_unit_bytes_new = sizeof(V) * new_dim; + int64 tot_key_bytes_read(0); + int64 tot_value_bytes_read(0); + int64 tot_version_bytes_read(0); + int64 tot_freq_bytes_read(0); + size_t key_bytes_read = 0; + size_t value_bytes_read = 0; + size_t version_bytes_read = 0; + size_t freq_bytes_read = 0; + + while (tot_key_num > 0) { + size_t read_key_num = std::min( + std::min(kBufferSize / sizeof(K), kBufferSize / value_unit_bytes), + kBufferSize / sizeof(int64)); + read_key_num = std::min(read_key_num, kBufferSize / value_unit_bytes_new); + read_key_num = std::min((int)read_key_num, tot_key_num); + reader_->LookupSegmentOffset( + restore_args_.m_tensor_key, key_part_offset + tot_key_bytes_read, + read_key_num * sizeof(K), restore_buff.key_buffer, key_bytes_read); + reader_->LookupSegmentOffset( + restore_args_.m_tensor_value, value_part_offset + tot_value_bytes_read, + read_key_num * value_unit_bytes, restore_buff.value_buffer, + value_bytes_read); + if (!restore_args_.m_reset_version) { + reader_->LookupSegmentOffset( + restore_args_.m_tensor_version, + version_part_offset + tot_version_bytes_read, + read_key_num * sizeof(int64), restore_buff.version_buffer, + version_bytes_read); + if (version_bytes_read == 0) { + memset(restore_buff.version_buffer, -1, sizeof(int64) * read_key_num); + } + } else { + int64* version_tmp = (int64*)restore_buff.version_buffer; + memset(version_tmp, 0, read_key_num * sizeof(int64)); + } + + if (restore_args_.m_has_freq) { + reader_->LookupSegmentOffset( + restore_args_.m_tensor_freq, freq_part_offset + tot_freq_bytes_read, + read_key_num * sizeof(int64), restore_buff.freq_buffer, + freq_bytes_read); + if (freq_bytes_read == 0) { + int64* freq_tmp = (int64*)restore_buff.freq_buffer; + for (int64 i = 0; i < read_key_num; i++) { + freq_tmp[i] = (ev_->MinFreq() == 0) ? 1 : ev_->MinFreq(); + } + } + } else { + int64* freq_tmp = (int64*)restore_buff.freq_buffer; + for (int64 i = 0; i < read_key_num; i++) { + freq_tmp[i] = (ev_->MinFreq() == 0) ? 1 : ev_->MinFreq(); + } + } + if (key_bytes_read > 0) { + read_key_num = key_bytes_read / sizeof(K); + Status st = RestoreCustomDim(new_dim, read_key_num, value_unit_bytes, + value_bytes_read, value_unit_bytes_new, + restore_buff); + if (!st.ok()) { + LOG(FATAL) << "EV Restore fail:" << st.ToString(); + } + + st = storage_->RestoreFeatures( + read_key_num, kSavedPartitionNum, restore_args_.m_partition_id, + restore_args_.m_partition_num, new_dim, false, restore_args_.m_is_incr, + emb_config, device, + filter_, restore_buff); + if (!st.ok()) { + LOG(FATAL) << "EV Restore fail:" << st.ToString(); + } + } + + tot_key_num -= read_key_num; + tot_key_bytes_read += key_bytes_read; + tot_value_bytes_read += value_bytes_read; + tot_version_bytes_read += version_bytes_read; + tot_freq_bytes_read += freq_bytes_read; + } + + return Status::OK(); +} +#define REGISTER_KERNELS(ktype, vtype) \ + template Status CheckpointLoader::EVRestoreFeatures( \ + int, int64, int64, int64, int64, RestoreBuffer&, \ + int64, const EmbeddingConfig&, const Eigen::GpuDevice*); +#define REGISTER_KERNELS_ALL_INDEX(type) \ + REGISTER_KERNELS(int32, type) \ + REGISTER_KERNELS(int64, type) +TF_CALL_FLOAT_TYPES(REGISTER_KERNELS_ALL_INDEX) +#undef REGISTER_KERNELS_ALL_INDEX +#undef REGISTER_KERNELS + +template +Status CheckpointLoader::EVRestoreFilteredFeatures( + int64 subpart_id, int64 value_len, RestoreBuffer& restore_buff, + typename TTypes::Flat part_filter_offset_flat, + const EmbeddingConfig& emb_config, const Eigen::GpuDevice* device) { + int subpart_filter_offset = part_filter_offset_flat(subpart_id); + int tot_key_filter_num = + part_filter_offset_flat(subpart_id + 1) - subpart_filter_offset; + int64 key_filter_part_offset = subpart_filter_offset * sizeof(K); + int64 version_filter_part_offset = subpart_filter_offset * sizeof(int64); + int64 freq_filter_part_offset = subpart_filter_offset * sizeof(int64); + + VLOG(1) << "key_filter_num: " << tot_key_filter_num + << ", subpart_filter_offset: " << subpart_filter_offset; + + size_t key_filter_bytes_read = 0; + size_t version_filter_bytes_read = 0; + size_t freq_filter_bytes_read = 0; + + while (tot_key_filter_num > 0) { + size_t read_key_num = + std::min(kBufferSize / sizeof(K), kBufferSize / sizeof(int64)); + read_key_num = std::min((int)read_key_num, tot_key_filter_num); + reader_->LookupSegmentOffset( + restore_args_.m_tensor_key + "_filtered", + key_filter_part_offset + key_filter_bytes_read, + read_key_num * sizeof(K), restore_buff.key_buffer, + key_filter_bytes_read); + if (!restore_args_.m_reset_version) { + reader_->LookupSegmentOffset( + restore_args_.m_tensor_version + "_filtered", + version_filter_part_offset + version_filter_bytes_read, + read_key_num * sizeof(int64), restore_buff.version_buffer, + version_filter_bytes_read); + } else { + int64* version_tmp = (int64*)restore_buff.version_buffer; + memset(version_tmp, 0, read_key_num * sizeof(int64)); + } + reader_->LookupSegmentOffset( + restore_args_.m_tensor_freq + "_filtered", + freq_filter_part_offset + freq_filter_bytes_read, + read_key_num * sizeof(int64), restore_buff.freq_buffer, + freq_filter_bytes_read); + if (key_filter_bytes_read > 0) { + read_key_num = key_filter_bytes_read / sizeof(K); + VLOG(2) << "restore, read_key_num:" << read_key_num; + Status st = storage_->RestoreFeatures( + read_key_num, kSavedPartitionNum, restore_args_.m_partition_id, + restore_args_.m_partition_num, value_len, true, restore_args_.m_is_incr, + emb_config, device, + filter_, restore_buff); + if (!st.ok()) return st; + tot_key_filter_num -= read_key_num; + } + } + return Status::OK(); +} +#define REGISTER_KERNELS(ktype, vtype) \ + template Status CheckpointLoader::EVRestoreFilteredFeatures( \ + int64, int64, RestoreBuffer&, typename TTypes::Flat, \ + const EmbeddingConfig&, const Eigen::GpuDevice*); +#define REGISTER_KERNELS_ALL_INDEX(type) \ + REGISTER_KERNELS(int32, type) \ + REGISTER_KERNELS(int64, type) +TF_CALL_FLOAT_TYPES(REGISTER_KERNELS_ALL_INDEX) +#undef REGISTER_KERNELS_ALL_INDEX +#undef REGISTER_KERNELS + +}// namespace tensorflow \ No newline at end of file diff --git a/tensorflow/core/framework/embedding/embedding_var_restore.h b/tensorflow/core/framework/embedding/embedding_var_restore.h index ec97566fbec..3016ba9eeb8 100644 --- a/tensorflow/core/framework/embedding/embedding_var_restore.h +++ b/tensorflow/core/framework/embedding/embedding_var_restore.h @@ -16,23 +16,11 @@ limitations under the License. #ifndef TENSORFLOW_CORE_FRAMEWORK_EMBEDDING_EMBEDDING_VAR_RESTORE_H_ #define TENSORFLOW_CORE_FRAMEWORK_EMBEDDING_EMBEDDING_VAR_RESTORE_H_ -#include "tensorflow/core/framework/allocator.h" -#include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/embedding/embedding_var.h" #include "tensorflow/core/framework/embedding/embedding_config.h" #include "tensorflow/core/framework/embedding/filter_policy.h" #include "tensorflow/core/framework/embedding/storage.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/register_types.h" -#include "tensorflow/core/kernels/save_restore_tensor.h" -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/random/philox_random.h" -#include "tensorflow/core/lib/random/random.h" -#include "tensorflow/core/lib/random/random_distributions.h" -#include "tensorflow/core/platform/macros.h" -#include "tensorflow/core/platform/mutex.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/tensor_bundle/tensor_bundle.h" +#include "tensorflow/core/util/env_var.h" namespace tensorflow { using GPUDevice = Eigen::GpuDevice; @@ -60,26 +48,7 @@ namespace { } // namespace template -int64 ReadRecord(BundleReader* reader, const string& record_key, K** buffer) { - TensorShape shape; - Status st; - st = reader->LookupTensorShape(record_key, &shape); - if (!st.ok()) { - LOG(FATAL) << "Restore record " << record_key << " failed"; - } - st = reader->LookupHeader(record_key, sizeof(K) * shape.dim_size(0)); - if (!st.ok()) { - LOG(FATAL) << "Restore record " << record_key << " failed"; - } - size_t bytes_read = 0; - *buffer = new K[shape.dim_size(0)]; - st = reader->LookupSegment(record_key, sizeof(K) * shape.dim_size(0), - (char*)*buffer, bytes_read); - if (!st.ok()) { - LOG(FATAL) << "Restore record " << record_key << " failed"; - } - return shape.dim_size(0); -} +int64 ReadRecord(BundleReader* reader, const string& record_key, K** buffer); template struct RestoreSSDBuffer { @@ -178,513 +147,28 @@ class CheckpointLoader { void RestoreInternal(const std::string& name_string, const EmbeddingConfig& emb_config, const Eigen::GpuDevice* device, - RestoreBuffer& restore_buff) { - Status s = EVInitTensorNameAndShape(name_string); - if (!s.ok()) { - LOG(ERROR) << "EVInitTensorNameAndShape fail:" << s.ToString(); - return; - } - - Tensor part_offset_tensor; - Tensor part_filter_offset_tensor; - if (!restore_args_.m_is_oldform) { - /****** InitPartOffsetTensor ******/ - TensorShape part_offset_shape, part_filter_offset_shape; - DataType part_offset_type, part_filter_offset_type; - string offset_tensor_name; - if (!restore_args_.m_is_incr) { - offset_tensor_name = name_string + kPartOffsetTensorSuffsix; - } else { - offset_tensor_name = name_string + kIncrPartOffsetTensorSuffsix; - } - - string offset_filter_tensor_name = - name_string + kPartFilterOffsetTensorSuffsix; - Status s = reader_->LookupDtypeAndShape( - offset_tensor_name, &part_offset_type, &part_offset_shape); - if (!s.ok()) { - LOG(ERROR) << "EV restoring fail:" << s.error_message(); - } - s = reader_->LookupDtypeAndShape(offset_filter_tensor_name, - &part_filter_offset_type, - &part_filter_offset_shape); - if (!s.ok()) { - LOG(ERROR) << "EV restoring fail: " << s.error_message(); - } - part_offset_tensor = - Tensor(cpu_allocator(), part_offset_type, part_offset_shape); - part_filter_offset_tensor = Tensor( - cpu_allocator(), part_filter_offset_type, part_filter_offset_shape); - s = reader_->Lookup(offset_tensor_name, &part_offset_tensor); - if (!s.ok()) { - LOG(ERROR) << "EV restoring fail:" << s.error_message(); - } - - s = reader_->Lookup(offset_filter_tensor_name, - &part_filter_offset_tensor); - if (!s.ok()) { - LOG(ERROR) << "EV restoring fail: " << s.error_message(); - } - } - auto part_offset_flat = part_offset_tensor.flat(); - auto part_filter_offset_flat = part_filter_offset_tensor.flat(); - - if (restore_args_.m_is_oldform) { - VLOG(1) << "old form, EV name:" << name_string - << ", partition_id:" << restore_args_.m_partition_id - << ", new partition num:" << restore_args_.m_partition_num; - int64 new_dim = ev_->ValueLen(); - TensorShape key_shape; - Status st = - reader_->LookupTensorShape(restore_args_.m_tensor_key, &key_shape); - if (!st.ok()) { - } - int tot_key_num = key_shape.dim_size(0); - Status s = EVRestoreFeatures(tot_key_num, 0, 0, 0, 0, restore_buff, - new_dim, emb_config, device); - if (!s.ok()) { - LOG(ERROR) << "EVRestoreFeaturesOld fail: " << s.error_message(); - } - } else { - int64 new_dim = ev_->ValueLen(); - VLOG(1) << "new form checkpoint... :" << name_string - << " , partition_id:" << restore_args_.m_partition_id - << " , partition_num:" << restore_args_.m_partition_num; - for (size_t i = 0; i < restore_args_.m_loaded_parts.size(); i++) { - int subpart_id = restore_args_.m_loaded_parts[i]; - size_t value_unit_bytes = sizeof(V) * restore_args_.m_old_dim; - size_t value_unit_bytes_new = sizeof(V) * new_dim; - int subpart_offset = part_offset_flat(subpart_id); - int tot_key_num = part_offset_flat(subpart_id + 1) - subpart_offset; - int64 key_part_offset = subpart_offset * sizeof(K); - int64 value_part_offset = - subpart_offset * sizeof(V) * restore_args_.m_old_dim; - int64 version_part_offset = subpart_offset * sizeof(int64); - int64 freq_part_offset = subpart_offset * sizeof(int64); - VLOG(1) << "dynamically load ev : " << name_string - << ", subpartid:" << subpart_id; - - EVRestoreFeatures(tot_key_num, key_part_offset, value_part_offset, - version_part_offset, freq_part_offset, restore_buff, - new_dim, emb_config, device); - - if (restore_args_.m_has_filter) { - Status s = EVRestoreFilteredFeatures( - subpart_id, new_dim, restore_buff, part_filter_offset_flat, - emb_config, device); - if (!s.ok()) { - LOG(ERROR) << "EVRestoreFilteredFeatures fail: " << s.error_message(); - } - } - } - } - } + RestoreBuffer& restore_buff); private: - void RestoreSSD() { - std::string name_string_temp(restore_args_.m_name_string); - std::string new_str = "_"; - int64 pos = name_string_temp.find("/"); - while (pos != std::string::npos) { - name_string_temp.replace(pos, 1, new_str.data(), 1); - pos = name_string_temp.find("/"); - } - std::string ssd_record_file_name = restore_args_.m_file_name_string + "-" + - name_string_temp + "-ssd_record"; - if (Env::Default()->FileExists(ssd_record_file_name + ".index").ok()) { - std::string ssd_emb_file_name = restore_args_.m_file_name_string + "-" + - name_string_temp + "-emb_files"; - BundleReader ssd_record_reader(Env::Default(), ssd_record_file_name); - RestoreSSDBuffer ssd_buffer(&ssd_record_reader); - VLOG(1) << "Loading SSD record... " << ssd_record_file_name; - storage_->RestoreSSD(ev_->GetEmbeddingIndex(), - ev_->GetEmbeddingSlotNum(), ev_->ValueLen(), - ssd_emb_file_name, ev_, ssd_buffer); - } - } + void RestoreSSD(); bool IsOldCheckpoint(const std::string& curr_partid_str, - const std::string& kPartOffsetTensorSuffsix) { - if (restore_args_.m_name_string.find(kPartStr) == std::string::npos) { - string tensor_name = restore_args_.m_name_string; - TensorShape part_offset_shape; - DataType part_offset_type; - Status st = - reader_->LookupDtypeAndShape(tensor_name + kPartOffsetTensorSuffsix, - &part_offset_type, &part_offset_shape); - if (st.ok()) return false; - - string part_id = std::to_string(0); - tensor_name = restore_args_.m_name_string + "/" + kPartStr + part_id; - - Status form_st = - reader_->LookupDtypeAndShape(tensor_name + kPartOffsetTensorSuffsix, - &part_offset_type, &part_offset_shape); - if (form_st.ok()) return false; - } else { - string part_id = std::to_string(0); - size_t part_pos = restore_args_.m_name_string.find(kPartStr); - size_t part_size = strlen(kPartStr); - size_t cur_part_size = curr_partid_str.size(); - - string pre_subname = restore_args_.m_name_string.substr(0, part_pos); - string post_subname = restore_args_.m_name_string.substr( - part_pos + part_size + cur_part_size); - string tensor_name = pre_subname + kPartStr + part_id + post_subname; - - TensorShape part_offset_shape; - DataType part_offset_type; - Status form_st = - reader_->LookupDtypeAndShape(tensor_name + kPartOffsetTensorSuffsix, - &part_offset_type, &part_offset_shape); - if (form_st.ok()) return false; - pre_subname = - restore_args_.m_name_string.substr(0, part_pos - 1); /* var1*/ - post_subname = restore_args_.m_name_string.substr(part_pos + part_size + - cur_part_size); - tensor_name = pre_subname + post_subname; - - Status st = - reader_->LookupDtypeAndShape(tensor_name + kPartOffsetTensorSuffsix, - &part_offset_type, &part_offset_shape); - if (st.ok()) return false; - } - - return true; - } - - void InitPartNumAndLoadedParts(std::vector& tensor_name_vec) { - std::string tmp_key_suffix; - std::string tmp_kPartOffsetTensorSuffsix; - if (!restore_args_.m_is_incr) { - tmp_key_suffix = kKeySuffix; - tmp_kPartOffsetTensorSuffsix = kPartOffsetTensorSuffsix; - } else { - tmp_key_suffix = kIncrKeySuffix; - tmp_kPartOffsetTensorSuffsix = kIncrPartOffsetTensorSuffsix; - } - - restore_args_.m_loaded_parts.reserve(kSavedPartitionNum); - int orig_partnum = 0; - const string& curr_partid_str = std::to_string(restore_args_.m_partition_id); - size_t part_pos = restore_args_.m_name_string.find(kPartStr); - - if (IsOldCheckpoint(curr_partid_str, tmp_kPartOffsetTensorSuffsix)) { - restore_args_.m_is_oldform = true; - } - - if (part_pos == std::string::npos) { - for (;; orig_partnum++) { - string part_id = std::to_string(orig_partnum); - string tensor_name = - restore_args_.m_name_string + "/" + kPartStr + part_id; - string tensor_key = tensor_name + tmp_key_suffix; - TensorShape key_shape; - Status st = reader_->LookupTensorShape(tensor_key, &key_shape); - if (!st.ok()) { - break; - } - tensor_name_vec.emplace_back(tensor_name); - } - if (orig_partnum == 0) { - tensor_name_vec.emplace_back(restore_args_.m_name_string); - } - for (int i = 0; i < kSavedPartitionNum; ++i) { - restore_args_.m_loaded_parts.push_back(i); - } - } else { - for (;; orig_partnum++) { - string part_id = std::to_string(orig_partnum); - string pre_subname = restore_args_.m_name_string.substr(0, part_pos); - string post_subname = restore_args_.m_name_string.substr( - part_pos + strlen(kPartStr) + curr_partid_str.size()); - string tensor_name = pre_subname + kPartStr + part_id + post_subname; - string tensor_key = tensor_name + tmp_key_suffix; - TensorShape key_shape; - Status st = reader_->LookupTensorShape(tensor_key, &key_shape); - if (!st.ok()) { - break; - } - tensor_name_vec.emplace_back(tensor_name); - } - if (orig_partnum == 0) { - string pre_subname = restore_args_.m_name_string.substr(0, part_pos - 1); - string post_subname = restore_args_.m_name_string.substr( - part_pos + strlen(kPartStr) + curr_partid_str.size()); - string tmp_name = pre_subname + post_subname; - tensor_name_vec.emplace_back(tmp_name); - } - for (int i = 0; i < kSavedPartitionNum; i++) { - if (i % restore_args_.m_partition_num == restore_args_.m_partition_id) { - restore_args_.m_loaded_parts.push_back(i); - } - } - } - for (auto& tensor_name : tensor_name_vec) { - VLOG(1) << "**** " << restore_args_.m_name_string << " " << tensor_name - << " ****"; - } - } + const std::string& kPartOffsetTensorSuffsix); - Status EVInitTensorNameAndShape(const std::string& tensor_name) { - if (!restore_args_.m_is_incr) { - restore_args_.m_tensor_key = tensor_name + kKeySuffix; - restore_args_.m_tensor_value = tensor_name + kValueSuffix; - restore_args_.m_tensor_version = tensor_name + kVersionSuffix; - restore_args_.m_tensor_freq = tensor_name + kFreqSuffix; - } else { - restore_args_.m_tensor_key = tensor_name + kIncrKeySuffix; - restore_args_.m_tensor_value = tensor_name + kIncrValueSuffix; - restore_args_.m_tensor_version = tensor_name + kIncrVersionSuffix; - restore_args_.m_tensor_freq = tensor_name + kIncrFreqSuffix; - } + void InitPartNumAndLoadedParts(std::vector& tensor_name_vec); - TensorShape key_shape, value_shape, version_shape, freq_shape; - - Status st = - reader_->LookupTensorShape(restore_args_.m_tensor_key, &key_shape); - if (!st.ok()) { - return st; - } - st = reader_->LookupTensorShape(restore_args_.m_tensor_value, &value_shape); - if (!st.ok()) { - return st; - } - st = reader_->LookupTensorShape(restore_args_.m_tensor_version, - &version_shape); - if (!st.ok()) { - return st; - } - st = reader_->LookupHeader(restore_args_.m_tensor_key, - sizeof(K) * key_shape.dim_size(0)); - if (!st.ok()) { - return st; - } - st = reader_->LookupHeader(restore_args_.m_tensor_value, - sizeof(V) * value_shape.dim_size(0) * - value_shape.dim_size(1)); - if (!st.ok()) { - return st; - } - st = reader_->LookupHeader(restore_args_.m_tensor_version, - sizeof(int64) * version_shape.dim_size(0)); - if (!st.ok()) { - return st; - } - st = reader_->LookupTensorShape(restore_args_.m_tensor_freq, &freq_shape); - if (!st.ok()) { - if (st.code() == error::NOT_FOUND) { - freq_shape = version_shape; - } else { - return st; - } - } - st = reader_->LookupHeader(restore_args_.m_tensor_freq, - sizeof(int64) * freq_shape.dim_size(0)); - if (!st.ok()) { - if (st.code() == error::NOT_FOUND) { - restore_args_.m_has_freq = false; - } else { - return st; - } - } - restore_args_.m_old_dim = value_shape.dim_size(1); - - if (!restore_args_.m_is_oldform) { - TensorShape key_filter_shape, version_filter_shape, freq_filter_shape; - st = reader_->LookupTensorShape(restore_args_.m_tensor_key + "_filtered", - &key_filter_shape); - if (!st.ok()) { - if (st.code() == error::NOT_FOUND) { - key_filter_shape = key_shape; - restore_args_.m_has_filter = false; - } else { - return st; - } - } - st = reader_->LookupTensorShape( - restore_args_.m_tensor_version + "_filtered", &version_filter_shape); - if ((!st.ok()) && (st.code() != error::NOT_FOUND)) { - return st; - } - st = reader_->LookupHeader(restore_args_.m_tensor_key + "_filtered", - sizeof(K) * key_filter_shape.dim_size(0)); - if (!st.ok()) { - if (st.code() == error::NOT_FOUND) { - restore_args_.m_has_filter = false; - } else { - return st; - } - } - st = reader_->LookupHeader(restore_args_.m_tensor_version + "_filtered", - sizeof(K) * version_filter_shape.dim_size(0)); - if (!st.ok()) { - return st; - } - st = reader_->LookupTensorShape(restore_args_.m_tensor_freq + "_filtered", - &freq_filter_shape); - if (!st.ok()) { - if (st.code() == error::NOT_FOUND) { - freq_filter_shape = freq_shape; - } else { - return st; - } - } - - st = reader_->LookupHeader(restore_args_.m_tensor_freq + "_filtered", - sizeof(K) * freq_filter_shape.dim_size(0)); - if (!st.ok() && st.code() != error::NOT_FOUND) { - return st; - } - } - return st; - } + Status EVInitTensorNameAndShape(const std::string& tensor_name); Status EVRestoreFeatures(int tot_key_num, int64 key_part_offset, int64 value_part_offset, int64 version_part_offset, int64 freq_part_offset, RestoreBuffer& restore_buff, int64 new_dim, const EmbeddingConfig& emb_config, - const Eigen::GpuDevice* device) { - size_t value_unit_bytes = sizeof(V) * restore_args_.m_old_dim; - size_t value_unit_bytes_new = sizeof(V) * new_dim; - int64 tot_key_bytes_read(0); - int64 tot_value_bytes_read(0); - int64 tot_version_bytes_read(0); - int64 tot_freq_bytes_read(0); - size_t key_bytes_read = 0; - size_t value_bytes_read = 0; - size_t version_bytes_read = 0; - size_t freq_bytes_read = 0; - - while (tot_key_num > 0) { - size_t read_key_num = std::min( - std::min(kBufferSize / sizeof(K), kBufferSize / value_unit_bytes), - kBufferSize / sizeof(int64)); - read_key_num = std::min(read_key_num, kBufferSize / value_unit_bytes_new); - read_key_num = std::min((int)read_key_num, tot_key_num); - reader_->LookupSegmentOffset( - restore_args_.m_tensor_key, key_part_offset + tot_key_bytes_read, - read_key_num * sizeof(K), restore_buff.key_buffer, key_bytes_read); - reader_->LookupSegmentOffset( - restore_args_.m_tensor_value, value_part_offset + tot_value_bytes_read, - read_key_num * value_unit_bytes, restore_buff.value_buffer, - value_bytes_read); - if (!restore_args_.m_reset_version) { - reader_->LookupSegmentOffset( - restore_args_.m_tensor_version, - version_part_offset + tot_version_bytes_read, - read_key_num * sizeof(int64), restore_buff.version_buffer, - version_bytes_read); - if (version_bytes_read == 0) { - memset(restore_buff.version_buffer, -1, sizeof(int64) * read_key_num); - } - } else { - int64* version_tmp = (int64*)restore_buff.version_buffer; - memset(version_tmp, 0, read_key_num * sizeof(int64)); - } - - if (restore_args_.m_has_freq) { - reader_->LookupSegmentOffset( - restore_args_.m_tensor_freq, freq_part_offset + tot_freq_bytes_read, - read_key_num * sizeof(int64), restore_buff.freq_buffer, - freq_bytes_read); - if (freq_bytes_read == 0) { - int64* freq_tmp = (int64*)restore_buff.freq_buffer; - for (int64 i = 0; i < read_key_num; i++) { - freq_tmp[i] = (ev_->MinFreq() == 0) ? 1 : ev_->MinFreq(); - } - } - } else { - int64* freq_tmp = (int64*)restore_buff.freq_buffer; - for (int64 i = 0; i < read_key_num; i++) { - freq_tmp[i] = (ev_->MinFreq() == 0) ? 1 : ev_->MinFreq(); - } - } - if (key_bytes_read > 0) { - read_key_num = key_bytes_read / sizeof(K); - Status st = RestoreCustomDim(new_dim, read_key_num, value_unit_bytes, - value_bytes_read, value_unit_bytes_new, - restore_buff); - if (!st.ok()) { - LOG(FATAL) << "EV Restore fail:" << st.ToString(); - } - - st = storage_->RestoreFeatures( - read_key_num, kSavedPartitionNum, restore_args_.m_partition_id, - restore_args_.m_partition_num, new_dim, false, restore_args_.m_is_incr, - emb_config, device, - filter_, restore_buff); - if (!st.ok()) { - LOG(FATAL) << "EV Restore fail:" << st.ToString(); - } - } - - tot_key_num -= read_key_num; - tot_key_bytes_read += key_bytes_read; - tot_value_bytes_read += value_bytes_read; - tot_version_bytes_read += version_bytes_read; - tot_freq_bytes_read += freq_bytes_read; - } - - return Status::OK(); - } + const Eigen::GpuDevice* device); Status EVRestoreFilteredFeatures( int64 subpart_id, int64 value_len, RestoreBuffer& restore_buff, typename TTypes::Flat part_filter_offset_flat, - const EmbeddingConfig& emb_config, const Eigen::GpuDevice* device) { - int subpart_filter_offset = part_filter_offset_flat(subpart_id); - int tot_key_filter_num = - part_filter_offset_flat(subpart_id + 1) - subpart_filter_offset; - int64 key_filter_part_offset = subpart_filter_offset * sizeof(K); - int64 version_filter_part_offset = subpart_filter_offset * sizeof(int64); - int64 freq_filter_part_offset = subpart_filter_offset * sizeof(int64); - - VLOG(1) << "key_filter_num: " << tot_key_filter_num - << ", subpart_filter_offset: " << subpart_filter_offset; - - size_t key_filter_bytes_read = 0; - size_t version_filter_bytes_read = 0; - size_t freq_filter_bytes_read = 0; - - while (tot_key_filter_num > 0) { - size_t read_key_num = - std::min(kBufferSize / sizeof(K), kBufferSize / sizeof(int64)); - read_key_num = std::min((int)read_key_num, tot_key_filter_num); - reader_->LookupSegmentOffset( - restore_args_.m_tensor_key + "_filtered", - key_filter_part_offset + key_filter_bytes_read, - read_key_num * sizeof(K), restore_buff.key_buffer, - key_filter_bytes_read); - if (!restore_args_.m_reset_version) { - reader_->LookupSegmentOffset( - restore_args_.m_tensor_version + "_filtered", - version_filter_part_offset + version_filter_bytes_read, - read_key_num * sizeof(int64), restore_buff.version_buffer, - version_filter_bytes_read); - } else { - int64* version_tmp = (int64*)restore_buff.version_buffer; - memset(version_tmp, 0, read_key_num * sizeof(int64)); - } - reader_->LookupSegmentOffset( - restore_args_.m_tensor_freq + "_filtered", - freq_filter_part_offset + freq_filter_bytes_read, - read_key_num * sizeof(int64), restore_buff.freq_buffer, - freq_filter_bytes_read); - if (key_filter_bytes_read > 0) { - read_key_num = key_filter_bytes_read / sizeof(K); - VLOG(2) << "restore, read_key_num:" << read_key_num; - Status st = storage_->RestoreFeatures( - read_key_num, kSavedPartitionNum, restore_args_.m_partition_id, - restore_args_.m_partition_num, value_len, true, restore_args_.m_is_incr, - emb_config, device, - filter_, restore_buff); - if (!st.ok()) return st; - tot_key_filter_num -= read_key_num; - } - } - return Status::OK(); - } + const EmbeddingConfig& emb_config, const Eigen::GpuDevice* device); Status RestoreCustomDim(int new_dim, int read_key_num, size_t value_unit_bytes, size_t value_bytes_read, diff --git a/tensorflow/core/framework/embedding/kv_interface.h b/tensorflow/core/framework/embedding/kv_interface.h index 71667cf0917..5d1f20b581a 100644 --- a/tensorflow/core/framework/embedding/kv_interface.h +++ b/tensorflow/core/framework/embedding/kv_interface.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_CORE_FRAMEWORK_EMBEDDING_KV_INTERFACE_H_ #define TENSORFLOW_CORE_FRAMEWORK_EMBEDDING_KV_INTERFACE_H_ +#include "tensorflow/core/framework/device_base.h" #include "tensorflow/core/lib/core/status.h" namespace tensorflow { @@ -29,6 +30,7 @@ class ValuePtr; template class GPUHashTable; +using GPUDevice = Eigen::GpuDevice; namespace embedding { template @@ -90,15 +92,15 @@ class KVInterface { virtual Status BatchLookupOrCreate(const K* keys, V* val, V* default_v, int32 default_v_num, - size_t n, const Eigen::GpuDevice& device) { + size_t n, const GPUDevice& device) { return Status::OK(); } virtual Status BatchLookupOrCreateKeys(const K* keys, size_t n, - int32* item_idxs, const Eigen::GpuDevice& device) { + int32* item_idxs, const GPUDevice& device) { return Status::OK(); } - virtual Status BatchLookup(const Eigen::GpuDevice& device, + virtual Status BatchLookup(const GPUDevice& device, const K* keys, V* val, size_t n, const V* default_v) { return Status(error::Code::UNIMPLEMENTED, "Unimplemented for BatchLookup in KVInterface."); diff --git a/tensorflow/core/framework/embedding/ssd_record_descriptor.cc b/tensorflow/core/framework/embedding/ssd_record_descriptor.cc new file mode 100644 index 00000000000..b224b24e856 --- /dev/null +++ b/tensorflow/core/framework/embedding/ssd_record_descriptor.cc @@ -0,0 +1,88 @@ +/* Copyright 2022 The DeepRec Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +#include "tensorflow/core/framework/embedding/ssd_record_descriptor.h" +#include "tensorflow/core/kernels/save_restore_tensor.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/io/path.h" + +namespace tensorflow { +namespace embedding { +template +template +void SsdRecordDescriptor::DumpSection( + const std::vector& data_vec, + const std::string& section_str, + BundleWriter* writer, + std::vector& dump_buffer) { + EVVectorDataDumpIterator iter(data_vec); + SaveTensorWithFixedBuffer( + section_str, + writer, dump_buffer.data(), + dump_buffer.size(), &iter, + TensorShape({data_vec.size()})); +} +#define REGISTER_KERNELS(ktype, ttype) \ + template void SsdRecordDescriptor::DumpSection( \ + const std::vector&, const std::string&, \ + BundleWriter*, std::vector&); +REGISTER_KERNELS(int32, int32); +REGISTER_KERNELS(int32, int64); +REGISTER_KERNELS(int64, int32); +REGISTER_KERNELS(int64, int64); +#undef REGISTER_KERNELS + +template +void SsdRecordDescriptor::DumpSsdMeta( + const std::string& prefix, + const std::string& var_name) { + std::fstream fs; + std::string var_name_temp(var_name); + std::string new_str = "_"; + int64 pos = var_name_temp.find("/"); + while (pos != std::string::npos) { + var_name_temp.replace(pos, 1, new_str.data(), 1); + pos = var_name_temp.find("/"); + } + + std::string ssd_record_path = + prefix + "-" + var_name_temp + "-ssd_record"; + BundleWriter ssd_record_writer(Env::Default(), + ssd_record_path); + size_t bytes_limit = 8 << 20; + std::vector dump_buffer(bytes_limit); + + DumpSection(key_list, "keys", + &ssd_record_writer, dump_buffer); + DumpSection(key_file_id_list, "keys_file_id", + &ssd_record_writer, dump_buffer); + DumpSection(key_offset_list, "keys_offset", + &ssd_record_writer, dump_buffer); + DumpSection(file_list, "files", + &ssd_record_writer, dump_buffer); + DumpSection(invalid_record_count_list, "invalid_record_count", + &ssd_record_writer, dump_buffer); + DumpSection(record_count_list, "record_count", + &ssd_record_writer, dump_buffer); + + ssd_record_writer.Finish(); +} +#define REGISTER_KERNELS(ktype) \ + template void SsdRecordDescriptor::DumpSsdMeta( \ + const std::string&, const std::string&); +REGISTER_KERNELS(int32); +REGISTER_KERNELS(int64); +#undef REGISTER_KERNELS +}//namespace embedding +}//namespace tensorflow diff --git a/tensorflow/core/framework/embedding/ssd_record_descriptor.h b/tensorflow/core/framework/embedding/ssd_record_descriptor.h index 9d015236934..aeb8d324759 100644 --- a/tensorflow/core/framework/embedding/ssd_record_descriptor.h +++ b/tensorflow/core/framework/embedding/ssd_record_descriptor.h @@ -20,14 +20,13 @@ limitations under the License. #include #include #include - +#include #include "tensorflow/core/framework/embedding/embedding_var_dump_iterator.h" #include "tensorflow/core/framework/embedding/kv_interface.h" -#include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/io/path.h" -#include "tensorflow/core/util/env_var.h" +#include "tensorflow/core/platform/env.h" namespace tensorflow { +class BundleWriter; namespace embedding { template @@ -59,48 +58,10 @@ class SsdRecordDescriptor { void DumpSection(const std::vector& data_vec, const std::string& section_str, BundleWriter* writer, - std::vector& dump_buffer) { - EVVectorDataDumpIterator iter(data_vec); - SaveTensorWithFixedBuffer( - section_str, - writer, dump_buffer.data(), - dump_buffer.size(), &iter, - TensorShape({data_vec.size()})); - } + std::vector& dump_buffer); void DumpSsdMeta(const std::string& prefix, - const std::string& var_name) { - std::fstream fs; - std::string var_name_temp(var_name); - std::string new_str = "_"; - int64 pos = var_name_temp.find("/"); - while (pos != std::string::npos) { - var_name_temp.replace(pos, 1, new_str.data(), 1); - pos = var_name_temp.find("/"); - } - - std::string ssd_record_path = - prefix + "-" + var_name_temp + "-ssd_record"; - BundleWriter ssd_record_writer(Env::Default(), - ssd_record_path); - size_t bytes_limit = 8 << 20; - std::vector dump_buffer(bytes_limit); - - DumpSection(key_list, "keys", - &ssd_record_writer, dump_buffer); - DumpSection(key_file_id_list, "keys_file_id", - &ssd_record_writer, dump_buffer); - DumpSection(key_offset_list, "keys_offset", - &ssd_record_writer, dump_buffer); - DumpSection(file_list, "files", - &ssd_record_writer, dump_buffer); - DumpSection(invalid_record_count_list, "invalid_record_count", - &ssd_record_writer, dump_buffer); - DumpSection(record_count_list, "record_count", - &ssd_record_writer, dump_buffer); - - ssd_record_writer.Finish(); - } + const std::string& var_name); void CopyEmbeddingFilesToCkptDir( const std::string& prefix, diff --git a/tensorflow/core/framework/embedding/storage.h b/tensorflow/core/framework/embedding/storage.h index d212e5b9c77..bb949183492 100644 --- a/tensorflow/core/framework/embedding/storage.h +++ b/tensorflow/core/framework/embedding/storage.h @@ -26,7 +26,6 @@ limitations under the License. #include "tensorflow/core/framework/embedding/storage_config.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/util/tensor_bundle/tensor_bundle.h" #include "tensorflow/core/util/work_sharder.h" #include "tensorflow/core/framework/device_base.h" #if GOOGLE_CUDA @@ -53,6 +52,9 @@ struct SsdRecordDescriptor; template class GPUHashTable; +class BundleWriter; +class BundleReader; + template struct EmbeddingVarContext; namespace { diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index fc1b2cd9c67..115e3c4bae6 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -2907,7 +2907,10 @@ tf_kernel_library( hdrs = ["kv_variable_ops.h"], srcs = ["kv_variable_ops.cc", "kv_variable_lookup_ops.cc", - "kv_variable_restore_ops.cc"], + "kv_variable_restore_ops.cc", + "//tensorflow/core:framework/embedding/embedding_var_ckpt_data.cc", + "//tensorflow/core:framework/embedding/embedding_var_restore.cc", + "//tensorflow/core:framework/embedding/ssd_record_descriptor.cc"], copts = tf_copts() + ["-g"], deps = [ ":bounds_check",