diff --git a/tensorflow/contrib/elastic_grpc_server/elastic_grpc_server_lib.cc b/tensorflow/contrib/elastic_grpc_server/elastic_grpc_server_lib.cc index d45d70d6c8c..66e237956e5 100644 --- a/tensorflow/contrib/elastic_grpc_server/elastic_grpc_server_lib.cc +++ b/tensorflow/contrib/elastic_grpc_server/elastic_grpc_server_lib.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include #include "include/json/json.h" #include "grpc/support/alloc.h" #include "grpcpp/grpcpp.h" @@ -89,7 +90,7 @@ Status ElasticGrpcServer::UpdateServerDef(const string& cluster_def_str, int& be return errors::Internal("PARSE TF_CONFIG/cluster ERROR"); } - std::unordered_set ps_addrs_vec; + std::set ps_addrs_vec; //ordered after_part_num = cluster_json["cluster"]["ps"].size(); for (auto& value: cluster_json["cluster"]["ps"]) { ps_addrs_vec.emplace(value.asString()); @@ -111,21 +112,25 @@ Status ElasticGrpcServer::UpdateServerDef(const string& cluster_def_str, int& be } for (auto ps_addr: ps_addrs_vec) { if (target_string_set.find(ps_addr) == target_string_set.end()) { - job->mutable_tasks()->insert({idx, ps_addr}); + job->mutable_tasks()->insert({idx++, ps_addr}); tf_config_json["cluster"]["ps"].append(ps_addr); } } break; } else { LOG(INFO) << "SCALING DOWN, partition_num is: " << after_part_num; + google::protobuf::Map< google::protobuf::int32, std::string > tasks; + Json::Value arr_value(Json::arrayValue); + int idx = 0; for (int i = 0; i < before_part_num; ++i) { string tmp_string = tf_config_json["cluster"]["ps"][i].asString(); - if (ps_addrs_vec.find(tmp_string) == ps_addrs_vec.end()) { - Json::Value ps_addr; - tf_config_json["cluster"]["ps"].removeIndex(i, &ps_addr); - job->mutable_tasks()->erase(i); + if (ps_addrs_vec.find(tmp_string) != ps_addrs_vec.end()) { + arr_value.append(tf_config_json["cluster"]["ps"][i]); + tasks[idx++] = tmp_string; } } + tf_config_json["cluster"]["ps"].swap(arr_value); + job->mutable_tasks()->swap(tasks); } } } diff --git a/tensorflow/contrib/elastic_grpc_server/elastic_service.cc b/tensorflow/contrib/elastic_grpc_server/elastic_service.cc index 61aa6e662ec..59f7fa473bd 100644 --- a/tensorflow/contrib/elastic_grpc_server/elastic_service.cc +++ b/tensorflow/contrib/elastic_grpc_server/elastic_service.cc @@ -24,7 +24,7 @@ limitations under the License. #include #include "grpcpp/server_builder.h" -using namespace des; +using namespace deeprec; using grpc::Server; using grpc::ServerAsyncResponseWriter; diff --git a/tensorflow/core/protobuf/elastic_training.proto b/tensorflow/core/protobuf/elastic_training.proto index ee0d0bd10e0..b6af4b139cf 100644 --- a/tensorflow/core/protobuf/elastic_training.proto +++ b/tensorflow/core/protobuf/elastic_training.proto @@ -1,6 +1,6 @@ syntax = "proto3"; -package des; +package deeprec; enum Code { OK = 0; diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index a740e0916d9..f9cc74743be 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -4747,6 +4747,7 @@ py_library( ":platform", ":protos_all_py", ":session_run_hook", + "//tensorflow/core:elastic_service_pb_py", ":training_util", ":util", ],