diff --git a/configure.py b/configure.py index 6aeaf7d12af..4fb1c78c40b 100644 --- a/configure.py +++ b/configure.py @@ -1434,7 +1434,7 @@ def main(): True, 'star') set_build_var(environ_cp, 'TF_NEED_ELASTIC', 'ELASTIC TRAINING', 'with_elastic_support', - True, 'elastic') + False, 'elastic') set_build_var(environ_cp, 'TF_ENABLE_PMEM', 'PMEM', 'with_pmem_support', False, 'pmem') diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 0531200e7ab..ef1ebcb6dcf 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -128,6 +128,7 @@ load( "tf_additional_numa_deps", "tf_additional_numa_lib_defines", "tf_additional_star_lib_defines", + "tf_additional_elastic_server_lib_defines", "tf_additional_api_compatible_defines", "tf_additional_pmem_lib_defines", "tf_additional_test_deps", @@ -1441,6 +1442,7 @@ tf_cc_test( cc_library( name = "ops", visibility = ["//visibility:public"], + defines = tf_additional_elastic_server_lib_defines(), deps = [ ":array_ops_op_lib", ":parquet_ops_op_lib", @@ -2562,7 +2564,8 @@ LIB_INTERNAL_DEFINES = ( tf_additional_gdr_lib_defines() + tf_additional_numa_lib_defines() + tf_additional_star_lib_defines() + - tf_additional_pmem_lib_defines() + tf_additional_pmem_lib_defines() + + tf_additional_elastic_server_lib_defines() ) cc_library( diff --git a/tensorflow/core/kernels/data/BUILD b/tensorflow/core/kernels/data/BUILD index 08445403b58..6878c5f8350 100644 --- a/tensorflow/core/kernels/data/BUILD +++ b/tensorflow/core/kernels/data/BUILD @@ -9,6 +9,11 @@ load( "transitive_hdrs", ) +load( + "//tensorflow/core/platform:default/build_config.bzl", + "tf_additional_elastic_server_lib_defines", +) + package( default_visibility = ["//visibility:public"], licenses = ["notice"], # Apache 2.0 @@ -1119,6 +1124,7 @@ tf_kernel_library( name = "iterator_ops", srcs = ["iterator_ops.cc"], hdrs = ["iterator_ops.h"], + defines = tf_additional_elastic_server_lib_defines(), deps = [ ":captured_function", ":dataset_utils", diff --git a/tensorflow/core/kernels/data/iterator_ops.cc b/tensorflow/core/kernels/data/iterator_ops.cc index 68bf172268d..ed6b40a38a0 100644 --- a/tensorflow/core/kernels/data/iterator_ops.cc +++ b/tensorflow/core/kernels/data/iterator_ops.cc @@ -269,7 +269,6 @@ IteratorHandleOp::IteratorHandleOp(OpKernelConstruction* ctx) OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_dtypes_)); OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("shared_name", &name_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("recoverable", &recoverable_)); } // The resource is deleted from the resource manager only when it is private @@ -309,11 +308,11 @@ void IteratorHandleOp::Compute(OpKernelContext* context) LOCKS_EXCLUDED(mu_) { } ResourceMgr* mgr = context->resource_manager(); - if (recoverable_ == false) { - OP_REQUIRES_OK(context, cinfo_.Init(mgr, def(), false)); - } else { - OP_REQUIRES_OK(context, cinfo_.Init(mgr, def(), true)); - } +#ifdef TENSORFLOW_USE_ELASTIC_SERVER + OP_REQUIRES_OK(context, cinfo_.Init(mgr, def(), true)); +#else + OP_REQUIRES_OK(context, cinfo_.Init(mgr, def(), false)); +#endif IteratorResource* resource; OP_REQUIRES_OK( @@ -788,7 +787,11 @@ class OneShotIteratorOp : public AsyncOpKernel { Status TryInit(OpKernelContext* ctx, IteratorResource** iterator, ContainerInfo* cinfo) { - TF_RETURN_IF_ERROR(cinfo->Init(ctx->resource_manager(), def(), true)); +#ifdef TENSORFLOW_USE_ELASTIC_SERVER + TF_RETURN_IF_ERROR(cinfo->Init(ctx->resource_manager(), def(), true)); +#else + TF_RETURN_IF_ERROR(cinfo->Init(ctx->resource_manager(), def(), false)); +#endif FunctionLibraryRuntime* flr; std::unique_ptr flib_def(nullptr); diff --git a/tensorflow/core/kernels/data/iterator_ops.h b/tensorflow/core/kernels/data/iterator_ops.h index 7277a8ac652..07b88d4ccc3 100644 --- a/tensorflow/core/kernels/data/iterator_ops.h +++ b/tensorflow/core/kernels/data/iterator_ops.h @@ -135,7 +135,6 @@ class IteratorHandleOp : public OpKernel { std::vector output_shapes_; const int graph_def_version_; string name_; - bool recoverable_; }; // Like IteratorHandleOp, but creates handles which are never shared, and does diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc index 6fe19a1471d..3ed48c17224 100644 --- a/tensorflow/core/ops/dataset_ops.cc +++ b/tensorflow/core/ops/dataset_ops.cc @@ -555,24 +555,13 @@ REGISTER_OP("Iterator") .Attr("output_shapes: list(shape) >= 1") .SetShapeFn(shape_inference::ScalarShape); -#ifndef TF_API_COMPATIBLE_1150 REGISTER_OP("IteratorV2") .Output("handle: resource") .Attr("shared_name: string") .Attr("container: string") - .Attr("recoverable: bool = false") .Attr("output_types: list(type) >= 1") .Attr("output_shapes: list(shape) >= 1") .SetShapeFn(shape_inference::ScalarShape); -#else -REGISTER_OP("IteratorV2") - .Output("handle: resource") - .Attr("shared_name: string") - .Attr("container: string") - .Attr("output_types: list(type) >= 1") - .Attr("output_shapes: list(shape) >= 1") - .SetShapeFn(shape_inference::ScalarShape); -#endif REGISTER_OP("AnonymousIterator") .Output("handle: resource")