From 7eec676e2d844a5f68f80d2a905d44e8d8634704 Mon Sep 17 00:00:00 2001 From: luliyucoordinate Date: Sun, 14 Aug 2022 12:24:45 +0800 Subject: [PATCH 001/376] Fix Tensor constructor with DataType --- tensorflow/core/framework/tensor.cc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tensorflow/core/framework/tensor.cc b/tensorflow/core/framework/tensor.cc index d57424a9a3a827..ade2cf67ecf48c 100644 --- a/tensorflow/core/framework/tensor.cc +++ b/tensorflow/core/framework/tensor.cc @@ -696,7 +696,9 @@ void UnrefIfNonNull(core::RefCounted* buf) { Tensor::Tensor() : Tensor(DT_FLOAT) {} -Tensor::Tensor(DataType type) : shape_(type), buf_(nullptr) {} +Tensor::Tensor(DataType type) : shape_(TensorShape({})), buf_(nullptr) { + set_dtype(type); +} Tensor::Tensor(DataType type, const TensorShape& shape, TensorBuffer* buf) : shape_(shape), buf_(buf) { From 6a2abbbb336d2b7a63ab42627137b5cb4737510d Mon Sep 17 00:00:00 2001 From: Kun-Lu Date: Tue, 14 Feb 2023 17:29:28 -0500 Subject: [PATCH 002/376] Enable secure grpc++/grpc and ssl features on s390x Signed-off-by: Kun-Lu --- tensorflow/BUILD | 2 -- tensorflow/tsl/BUILD | 1 - tensorflow/tsl/platform/default/build_config.bzl | 2 -- 3 files changed, 5 deletions(-) diff --git a/tensorflow/BUILD b/tensorflow/BUILD index d7414c63067132..c4e64e46f0c93f 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -1086,7 +1086,6 @@ cc_library( name = "grpc", visibility = ["//visibility:public"], deps = select({ - ":linux_s390x": ["@com_github_grpc_grpc//:grpc_unsecure"], "//conditions:default": ["@com_github_grpc_grpc//:grpc"], }), ) @@ -1095,7 +1094,6 @@ cc_library( name = "grpc++", visibility = ["//visibility:public"], deps = select({ - ":linux_s390x": ["@com_github_grpc_grpc//:grpc++_unsecure"], "//conditions:default": ["@com_github_grpc_grpc//:grpc++"], }), ) diff --git a/tensorflow/tsl/BUILD b/tensorflow/tsl/BUILD index 111bd65188337b..076836bc5418d7 100644 --- a/tensorflow/tsl/BUILD +++ b/tensorflow/tsl/BUILD @@ -467,7 +467,6 @@ cc_library( name = "grpc++", visibility = ["//visibility:public"], deps = select({ - ":linux_s390x": ["@com_github_grpc_grpc//:grpc++_unsecure"], "//conditions:default": ["@com_github_grpc_grpc//:grpc++"], }), ) diff --git a/tensorflow/tsl/platform/default/build_config.bzl b/tensorflow/tsl/platform/default/build_config.bzl index c6b49ba7ca4edd..256597daf36e31 100644 --- a/tensorflow/tsl/platform/default/build_config.bzl +++ b/tensorflow/tsl/platform/default/build_config.bzl @@ -243,7 +243,6 @@ def cc_proto_library( if use_grpc_plugin: cc_libs += select({ - clean_dep("//tensorflow/tsl:linux_s390x"): ["//external:grpc_lib_unsecure"], "//conditions:default": ["//external:grpc_lib"], }) @@ -326,7 +325,6 @@ def cc_grpc_library( proto_targets += srcs extra_deps += select({ - clean_dep("//tensorflow/tsl:linux_s390x"): ["//external:grpc_lib_unsecure"], "//conditions:default": ["//external:grpc_lib"], }) From 6c23541ce7100c6b824a0531090ebf5fd337ff14 Mon Sep 17 00:00:00 2001 From: Kun-Lu Date: Thu, 16 Feb 2023 09:21:58 -0500 Subject: [PATCH 003/376] Set system_lib as the default option for ssl on s390x Signed-off-by: Kun-Lu --- configure.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/configure.py b/configure.py index 6abde63a28aa0d..91e8d5d380d7f8 100644 --- a/configure.py +++ b/configure.py @@ -83,6 +83,8 @@ def is_macos(): def is_ppc64le(): return platform.machine() == 'ppc64le' +def is_s390x(): + return platform.machine() == 's390x' def is_cygwin(): return platform.system().startswith('CYGWIN_NT') @@ -1007,6 +1009,10 @@ def system_specific_test_config(environ_cp): def set_system_libs_flag(environ_cp): syslibs = environ_cp.get('TF_SYSTEM_LIBS', '') + + if is_s390x() and "boringssl" not in syslibs: + syslibs = "boringssl" + (", " + syslibs if syslibs != "" else "") + if syslibs: if ',' in syslibs: syslibs = ','.join(sorted(syslibs.split(','))) From 20fee9c498f0ad2cb4abb3b60f160164b501a984 Mon Sep 17 00:00:00 2001 From: Mahmoud Abuzaina Date: Wed, 19 Apr 2023 11:35:10 -0700 Subject: [PATCH 004/376] Enabled oneDNNv3.1 in INT8 Conv --- .../core/common_runtime/mkl_layout_pass.cc | 6 + tensorflow/core/kernels/mkl/mkl_conv_ops.cc | 511 +++++++++++++++--- .../mkl/mkl_quantized_conv_ops_test.cc | 4 +- tensorflow/core/util/mkl_util.h | 13 +- 4 files changed, 448 insertions(+), 86 deletions(-) diff --git a/tensorflow/core/common_runtime/mkl_layout_pass.cc b/tensorflow/core/common_runtime/mkl_layout_pass.cc index e188147f637a1c..a98221a617fcc4 100644 --- a/tensorflow/core/common_runtime/mkl_layout_pass.cc +++ b/tensorflow/core/common_runtime/mkl_layout_pass.cc @@ -572,6 +572,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass { rinfo_.push_back({csinfo_.quantized_concatv2, mkl_op_registry::GetMklOpName(csinfo_.quantized_concatv2), CopyAttrsAll, ConcatV2Rewrite, kRewriteForOpNameChange}); +#endif // !ENABLE_ONEDNN_V3 rinfo_.push_back({csinfo_.quantized_conv2d, mkl_op_registry::GetMklOpName(csinfo_.quantized_conv2d), CopyAttrsQuantizedConv2D, AlwaysRewrite, @@ -613,9 +614,11 @@ class MklLayoutRewritePass : public GraphOptimizationPass { mkl_op_registry::GetMklOpName( csinfo_.quantized_conv2d_with_bias_and_relu_and_requantize), CopyAttrsQuantizedConv2D, AlwaysRewrite, kRewriteForOpNameChange}); +#ifndef ENABLE_ONEDNN_V3 rinfo_.push_back({csinfo_.quantized_max_pool, mkl_op_registry::GetMklOpName(csinfo_.quantized_max_pool), CopyAttrsAll, AlwaysRewrite, kRewriteForOpNameChange}); +#endif rinfo_.push_back({csinfo_.quantized_conv2d_with_bias_sum_and_relu, mkl_op_registry::GetMklOpName( csinfo_.quantized_conv2d_with_bias_sum_and_relu), @@ -631,6 +634,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass { mkl_op_registry::GetMklOpName( csinfo_.quant_conv2d_with_bias_signed_sum_and_relu_and_requantize), CopyAttrsQuantizedConv2D, AlwaysRewrite, kRewriteForOpNameChange}); +#ifndef ENABLE_ONEDNN_V3 rinfo_.push_back( {csinfo_.quantized_matmul_with_bias, mkl_op_registry::GetMklOpName(csinfo_.quantized_matmul_with_bias), @@ -657,6 +661,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass { csinfo_.quantized_matmul_with_bias_and_dequantize), CopyAttrsQuantizedMatMulWithBiasAndDequantize, AlwaysRewrite, kRewriteForOpNameChange}); +#endif // !ENABLE_ONEDNN_V3 rinfo_.push_back( {csinfo_.quantized_depthwise_conv2d, mkl_op_registry::GetMklOpName(csinfo_.quantized_depthwise_conv2d), @@ -677,6 +682,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass { csinfo_ .quantized_depthwise_conv2d_with_bias_and_relu_and_requantize), CopyAttrsQuantizedConv2D, AlwaysRewrite, kRewriteForOpNameChange}); +#ifndef ENABLE_ONEDNN_V3 rinfo_.push_back({csinfo_.quantize_v2, mkl_op_registry::GetMklOpName(csinfo_.quantize_v2), CopyAttrsAll, QuantizeOpRewrite, diff --git a/tensorflow/core/kernels/mkl/mkl_conv_ops.cc b/tensorflow/core/kernels/mkl/mkl_conv_ops.cc index 9ef3577d20d6d5..e5677a8a1c1705 100644 --- a/tensorflow/core/kernels/mkl/mkl_conv_ops.cc +++ b/tensorflow/core/kernels/mkl/mkl_conv_ops.cc @@ -49,6 +49,12 @@ namespace tensorflow { #define SET_FUSE_ACTIVATION_FOR_RELU6 \ set_fuse_activation(true, dnnl::algorithm::eltwise_bounded_relu, 6.0) #define SET_MKL_LAYOUT(md) SetMklLayout(&md) +#define TSCALED_BIAS Tbias +#define SCALE scales +#define SUMMAND_SCALE_U8(summand_range, output_range) \ + summand_range / output_range +#define SUMMAND_SCALE_S8(summand_range, output_range) \ + 255.0f * summand_range / (output_range * 127.0f) #else #define APPEND_DEPTHWISE(wei_dt, bias_dt, dst_dt, kernel, stride, padding, \ scales_mask, scales) \ @@ -58,6 +64,10 @@ namespace tensorflow { #define SET_FUSE_ACTIVATION_FOR_RELU6 \ set_fuse_activation(true, dnnl::algorithm::eltwise_clip, 0.0, 6.0) #define SET_MKL_LAYOUT(md) SetMklLayout(md) +#define TSCALED_BIAS float +#define SCALE wei_scale +#define SUMMAND_SCALE_U8(summand_range, output_range) summand_range / 255.0f +#define SUMMAND_SCALE_S8(summand_range, output_range) summand_range / 127.0f #endif // !ENABLE_ONEDNN_V3 // TODO(intel-tf) Remove this once old API of quantized ops is abandoned @@ -87,12 +97,14 @@ struct MklConvFwdParams { memory::dims fuse_bn_dims; MklTensorFormat tf_fmt; bool native_format; + bool is_depthwise; string dtypes = string(""); struct PostOpParam { string name; dnnl::algorithm alg; std::vector param; std::string partial_key; + DataType dtype = DT_INVALID; }; std::vector post_op_params; @@ -101,7 +113,7 @@ struct MklConvFwdParams { memory::dims strides, memory::dims dilations, memory::dims padding_left, memory::dims padding_right, memory::dims fuse_bn_dims, MklTensorFormat tf_fmt, - bool native_format) + bool native_format, bool is_depthwise) : src_dims(src_dims), filter_dims(filter_dims), bias_dims(bias_dims), @@ -112,7 +124,8 @@ struct MklConvFwdParams { padding_right(padding_right), fuse_bn_dims(fuse_bn_dims), tf_fmt(tf_fmt), - native_format(native_format) {} + native_format(native_format), + is_depthwise(is_depthwise) {} }; // With quantization, input, filter, and output can have different types @@ -139,16 +152,18 @@ class MklConvFwdPrimitive : public MklPrimitive { // bias_data: input data buffer of bias // dst_data: output data buffer of dst void Execute(const Tinput* src_data, const Tfilter* filter_data, - const Tbias* bias_data, const Toutput* dst_data, + const void* bias_data, const Toutput* dst_data, + const MklConvFwdParams& convFwdDims, std::shared_ptr fwd_stream, void* sp_data = nullptr) { Execute(src_data, filter_data, bias_data, dst_data, nullptr, nullptr, - nullptr, nullptr, fwd_stream, sp_data); + nullptr, nullptr, convFwdDims, fwd_stream, sp_data); } void Execute(const Tinput* src_data, const Tfilter* filter_data, - const Tbias* bias_data, const Toutput* dst_data, + const void* bias_data, const Toutput* dst_data, const Tinput* bn_scale_data, const Tinput* bn_mean_data, const Tinput* bn_offset_data, const Tinput* bn_rsqrt_data, + const MklConvFwdParams& convFwdDims, std::shared_ptr fwd_stream, void* sp_data) { #ifdef DNNL_AARCH64_USE_ACL // When we are using single global cache then in this case we can have @@ -162,8 +177,29 @@ class MklConvFwdPrimitive : public MklPrimitive { context_.filter_mem->set_data_handle( static_cast(const_cast(filter_data)), *fwd_stream); if (bias_data != nullptr) { - context_.bias_mem->set_data_handle( - static_cast(const_cast(bias_data)), *fwd_stream); + context_.bias_mem->set_data_handle(const_cast(bias_data), + *fwd_stream); + } + auto const& post_op_params = convFwdDims.post_op_params; + if (!post_op_params.empty()) { + for (auto const& post_op_param : post_op_params) { + if (post_op_param.name == "src_scale") { + context_.src_scale_mem->set_data_handle( + static_cast( + const_cast(post_op_param.param.data())), + *fwd_stream); + } else if (post_op_param.name == "wei_scale") { + context_.wei_scale_mem->set_data_handle( + static_cast( + const_cast(post_op_param.param.data())), + *fwd_stream); + } else if (post_op_param.name == "dst_scale") { + context_.dst_scale_mem->set_data_handle( + static_cast( + const_cast(post_op_param.param.data())), + *fwd_stream); + } + } } if (bn_scale_data != nullptr) { context_.bn_scale_mem->set_data_handle( @@ -187,8 +223,22 @@ class MklConvFwdPrimitive : public MklPrimitive { context_.filter_mem->set_data_handle( static_cast(const_cast(filter_data))); if (bias_data != nullptr) { - context_.bias_mem->set_data_handle( - static_cast(const_cast(bias_data))); + context_.bias_mem->set_data_handle(const_cast(bias_data)); + } + auto const& post_op_params = convFwdDims.post_op_params; + if (!post_op_params.empty()) { + for (auto const& post_op_param : post_op_params) { + if (post_op_param.name == "src_scale") { + context_.src_scale_mem->set_data_handle(static_cast( + const_cast(post_op_param.param.data()))); + } else if (post_op_param.name == "wei_scale") { + context_.wei_scale_mem->set_data_handle(static_cast( + const_cast(post_op_param.param.data()))); + } else if (post_op_param.name == "dst_scale") { + context_.dst_scale_mem->set_data_handle(static_cast( + const_cast(post_op_param.param.data()))); + } + } } if (bn_scale_data != nullptr) { context_.bn_scale_mem->set_data_handle( @@ -235,10 +285,10 @@ class MklConvFwdPrimitive : public MklPrimitive { // filter_data: input data buffer of filter (weights) // dst_data: output data buffer of dst void Execute(const Tinput* src_data, const Tfilter* filter_data, - const Toutput* dst_data, std::shared_ptr fwd_stream, - void* sp_data) { + const Toutput* dst_data, const MklConvFwdParams& convFwdDims, + std::shared_ptr fwd_stream, void* sp_data) { Execute(src_data, filter_data, nullptr, dst_data, nullptr, nullptr, nullptr, - nullptr, fwd_stream, sp_data); + nullptr, convFwdDims, fwd_stream, sp_data); } std::shared_ptr GetPrimitiveDesc() const { @@ -261,6 +311,11 @@ class MklConvFwdPrimitive : public MklPrimitive { std::shared_ptr bn_rsqrt_mem; std::shared_ptr bn_offset_mem; + // Quantization scale related memory + std::shared_ptr src_scale_mem; + std::shared_ptr wei_scale_mem; + std::shared_ptr dst_scale_mem; + // Desc & primitive desc #ifndef ENABLE_ONEDNN_V3 std::shared_ptr fwd_desc; @@ -279,6 +334,11 @@ class MklConvFwdPrimitive : public MklPrimitive { std::shared_ptr bn_rsqrt_md; std::shared_ptr bn_offset_md; + // Quantization scale related memory descriptors + std::shared_ptr src_scale_md; + std::shared_ptr wei_scale_md; + std::shared_ptr dst_scale_md; + // Convolution primitive std::shared_ptr conv_fwd; @@ -295,6 +355,9 @@ class MklConvFwdPrimitive : public MklPrimitive { bn_mean_mem(nullptr), bn_rsqrt_mem(nullptr), bn_offset_mem(nullptr), + src_scale_mem(nullptr), + wei_scale_mem(nullptr), + dst_scale_mem(nullptr), #ifndef ENABLE_ONEDNN_V3 fwd_desc(nullptr), #endif // !ENABLE_ONEDNN_V3 @@ -306,6 +369,9 @@ class MklConvFwdPrimitive : public MklPrimitive { bn_mean_md(nullptr), bn_rsqrt_md(nullptr), bn_offset_md(nullptr), + src_scale_md(nullptr), + wei_scale_md(nullptr), + dst_scale_md(nullptr), fwd_pd(nullptr), conv_fwd(nullptr) { } @@ -330,9 +396,15 @@ class MklConvFwdPrimitive : public MklPrimitive { {convFwdDims.dst_dims}, MklDnnType(), user_data_fmt)); if (!convFwdDims.bias_dims.empty()) { - context_.bias_md.reset(new memory::desc({convFwdDims.bias_dims}, - MklDnnType(), - memory::format_tag::any)); + if (std::is_same::value) { + context_.bias_md.reset(new memory::desc({convFwdDims.bias_dims}, + MklDnnType(), + memory::format_tag::any)); + } else { + context_.bias_md.reset(new memory::desc({convFwdDims.bias_dims}, + MklDnnType(), + memory::format_tag::any)); + } #ifndef ENABLE_ONEDNN_V3 // Create a convolution descriptor context_.fwd_desc.reset(new convolution_forward::desc( @@ -370,6 +442,7 @@ class MklConvFwdPrimitive : public MklPrimitive { dnnl::primitive_attr post_ops_attr; dnnl::post_ops post_ops; post_ops_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user); + std::unordered_map is_scale_set; if (!post_op_params.empty()) { for (auto const& post_op_param : post_op_params) { if (post_op_param.name == "activation") { @@ -383,21 +456,53 @@ class MklConvFwdPrimitive : public MklPrimitive { } else if (post_op_param.name == "sum") { DCHECK_EQ(post_op_param.param.size(), 1); float op_scale = post_op_param.param[0]; +#ifndef ENABLE_ONEDNN_V3 post_ops.append_sum(op_scale); - } else if (post_op_param.name == "output_scale") { +#else + if (post_op_param.dtype != DT_INVALID) { + if (post_op_param.dtype == DT_FLOAT) { + post_ops.append_sum(op_scale, /*zero_point=*/0, + MklDnnType()); + } else { + TF_CHECK_OK(Status(absl::StatusCode::kFailedPrecondition, + "Summand data type is expected to be float")); + } + } else { + post_ops.append_sum(op_scale); + } +#endif //! ENABLE_ONEDNN_V3 #ifndef ENABLE_ONEDNN_V3 + } else if (post_op_param.name == "output_scale") { if (post_op_param.param.size() == 1) { post_ops_attr.set_output_scales(0, post_op_param.param); } else { post_ops_attr.set_output_scales(2, post_op_param.param); } #else - // TODO(intel-tf): Enable this for int8 when using oneDNN v3.x - // and return a status instead of using DCHECK_EQ - DCHECK_EQ(post_op_param.param.size(), 1); + } else if (post_op_param.name == "src_scale") { + is_scale_set.insert({"src", true}); post_ops_attr.set_scales_mask(DNNL_ARG_SRC, 0); - post_ops_attr.set_scales_mask(DNNL_ARG_WEIGHTS, 0); + context_.src_scale_md.reset(new memory::desc({1}, MklDnnType(), + memory::format_tag::x)); + context_.src_scale_mem.reset( + new memory(*context_.src_scale_md, cpu_engine_, DummyData)); + } else if (post_op_param.name == "wei_scale") { + is_scale_set.insert({"wei", true}); + const int scale_size = post_op_param.param.size(); + const int mask = + scale_size == 1 ? 0 : convFwdDims.is_depthwise ? 3 : 1; + post_ops_attr.set_scales_mask(DNNL_ARG_WEIGHTS, mask); + context_.wei_scale_md.reset(new memory::desc( + {scale_size}, MklDnnType(), memory::format_tag::x)); + context_.wei_scale_mem.reset( + new memory(*context_.wei_scale_md, cpu_engine_, DummyData)); + } else if (post_op_param.name == "dst_scale") { + is_scale_set.insert({"dst", true}); post_ops_attr.set_scales_mask(DNNL_ARG_DST, 0); + context_.dst_scale_md.reset(new memory::desc({1}, MklDnnType(), + memory::format_tag::x)); + context_.dst_scale_mem.reset( + new memory(*context_.dst_scale_md, cpu_engine_, DummyData)); #endif // !ENABLE_ONEDNN_V3 } else if (post_op_param.name == "fuse_bn") { post_ops.append_binary(dnnl::algorithm::binary_sub, @@ -411,7 +516,13 @@ class MklConvFwdPrimitive : public MklPrimitive { } else { DCHECK((post_op_param.name == "activation") || (post_op_param.name == "sum") || +#ifndef ENABLE_ONEDNN_V3 (post_op_param.name == "output_scale") || +#else + (post_op_param.name == "src_scale") || + (post_op_param.name == "wei_scale") || + (post_op_param.name == "dst_scale") || +#endif // !ENABLE_ONEDNN_V3 (post_op_param.name == "fuse_bn")); } } @@ -451,15 +562,30 @@ class MklConvFwdPrimitive : public MklPrimitive { // Create convolution primitive and add it to net if (!convFwdDims.bias_dims.empty()) { - context_.bias_mem.reset(new memory( - {{convFwdDims.bias_dims}, MklDnnType(), memory::format_tag::x}, - cpu_engine_, DummyData)); - context_.fwd_primitives_args.push_back( - {{DNNL_ARG_SRC, *context_.src_mem}, - {DNNL_ARG_WEIGHTS, *context_.filter_mem}, - {DNNL_ARG_BIAS, *context_.bias_mem}, - {DNNL_ARG_SCRATCHPAD, *context_.sp_mem}, - {DNNL_ARG_DST, *context_.dst_mem}}); + context_.bias_mem.reset(new memory(context_.fwd_pd.get()->bias_desc(), + cpu_engine_, DummyData)); + if (is_scale_set["src"] && is_scale_set["wei"] && is_scale_set["dst"]) { + context_.fwd_primitives_args.push_back( + {{DNNL_ARG_SRC, *context_.src_mem}, + {DNNL_ARG_WEIGHTS, *context_.filter_mem}, + {DNNL_ARG_BIAS, *context_.bias_mem}, + {DNNL_ARG_SCRATCHPAD, *context_.sp_mem}, + {DNNL_ARG_DST, *context_.dst_mem}, +#ifdef ENABLE_ONEDNN_V3 + {DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, *context_.src_scale_mem}, + {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, *context_.wei_scale_mem}, + { DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST, + *context_.dst_scale_mem } +#endif // ENABLE_ONEDNN_V3 + }); + } else { + context_.fwd_primitives_args.push_back( + {{DNNL_ARG_SRC, *context_.src_mem}, + {DNNL_ARG_WEIGHTS, *context_.filter_mem}, + {DNNL_ARG_BIAS, *context_.bias_mem}, + {DNNL_ARG_SCRATCHPAD, *context_.sp_mem}, + {DNNL_ARG_DST, *context_.dst_mem}}); + } } else if (!convFwdDims.fuse_bn_dims.empty()) { context_.bn_scale_mem.reset( new memory(*context_.bn_scale_md, cpu_engine_, DummyData)); @@ -484,11 +610,26 @@ class MklConvFwdPrimitive : public MklPrimitive { {DNNL_ARG_ATTR_MULTIPLE_POST_OP(3) | DNNL_ARG_SRC_1, *context_.bn_offset_mem}}); } else { - context_.fwd_primitives_args.push_back( - {{DNNL_ARG_SRC, *context_.src_mem}, - {DNNL_ARG_WEIGHTS, *context_.filter_mem}, - {DNNL_ARG_SCRATCHPAD, *context_.sp_mem}, - {DNNL_ARG_DST, *context_.dst_mem}}); + if (is_scale_set["src"] && is_scale_set["wei"] && is_scale_set["dst"]) { + context_.fwd_primitives_args.push_back( + {{DNNL_ARG_SRC, *context_.src_mem}, + {DNNL_ARG_WEIGHTS, *context_.filter_mem}, + {DNNL_ARG_SCRATCHPAD, *context_.sp_mem}, + {DNNL_ARG_DST, *context_.dst_mem}, +#ifdef ENABLE_ONEDNN_V3 + {DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, *context_.src_scale_mem}, + {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, *context_.wei_scale_mem}, + { DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST, + *context_.dst_scale_mem } +#endif // ENABLE_ONEDNN_V3 + }); + } else { + context_.fwd_primitives_args.push_back( + {{DNNL_ARG_SRC, *context_.src_mem}, + {DNNL_ARG_WEIGHTS, *context_.filter_mem}, + {DNNL_ARG_SCRATCHPAD, *context_.sp_mem}, + {DNNL_ARG_DST, *context_.dst_mem}}); + } } context_.fwd_primitives.push_back(*context_.conv_fwd); } @@ -576,8 +717,17 @@ class MklConvFwdPrimitiveFactory : public MklPrimitiveFactory { for (auto& param : post_op_param.param) { key_creator.AddAsKey(param); } +#ifndef ENABLE_ONEDNN_V3 } else if (post_op_param.name == "output_scale") { key_creator.AddAsKey(post_op_param.partial_key); +#else + } else if (post_op_param.name == "src_scale") { + key_creator.AddAsKey(post_op_param.partial_key); + } else if (post_op_param.name == "wei_scale") { + key_creator.AddAsKey(post_op_param.partial_key); + } else if (post_op_param.name == "dst_scale") { + key_creator.AddAsKey(post_op_param.partial_key); +#endif // !ENABLE_ONEDNN_V3 } else if (post_op_param.name == "fuse_bn") { key_creator.AddAsKey(post_op_param.name); key_creator.AddAsKey(convFwdDims.fuse_bn_dims); @@ -871,7 +1021,7 @@ class MklConvOp : public OpKernel { MklConvFwdParams convFwdDims( src_dims, filter_dims, fuse_biasadd_ ? bias_dims : NONE_DIMS, dst_dims_mkl_order, strides, dilations, padding_left, padding_right, - fuse_bn_dims, tf_fmt, native_format); + fuse_bn_dims, tf_fmt, native_format, is_depthwise); // TODO(intel-tf): Extend the basic parameters for data types and fusions this->ExtendConvFwdParams(context, convFwdDims); @@ -952,10 +1102,10 @@ class MklConvOp : public OpKernel { fwd_cpu_stream.reset(CreateStream(&eigen_tp, conv_fwd->GetEngine())); if (fuse_biasadd_) { const Tensor& bias_tensor = MklGetInput(context, kInputIndex_Bias); - Tbias* bias_data = + void* bias_data = this->GetBiasHandle(context, conv_fwd_pd, bias_tensor); conv_fwd->Execute(src_data, filter_data, bias_data, dst_data, - fwd_cpu_stream, scratch_pad.Get()); + convFwdDims, fwd_cpu_stream, scratch_pad.Get()); } else if (fuse_bn_) { const Tensor& bn_scale_tensor = MklGetInput(context, kInputIndex_BN_Scale); @@ -980,10 +1130,11 @@ class MklConvOp : public OpKernel { bn_rsqrt_data); conv_fwd->Execute(src_data, filter_data, nullptr, dst_data, bn_scale_data, bn_mean_data, bn_offset_data, - bn_rsqrt_data, fwd_cpu_stream, scratch_pad.Get()); - } else { - conv_fwd->Execute(src_data, filter_data, dst_data, fwd_cpu_stream, + bn_rsqrt_data, convFwdDims, fwd_cpu_stream, scratch_pad.Get()); + } else { + conv_fwd->Execute(src_data, filter_data, dst_data, convFwdDims, + fwd_cpu_stream, scratch_pad.Get()); } // Delete primitive since it is not cached. @@ -1145,9 +1296,9 @@ class MklConvOp : public OpKernel { } } - virtual Tbias* GetBiasHandle(OpKernelContext* context, - std::shared_ptr& conv2d_fwd_pd, - const Tensor& bias_tensor) { + virtual void* GetBiasHandle(OpKernelContext* context, + std::shared_ptr& conv2d_fwd_pd, + const Tensor& bias_tensor) { if (fuse_biasadd_) { return static_cast( const_cast(bias_tensor.flat().data())); @@ -1162,6 +1313,7 @@ class MklConvOp : public OpKernel { MklDnnShape* output_mkl_shape, Tensor** output_tensor) { DCHECK(output_tensor); +#ifndef ENABLE_ONEDNN_V3 auto dst_md = conv_prim_desc.dst_desc(); if (!std::is_same::value) { @@ -1176,6 +1328,14 @@ class MklConvOp : public OpKernel { MklTensorFormatToMklDnnDataFormat(output_tf_format)); #endif // !ENABLE_ONEDNN_V3 } +#else + auto dst_md = + std::is_same::value + ? conv_prim_desc.dst_desc() + : memory::desc(conv_prim_desc.dst_desc().get_dims(), + MklDnnType(), + MklTensorFormatToMklDnnDataFormat(output_tf_format)); +#endif // !ENABLE_ONEDNN_V3 // Allocate shape of MKL tensor output_mkl_shape->SetMklTensor(true); @@ -1815,10 +1975,19 @@ class MklQuantizedConvOp // If Requantize is fused, we set output_scale as first post op since it is // logically applied before any post op. Then we maintain the order of post // ops according to the order of fused_ops. +#ifndef ENABLE_ONEDNN_V3 int idx = fuse_requantize ? 1 : 0; +#else + post_op_to_idx_["src_scale"] = 0; + post_op_to_idx_["wei_scale"] = 1; + post_op_to_idx_["dst_scale"] = 2; + int idx = 3; +#endif // !ENABLE_ONEDNN_V3 for (int i = 0; i < fused_ops_.size(); ++i) { if (fused_ops_[i] == "Requantize") { +#ifndef ENABLE_ONEDNN_V3 post_op_to_idx_["output_scale"] = 0; +#endif // !ENABLE_ONEDNN_V3 } else if (fused_ops_[i] == "Sum") { post_op_to_idx_["sum"] = idx++; } else if (fused_ops_[i] == "Relu") { @@ -1968,24 +2137,30 @@ class MklQuantizedConvOp /*pad_enabled*/ false, is_depthwise, /*native_format*/ true>::ExtendConvFwdParams(context, params); params.post_op_params.resize(post_op_to_idx_.size()); - // When the output type is quint8, the output data is requantized - // into quint8. A post_op "output_scale" is added to do the conversion. + const float min_input = + context->input(min_input_idx_).template scalar()(); + const float max_input = + context->input(max_input_idx_).template scalar()(); + const Tensor& min_filter_vector = context->input(min_filter_idx_); + const Tensor& max_filter_vector = context->input(max_filter_idx_); + OP_REQUIRES( + context, + ((min_filter_vector.NumElements() > 0) && + (max_filter_vector.NumElements() > 0) && + (min_filter_vector.shape() == max_filter_vector.shape())), + errors::InvalidArgument("`min_ and max_filter` must have same" + "shape and contain at least one element.")); + float int_input_limit = + std::is_same::value ? 255.0f : 127.0f; + size_t depth = min_filter_vector.NumElements(); + const float* min_filter = min_filter_vector.flat().data(); + const float* max_filter = max_filter_vector.flat().data(); + std::vector SCALE(depth); + float float_input_range = + std::max(std::abs(min_input), std::abs(max_input)); + const float src_scale = float_input_range / int_input_limit; if (std::is_same::value || std::is_same::value) { - const float min_input = - context->input(min_input_idx_).template scalar()(); - const float max_input = - context->input(max_input_idx_).template scalar()(); - const Tensor& min_filter_vector = context->input(min_filter_idx_); - const Tensor& max_filter_vector = context->input(max_filter_idx_); - OP_REQUIRES( - context, - ((min_filter_vector.NumElements() > 0) && - (max_filter_vector.NumElements() > 0) && - (min_filter_vector.shape() == max_filter_vector.shape())), - errors::InvalidArgument("`min_ and max_filter` must have same" - "shape and contain at least one element.")); - // min_freezed_output and max_freezed_output are the actual range // for the output. const float min_freezed_output = @@ -1995,12 +2170,6 @@ class MklQuantizedConvOp float int_output_limit = std::is_same::value ? 255.0f : 127.0f; - size_t depth = min_filter_vector.NumElements(); - const float* min_filter = min_filter_vector.flat().data(); - const float* max_filter = max_filter_vector.flat().data(); - std::vector scales(depth); - float float_input_range = - std::max(std::abs(min_input), std::abs(max_input)); float float_output_range = std::max(std::abs(min_freezed_output), std::abs(max_freezed_output)); const float int_const_scale_limit = @@ -2011,13 +2180,18 @@ class MklQuantizedConvOp float float_filter_range = std::max(std::abs(min_filter[i]), std::abs(max_filter[i])); // To understand the scaling, please see mkl_requantize_ops_test. +#ifndef ENABLE_ONEDNN_V3 scales[i] = int_output_limit * float_input_range * float_filter_range / (int_const_scale_limit * float_output_range); +#else + wei_scale[i] = float_filter_range / 127.0; +#endif // !ENABLE_ONEDNN_V3 } // we are creating a partial key here to use with primitive key caching to // improve key creation performance. Instead of using actual values we are // using the pointers for min/max_filter_vector, and this works since the // filter vector here is a constant. +#ifndef ENABLE_ONEDNN_V3 FactoryKeyCreator param_key; param_key.AddAsKey(min_input); param_key.AddAsKey(max_input); @@ -2027,12 +2201,63 @@ class MklQuantizedConvOp param_key.AddAsKey(max_filter); params.post_op_params[post_op_to_idx_["output_scale"]] = { "output_scale", dnnl::algorithm::undef, scales, param_key.GetKey()}; - } - +#else + const float dst_scale = float_output_range / int_output_limit; + FactoryKeyCreator dst_param_key; + dst_param_key.AddAsKey(min_freezed_output); + dst_param_key.AddAsKey(max_freezed_output); + params.post_op_params[post_op_to_idx_["dst_scale"]] = { + "dst_scale", + dnnl::algorithm::undef, + {dst_scale}, + dst_param_key.GetKey()}; +#endif // !ENABLE_ONEDNN_V3 + } else { +#ifdef ENABLE_ONEDNN_V3 + if (!std::is_same::value) + TF_CHECK_OK(Status(absl::StatusCode::kFailedPrecondition, + "Output datatype is expected to be qint32.")); + float min_min_filter = min_filter[0]; + float max_max_filter = max_filter[0]; + for (size_t i = 0; i < depth; ++i) { + float float_filter_range = + std::max(std::abs(min_filter[i]), std::abs(max_filter[i])); + wei_scale[i] = float_filter_range / 127.0; + if (min_filter[i] < min_min_filter) min_min_filter = min_filter[i]; + if (max_filter[i] > max_max_filter) max_max_filter = max_filter[i]; + } + const float single_wei_scale = + std::max(std::abs(min_min_filter), std::abs(max_max_filter)) / 127.0; + const float dst_scale = single_wei_scale * src_scale; + FactoryKeyCreator dst_param_key; + dst_param_key.AddAsKey(dst_scale); + params.post_op_params[post_op_to_idx_["dst_scale"]] = { + "dst_scale", + dnnl::algorithm::undef, + {dst_scale}, + dst_param_key.GetKey()}; +#endif // ENABLE_ONEDNN_V3 + } + +#ifdef ENABLE_ONEDNN_V3 + FactoryKeyCreator src_param_key; + src_param_key.AddAsKey(min_input); + src_param_key.AddAsKey(max_input); + FactoryKeyCreator wei_param_key; + wei_param_key.AddAsKey(min_filter); + wei_param_key.AddAsKey(max_filter); + params.post_op_params[post_op_to_idx_["src_scale"]] = { + "src_scale", + dnnl::algorithm::undef, + {src_scale}, + src_param_key.GetKey()}; + params.post_op_params[post_op_to_idx_["wei_scale"]] = { + "wei_scale", dnnl::algorithm::undef, wei_scale, wei_param_key.GetKey()}; +#endif // ENABLE_ONEDNN_V3 if (this->get_fuse_add()) { // Calculate the scale (beta in oneDNN api term) for sum + DataType summand_dt = this->input_type(this->get_input_add_idx()); if (std::is_same::value) { - DataType summand_dt = this->input_type(this->get_input_add_idx()); bool summand_condition = (summand_dt == DT_QINT8) || (summand_dt == DT_QUINT8); DCHECK((summand_condition)); @@ -2089,18 +2314,24 @@ class MklQuantizedConvOp params.post_op_params[post_op_to_idx_["sum"]] = { "sum", dnnl::algorithm::undef, - {summand_range / output_range}, + {SUMMAND_SCALE_U8(summand_range, output_range)}, ""}; } else { params.post_op_params[post_op_to_idx_["sum"]] = { "sum", dnnl::algorithm::undef, - {255.0f * summand_range / (output_range * 127.0f)}, + {SUMMAND_SCALE_S8(summand_range, output_range)}, ""}; } } else { - params.post_op_params[post_op_to_idx_["sum"]] = { - "sum", dnnl::algorithm::undef, {1.0}, ""}; + params.post_op_params[post_op_to_idx_["sum"]] = {"sum", + dnnl::algorithm::undef, + {1.0}, + "", +#ifdef ENABLE_ONEDNN_V3 + summand_dt +#endif // ENABLE_ONEDNN_V3 + }; } } @@ -2147,6 +2378,7 @@ class MklQuantizedConvOp "Summand cannot be forwarded in the current fusion.")); return; } +#ifndef ENABLE_ONEDNN_V3 MklConvOp< Device, Tinput, /*Tfilter*/ qint8, Tbias, Toutput, Ttemp_output, /*Tpadding*/ int32, @@ -2209,15 +2441,34 @@ class MklQuantizedConvOp conv_prim_desc.dst_desc(), reorder_attr); CreateAndExecuteReorder(reorder_desc, *summand_, *dst_, this->cpu_engine_, context); +#else + // In oneDNN v3.0 summand does not need to be scaled. + int summand_idx = this->get_input_add_idx(); + DataType summand_dt = this->input_type(summand_idx); + if (summand_dt != DT_FLOAT) + TF_CHECK_OK(Status(absl::StatusCode::kFailedPrecondition, + "Summand datatype is expected to be float.")); + Tensor& summand_float = const_cast(context->input(summand_idx)); + OP_REQUIRES_OK(context, + summand_float.BitcastFrom(summand_float, DT_QINT32, + summand_float.shape())); + OP_REQUIRES(context, + context->forward_input_to_output_with_shape( + summand_idx, 0, summand_float.shape(), output_tensor), + errors::InvalidArgument( + "Summand cannot be forwarded in the current fusion.")); + +#endif // !ENABLE_ONEDNN_V3 } } - Tbias* GetBiasHandle(OpKernelContext* context, - std::shared_ptr& conv_fwd_pd, - const Tensor& bias_tensor) override { + void* GetBiasHandle(OpKernelContext* context, + std::shared_ptr& conv_fwd_pd, + const Tensor& bias_tensor) override { if (!this->get_fuse_biasadd()) { return nullptr; } +#ifndef ENABLE_ONEDNN_V3 if (std::is_same::value) { return static_cast( const_cast(bias_tensor.flat().data())); @@ -2236,7 +2487,7 @@ class MklQuantizedConvOp (std::is_same::value) ? 255.0 * 127.0 : 127.0 * 127.0; // Re-scale bias if either of following 2 conditions are met: // 1. Bias is not const; - // 2. Bias is const, but bias cache is empty (first iteration). + // 2. Bias is const, bias has not been cached (first iteration). size_t depth = min_filter_vector.NumElements(); bool scales_are_valid = (depth == scales_.size()); @@ -2300,6 +2551,96 @@ class MklQuantizedConvOp return bias_data; } return GetCachedBias(context); +#else + if (std::is_same::value) { + return static_cast( + const_cast(bias_tensor.flat().data())); + } + // Starting oneDNN v3.0, bias needs to be passed as is (in float datatype). + // However, for backward compatibility we need to handle the case where bias + // is qint32. Since oneDNN v3.0 does not support qint32 bias, we need to + // dequantize to float. + const float min_input = + context->input(min_input_idx_).template scalar()(); + const float max_input = + context->input(max_input_idx_).template scalar()(); + const Tensor& min_filter_vector = context->input(min_filter_idx_); + const Tensor& max_filter_vector = context->input(max_filter_idx_); + if ((min_filter_vector.NumElements() == 0) || + (max_filter_vector.NumElements() == 0) || + (min_filter_vector.shape() != max_filter_vector.shape())) { + TF_CHECK_OK(Status(absl::StatusCode::kFailedPrecondition, + "`min_filter and max_filter` must have same" + "shape and contain at least one element.")); + } + const float* min_filter = min_filter_vector.flat().data(); + const float* max_filter = max_filter_vector.flat().data(); + const float int_const_scale_limit = + (std::is_same::value) ? 255.0 * 127.0 : 127.0 * 127.0; + // Re-scale bias if either of following 2 conditions are met: + // 1. Bias is not const; + // 2. Bias is const, but bias cache is empty (first iteration). + + size_t depth = min_filter_vector.NumElements(); + bool scales_are_valid = (depth == scales_.size()); + scales_.resize(depth); + for (size_t i = 0; i < depth; ++i) { + float tmp_scale = + int_const_scale_limit / + (std::max(std::abs(max_input), std::abs(min_input)) * + std::max(std::abs(max_filter[i]), std::abs(min_filter[i]))); + if (scales_are_valid && std::abs(tmp_scale - scales_[i]) > 1e-6) { + scales_are_valid = false; + } + scales_[i] = tmp_scale; + } + if (!is_bias_const_ || IsBiasCacheEmpty(context) || !scales_are_valid) { + dnnl::primitive_attr reorder_attr; + + if (depth == 1) { + reorder_attr.set_scales_mask(DNNL_ARG_DST, 0); + } else { + reorder_attr.set_scales_mask(DNNL_ARG_DST, 1); + } + + auto bias_md = memory::desc({static_cast(bias_tensor.NumElements())}, + MklDnnType(), memory::format_tag::x); + void* bias_buf = static_cast( + const_cast(bias_tensor.flat().data())); + if (!input_bias_) { + input_bias_ = new memory(bias_md, this->cpu_engine_, bias_buf); + } else { + input_bias_->set_data_handle(bias_buf); + } + + if (!scaled_bias_buf_) + AllocTmpBuffer(context, &scaled_bias_tensor_, + conv_fwd_pd->bias_desc(), &scaled_bias_buf_); + if (!scaled_bias_) { + scaled_bias_ = new memory(conv_fwd_pd->bias_desc(), this->cpu_engine_, + scaled_bias_buf_); + } else { + scaled_bias_->set_data_handle(scaled_bias_buf_); + } + std::unique_ptr scale_mem( + new memory({{depth}, MklDnnType(), memory::format_tag::x}, + this->cpu_engine_, scales_.data())); + auto reorder_desc = + ReorderPd(this->cpu_engine_, input_bias_->get_desc(), + this->cpu_engine_, scaled_bias_->get_desc(), reorder_attr); + CreateAndExecuteReorder(reorder_desc, *input_bias_, *scaled_bias_, + this->cpu_engine_, context, scale_mem.get()); + + float* bias_data = + reinterpret_cast(scaled_bias_->get_data_handle()); + if (is_bias_const_) + CacheBias(context, conv_fwd_pd, bias_data, scaled_bias_); + + return bias_data; + } + return GetCachedBias(context); + +#endif // !ENABLE_ONEDNN_V3 } bool is_bias_const_; @@ -2354,9 +2695,9 @@ class MklQuantizedConvOp DCHECK(bias_tensor); TensorShape bias_tf_shape; bias_tf_shape.AddDim( - (conv_prim_desc.bias_desc().get_size() / sizeof(Tbias))); + (conv_prim_desc.bias_desc().get_size() / sizeof(TSCALED_BIAS))); OP_REQUIRES_OK(context, - context->allocate_temp(DataTypeToEnum::value, + context->allocate_temp(DataTypeToEnum::value, bias_tf_shape, &cached_bias_data_)); *bias_tensor = &cached_bias_data_; } @@ -2374,7 +2715,7 @@ class MklQuantizedConvOp // Only one thread can execute this method at any given time. void CacheBias(OpKernelContext* context, const std::shared_ptr& conv_fwd_pd, - Tbias* bias_data, const memory* scaled_bias) + TSCALED_BIAS* bias_data, const memory* scaled_bias) TF_LOCKS_EXCLUDED(bias_cache_mu_) { mutex_lock lock(bias_cache_mu_); @@ -2387,18 +2728,18 @@ class MklQuantizedConvOp Tensor* bias_tensor_ptr = nullptr; AllocateTensor(context, *conv_fwd_pd, &bias_tensor_ptr); void* cached_bias_data = const_cast( - static_cast(bias_tensor_ptr->flat().data())); + static_cast(bias_tensor_ptr->flat().data())); size_t cached_bias_data_size = scaled_bias->get_desc().get_size(); memcpy(cached_bias_data, bias_data, cached_bias_data_size); } - Tbias* GetCachedBias(OpKernelContext* context) + TSCALED_BIAS* GetCachedBias(OpKernelContext* context) TF_LOCKS_EXCLUDED(bias_cache_mu_) { tf_shared_lock lock(bias_cache_mu_); const Tensor& cached_bias_data = cached_bias_data_; - return static_cast( - const_cast(cached_bias_data.flat().data())); + return static_cast(const_cast( + cached_bias_data.flat().data())); } }; @@ -2904,6 +3245,10 @@ REGISTER_KERNEL_BUILDER( #undef GET_DATA_TYPE #undef SET_FUSE_ACTIVATION_FOR_RELU6 #undef SET_MKL_LAYOUT +#undef TSCALED_BIAS +#undef SCALE +#undef SUMMAND_SCALE_U8 +#undef SUMMAND_SCALE_S8 } // namespace tensorflow #endif // INTEL_MKL diff --git a/tensorflow/core/kernels/mkl/mkl_quantized_conv_ops_test.cc b/tensorflow/core/kernels/mkl/mkl_quantized_conv_ops_test.cc index 00cc02bfcad397..4dc4634775b075 100644 --- a/tensorflow/core/kernels/mkl/mkl_quantized_conv_ops_test.cc +++ b/tensorflow/core/kernels/mkl/mkl_quantized_conv_ops_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#if defined(INTEL_MKL) && !defined(ENABLE_ONEDNN_V3) && defined(ENABLE_MKL) +#if defined(INTEL_MKL) && defined(ENABLE_MKL) #define EIGEN_USE_THREADS #include @@ -1062,4 +1062,4 @@ TEST_F(QuantizedConvTest, BiasAddSumReluFusionFloatSummand) { } } // namespace tensorflow -#endif // INTEL_MKL && !ENABLE_ONEDNN_V3 && ENABLE_MKL +#endif // INTEL_MKL && ENABLE_MKL diff --git a/tensorflow/core/util/mkl_util.h b/tensorflow/core/util/mkl_util.h index 322991376f9924..84f71fa2761388 100644 --- a/tensorflow/core/util/mkl_util.h +++ b/tensorflow/core/util/mkl_util.h @@ -1322,11 +1322,22 @@ inline Status CreateBlockedMemDescHelper(const memory::dims& dim, inline void CreateAndExecuteReorder(const ReorderPd& reorder_desc, const memory& src_mem, const memory& dst_mem, const engine& engine, - OpKernelContext* ctx = nullptr) { + OpKernelContext* ctx = nullptr, + memory* scale_mem = nullptr) { std::vector net; net.push_back(dnnl::reorder(reorder_desc)); std::vector net_args; +#ifndef ENABLE_ONEDNN_V3 net_args.push_back({{DNNL_ARG_FROM, src_mem}, {DNNL_ARG_TO, dst_mem}}); +#else + if (scale_mem != nullptr) { + net_args.push_back({{DNNL_ARG_FROM, src_mem}, + {DNNL_ARG_TO, dst_mem}, + {DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST, *scale_mem}}); + } else { + net_args.push_back({{DNNL_ARG_FROM, src_mem}, {DNNL_ARG_TO, dst_mem}}); + } +#endif // !ENABLE_ONEDNN_V3 ExecutePrimitive(net, &net_args, engine, ctx); } From e73158fb1df3c58207f4df57e739c1348529023c Mon Sep 17 00:00:00 2001 From: Sulav Date: Fri, 21 Apr 2023 11:27:05 -0400 Subject: [PATCH 005/376] Added missing const and ensure proper resource management with unique_ptr --- tensorflow/cc/experimental/libexport/load.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tensorflow/cc/experimental/libexport/load.h b/tensorflow/cc/experimental/libexport/load.h index cd85fb5f2b7efc..c8c5a4ef027c24 100644 --- a/tensorflow/cc/experimental/libexport/load.h +++ b/tensorflow/cc/experimental/libexport/load.h @@ -83,8 +83,8 @@ class TFPackage { // Returns a BundleReader for reading variable values. // // This TFPackage retains ownership of the underlying reader. - tensorflow::BundleReader* GetVariableReader() { - return variable_reader_.get(); + const std::unique_ptr &GetVariableReader() { + return variable_reader_; } // Returns whether or not we found a valid checkpoint when loading the @@ -92,7 +92,7 @@ class TFPackage { bool HasCheckpoint() { return has_checkpoint_; } // Returns the path to the variables file. - const std::string GetVariablesFilepath() { return variables_filepath_; } + const std::string GetVariablesFilepath() const { return variables_filepath_; } private: SavedModel saved_model_proto_; From 9c49d536d842bac8ec3cad100b5354b8442c1481 Mon Sep 17 00:00:00 2001 From: Lu Teng Date: Tue, 25 Apr 2023 09:30:17 +0800 Subject: [PATCH 006/376] Reduce the log times when plugin is enabled. --- .../optimizers/custom_graph_optimizer_registry.cc | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc b/tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc index 6de0c5d29d6d34..c4f4fdc1f3ced5 100644 --- a/tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc +++ b/tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include "absl/base/call_once.h" #include "tensorflow/core/platform/logging.h" namespace tensorflow { @@ -111,8 +112,11 @@ PluginGraphOptimizerRegistry::CreateOptimizers( for (auto it = GetPluginRegistrationMap()->begin(); it != GetPluginRegistrationMap()->end(); ++it) { if (device_types.find(it->first) == device_types.end()) continue; - LOG(INFO) << "Plugin optimizer for device_type " << it->first - << " is enabled."; + static absl::once_flag plugin_optimizer_flag; + absl::call_once(plugin_optimizer_flag, [&]() { + LOG(INFO) << "Plugin optimizer for device_type " << it->first + << " is enabled."; + }); optimizer_list.emplace_back( std::unique_ptr(it->second())); } From 8cd258cb4706509eea3fb07605a27905f6294cce Mon Sep 17 00:00:00 2001 From: Sulav Date: Tue, 25 Apr 2023 01:31:24 -0400 Subject: [PATCH 007/376] Revert change in GetVariableReader function as per review feedback --- tensorflow/cc/experimental/libexport/load.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tensorflow/cc/experimental/libexport/load.h b/tensorflow/cc/experimental/libexport/load.h index c8c5a4ef027c24..dfa5ee77bad333 100644 --- a/tensorflow/cc/experimental/libexport/load.h +++ b/tensorflow/cc/experimental/libexport/load.h @@ -83,9 +83,9 @@ class TFPackage { // Returns a BundleReader for reading variable values. // // This TFPackage retains ownership of the underlying reader. - const std::unique_ptr &GetVariableReader() { - return variable_reader_; - } +tensorflow::BundleReader* GetVariableReader() { + return variable_reader_.get(); +} // Returns whether or not we found a valid checkpoint when loading the // package. From a46a629cf9810d9a44f7bec0b4feefe6e9e5e4ac Mon Sep 17 00:00:00 2001 From: Mahmoud Abuzaina Date: Wed, 3 May 2023 17:30:12 -0700 Subject: [PATCH 008/376] Cleaned up some #ifdefs --- tensorflow/core/common_runtime/mkl_layout_pass.cc | 6 +----- tensorflow/core/kernels/mkl/mkl_fused_ops_test.cc | 2 -- 2 files changed, 1 insertion(+), 7 deletions(-) diff --git a/tensorflow/core/common_runtime/mkl_layout_pass.cc b/tensorflow/core/common_runtime/mkl_layout_pass.cc index 9b93f2ef651695..18bf9d23298d1e 100644 --- a/tensorflow/core/common_runtime/mkl_layout_pass.cc +++ b/tensorflow/core/common_runtime/mkl_layout_pass.cc @@ -572,7 +572,6 @@ class MklLayoutRewritePass : public GraphOptimizationPass { rinfo_.push_back({csinfo_.quantized_concatv2, mkl_op_registry::GetMklOpName(csinfo_.quantized_concatv2), CopyAttrsAll, ConcatV2Rewrite, kRewriteForOpNameChange}); -#endif // !ENABLE_ONEDNN_V3 rinfo_.push_back({csinfo_.quantized_conv2d, mkl_op_registry::GetMklOpName(csinfo_.quantized_conv2d), CopyAttrsQuantizedConv2D, AlwaysRewrite, @@ -614,11 +613,9 @@ class MklLayoutRewritePass : public GraphOptimizationPass { mkl_op_registry::GetMklOpName( csinfo_.quantized_conv2d_with_bias_and_relu_and_requantize), CopyAttrsQuantizedConv2D, AlwaysRewrite, kRewriteForOpNameChange}); -#ifndef ENABLE_ONEDNN_V3 rinfo_.push_back({csinfo_.quantized_max_pool, mkl_op_registry::GetMklOpName(csinfo_.quantized_max_pool), CopyAttrsAll, AlwaysRewrite, kRewriteForOpNameChange}); -#endif rinfo_.push_back({csinfo_.quantized_conv2d_with_bias_sum_and_relu, mkl_op_registry::GetMklOpName( csinfo_.quantized_conv2d_with_bias_sum_and_relu), @@ -634,7 +631,6 @@ class MklLayoutRewritePass : public GraphOptimizationPass { mkl_op_registry::GetMklOpName( csinfo_.quant_conv2d_with_bias_signed_sum_and_relu_and_requantize), CopyAttrsQuantizedConv2D, AlwaysRewrite, kRewriteForOpNameChange}); -#ifndef ENABLE_ONEDNN_V3 rinfo_.push_back( {csinfo_.quantized_matmul_with_bias, mkl_op_registry::GetMklOpName(csinfo_.quantized_matmul_with_bias), @@ -661,7 +657,6 @@ class MklLayoutRewritePass : public GraphOptimizationPass { csinfo_.quantized_matmul_with_bias_and_dequantize), CopyAttrsQuantizedMatMulWithBiasAndDequantize, AlwaysRewrite, kRewriteForOpNameChange}); -#endif // !ENABLE_ONEDNN_V3 rinfo_.push_back( {csinfo_.quantized_depthwise_conv2d, mkl_op_registry::GetMklOpName(csinfo_.quantized_depthwise_conv2d), @@ -682,6 +677,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass { csinfo_ .quantized_depthwise_conv2d_with_bias_and_relu_and_requantize), CopyAttrsQuantizedConv2D, AlwaysRewrite, kRewriteForOpNameChange}); +#endif // !ENABLE_ONEDNN_V3 rinfo_.push_back({csinfo_.quantize_v2, mkl_op_registry::GetMklOpName(csinfo_.quantize_v2), CopyAttrsAll, QuantizeOpRewrite, diff --git a/tensorflow/core/kernels/mkl/mkl_fused_ops_test.cc b/tensorflow/core/kernels/mkl/mkl_fused_ops_test.cc index 118b3273ea7702..8e9038b056f755 100644 --- a/tensorflow/core/kernels/mkl/mkl_fused_ops_test.cc +++ b/tensorflow/core/kernels/mkl/mkl_fused_ops_test.cc @@ -1258,7 +1258,6 @@ class BiasCacheTest : public OpsTestBase { } }; -#ifndef ENABLE_ONEDNN_V3 TEST_F(BiasCacheTest, Conv2DBiasCacheTestOldAPI) { TestConv2DBiasCacheTest(true); } @@ -1266,7 +1265,6 @@ TEST_F(BiasCacheTest, Conv2DBiasCacheTestOldAPI) { TEST_F(BiasCacheTest, Conv2DBiasCacheTestNewAPI) { TestConv2DBiasCacheTest(false); } -#endif // !ENABLE_ONEDNN_V3 // Testing fusion of pad and fusedconv2d template From 47eaa828a1dd4bf50ec4203ef4bbb348b3ef0dd0 Mon Sep 17 00:00:00 2001 From: dingyuqing05 Date: Thu, 4 May 2023 09:01:40 +0000 Subject: [PATCH 009/376] Add nullptr check --- tensorflow/c/kernels.cc | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tensorflow/c/kernels.cc b/tensorflow/c/kernels.cc index 59f978000e66f9..47ce55470b80ff 100644 --- a/tensorflow/c/kernels.cc +++ b/tensorflow/c/kernels.cc @@ -357,6 +357,10 @@ void TF_GetInput(TF_OpKernelContext* ctx, int i, TF_Tensor** tensor, return; } const ::tensorflow::Tensor& cc_tensor(cc_ctx->input(i)); + if ((&cc_tensor) == nullptr) { + *tensor = nullptr; + return; + } TF_Tensor* result = ::tensorflow::TF_TensorFromTensor(cc_tensor, &status->status); if (TF_GetCode(status) == TF_OK) { From cb91b15519aeb4d873d3ad2acbc5b23fbcb17dcf Mon Sep 17 00:00:00 2001 From: Yaohui Liu Date: Sun, 16 Apr 2023 17:16:31 +0800 Subject: [PATCH 010/376] Add more c apis which are correponding to python c apis. --- tensorflow/c/c_api.cc | 142 ++++++++++++++++++ tensorflow/c/c_api.h | 43 ++++++ tensorflow/core/framework/BUILD | 20 ++- .../core/framework/cpp_shape_inference.proto | 36 +++++ 4 files changed, 238 insertions(+), 3 deletions(-) create mode 100644 tensorflow/core/framework/cpp_shape_inference.proto diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index 11e0bcaeaae12b..599b676bbd69b7 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -55,10 +55,14 @@ limitations under the License. #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.pb.h" // NOLINT +#include "tensorflow/core/framework/cpp_shape_inference.pb.h" +#include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/versions.pb.h" +#include "tensorflow/core/framework/full_type.pb.h" +#include "tensorflow/core/framework/attr_value_util.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/graph/validate.h" @@ -2570,6 +2574,144 @@ void TF_UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst, } } +// Apis that are correponding to python c api. -------------------------- + +void TF_AddControlInput(TF_Graph* graph, TF_Operation* op, TF_Operation* input) { + mutex_lock l(graph->mu); + graph->graph.AddControlEdge(&input->node, &op->node); + tensorflow::RecordMutation(graph, *op, "adding control input"); +} + +void TF_SetAttr(TF_Graph* graph, TF_Operation* op, const char* attr_name, + TF_Buffer* attr_value_proto, TF_Status* status) { + using tensorflow::RecordMutation; + tensorflow::AttrValue attr_val; + if (!attr_val.ParseFromArray(attr_value_proto->data, + attr_value_proto->length)) { + status->status = + tensorflow::errors::InvalidArgument("Invalid AttrValue proto"); + return; + } + + mutex_lock l(graph->mu); + op->node.AddAttr(attr_name, attr_val); + tensorflow::RecordMutation(graph, *op, "setting attribute"); +} + +void TF_ClearAttr(TF_Graph* graph, TF_Operation* op, const char* attr_name, + TF_Status* status) { + mutex_lock l(graph->mu); + op->node.ClearAttr(attr_name); + tensorflow::RecordMutation(graph, *op, "clearing attribute"); +} + +void TF_SetFullType(TF_Graph* graph, TF_Operation* op, + const tensorflow::FullTypeDef& full_type) { + mutex_lock l(graph->mu); + *op->node.mutable_def()->mutable_experimental_type() = full_type; + tensorflow::RecordMutation(graph, *op, "setting fulltype"); +} + +void TF_SetRequestedDevice(TF_Graph* graph, TF_Operation* op, const char* device) { + mutex_lock l(graph->mu); + op->node.set_requested_device(device); + tensorflow::RecordMutation(graph, *op, "setting device"); +} + +void TF_UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst, + TF_Status* status) { + TF_UpdateEdge(graph, new_src, dst, status); +} + +void TF_RemoveAllControlInputs(TF_Graph* graph, TF_Operation* op) { + mutex_lock l(graph->mu); + std::vector control_edges; + for (const tensorflow::Edge* edge : op->node.in_edges()) { + if (!edge->IsControlEdge()) continue; + control_edges.push_back(edge); + } + for (const tensorflow::Edge* edge : control_edges) { + graph->graph.RemoveControlEdge(edge); + } +} + +void TF_SetRequireShapeInferenceFns(TF_Graph* graph, bool require) { + mutex_lock l(graph->mu); + graph->refiner.set_require_shape_inference_fns(require); +} + +void TF_ExtendSession(TF_Session* session, TF_Status* status) { + ExtendSessionGraphHelper(session, status); + session->extend_before_run = false; +} + +const char* TF_GetHandleShapeAndType(TF_Graph* graph, TF_Output output) { + Node* node = &output.oper->node; + tensorflow::CppShapeInferenceResult::HandleData handle_data; + handle_data.set_is_set(true); + { + mutex_lock l(graph->mu); + tensorflow::shape_inference::InferenceContext* ic = + graph->refiner.GetContext(node); + CHECK(ic != nullptr); + CHECK_LT(output.index, ic->num_outputs()); + const auto* shapes_and_types = + ic->output_handle_shapes_and_types(output.index); + if (shapes_and_types == nullptr) return ""; + + for (const auto& p : *shapes_and_types) { + auto* out_shape_and_type = handle_data.add_shape_and_type(); + ic->ShapeHandleToProto(p.shape, out_shape_and_type->mutable_shape()); + out_shape_and_type->set_dtype(p.dtype); + *out_shape_and_type->mutable_type() = p.type; + } + } + string result; + handle_data.SerializeToString(&result); + return result.c_str(); +} + +void TF_SetHandleShapeAndType(TF_Graph* graph, TF_Output output, const void* proto, + size_t proto_len, TF_Status* status) { + tensorflow::CppShapeInferenceResult::HandleData handle_data; + if (!handle_data.ParseFromArray(proto, proto_len)) { + status->status = tensorflow::errors::InvalidArgument( + "Couldn't deserialize HandleData proto"); + return; + } + DCHECK(handle_data.is_set()); + + tensorflow::mutex_lock l(graph->mu); + tensorflow::shape_inference::InferenceContext* ic = + graph->refiner.GetContext(&output.oper->node); + + std::vector shapes_and_types; + for (const auto& shape_and_type_proto : handle_data.shape_and_type()) { + tensorflow::shape_inference::ShapeHandle shape; + status->status = + ic->MakeShapeFromShapeProto(shape_and_type_proto.shape(), &shape); + if (TF_GetCode(status) != TF_OK) return; + shapes_and_types.emplace_back(shape, shape_and_type_proto.dtype(), + shape_and_type_proto.type()); + } + ic->set_output_handle_shapes_and_types(output.index, shapes_and_types); +} + +void TF_AddWhileInputHack(TF_Graph* graph, TF_Output new_src, TF_Operation* dst, + TF_Status* status) { + mutex_lock l(graph->mu); + status->status = graph->graph.AddWhileInputHack(&new_src.oper->node, + new_src.index, &dst->node); + if (TF_GetCode(status) == TF_OK) { + // This modification only updates the destination node for + // the purposes of running this graph in a session. Thus, we don't + // record the source node as being modified. + tensorflow::RecordMutation(graph, *dst, "adding input tensor"); + } +} + +// ------------------------------------------------------------------- + // TF_Server functions ---------------------------------------------- #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h index e4c6499506ec76..5800a3fbc4b571 100644 --- a/tensorflow/c/c_api.h +++ b/tensorflow/c/c_api.h @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/c/tf_status.h" #include "tensorflow/c/tf_tensor.h" #include "tensorflow/c/tf_tstring.h" +#include "tensorflow/core/framework/full_type.pb.h" // -------------------------------------------------------------------------- // C API for TensorFlow. @@ -1577,6 +1578,48 @@ TF_CAPI_EXPORT extern void TF_RegisterLogListener( TF_CAPI_EXPORT extern void TF_RegisterFilesystemPlugin( const char* plugin_filename, TF_Status* status); +// Apis that are correponding to python c api. -------------------- + +TF_CAPI_EXPORT extern void TF_AddControlInput(TF_Graph* graph, TF_Operation* op, TF_Operation* input); + +TF_CAPI_EXPORT extern void TF_SetAttr(TF_Graph* graph, TF_Operation* op, + const char* attr_name, + TF_Buffer* attr_value_proto, + TF_Status* status); + +TF_CAPI_EXPORT extern void TF_ClearAttr(TF_Graph* graph, TF_Operation* op, + const char* attr_name, + TF_Status* status); + +TF_CAPI_EXPORT extern void TF_SetFullType(TF_Graph* graph, TF_Operation* op, + const tensorflow::FullTypeDef& full_type); + +TF_CAPI_EXPORT extern void TF_SetRequestedDevice(TF_Graph* graph, + TF_Operation* op, + const char* device); + +TF_CAPI_EXPORT extern void TF_UpdateEdge(TF_Graph* graph, TF_Output new_src, + TF_Input dst, TF_Status* status); + +TF_CAPI_EXPORT extern void TF_RemoveAllControlInputs(TF_Graph* graph, TF_Operation* op); + +TF_CAPI_EXPORT extern void TF_SetRequireShapeInferenceFns(TF_Graph* graph, bool require); + +TF_CAPI_EXPORT extern void TF_ExtendSession(TF_Session* session, TF_Status* status); + +TF_CAPI_EXPORT extern const char* TF_GetHandleShapeAndType(TF_Graph* graph, TF_Output output); + +TF_CAPI_EXPORT extern void TF_SetHandleShapeAndType(TF_Graph* graph, + TF_Output output, + const void* proto, + size_t proto_len, + TF_Status* status); + +void TFC_AddWhileInputHack(TF_Graph* graph, TF_Output new_src, TF_Operation* dst, + TF_Status* status); + +// ---------------------------------------------------------------- + #ifdef __cplusplus } /* end extern "C" */ #endif diff --git a/tensorflow/core/framework/BUILD b/tensorflow/core/framework/BUILD index a4d7400ad8d392..90a78a8edbdfdb 100644 --- a/tensorflow/core/framework/BUILD +++ b/tensorflow/core/framework/BUILD @@ -30,7 +30,7 @@ package( default_visibility = [ "//tensorflow/core:__subpackages__", "//tensorflow/security/fuzzing:__subpackages__", - # TODO(pedaveeraiah): to be removed when summary.proto.h deps moves to TSL +# TODO(pedaveeraiah): to be removed when summary.proto.h deps moves to TSL "//tensorflow/tsl/lib:__subpackages__", # copybara:uncomment "//learning/brain/tfrt/aot:__subpackages__", # copybara:uncomment "//platforms/xla/megascale/tensorflow:__subpackages__", @@ -114,6 +114,7 @@ exports_files( srcs = [ "allocation_description.proto", "api_def.proto", + "cpp_shape_inference.proto", "attr_value.proto", "cost_graph.proto", "dataset_metadata.proto", @@ -1389,7 +1390,7 @@ cc_library( # protos from the same package, so we can build the protos here and then # link them from core:protos_all without circular dependencies. -# Generate the C++ sources for some of the protos. +#Generate the C++ sources for some of the protos. tf_generate_proto_text_sources( name = "attr_value_proto_text", srcs = [ @@ -1690,6 +1691,18 @@ tf_proto_library( ], ) +tf_proto_library( + name = "cpp_shape_inference_proto", + srcs = ["cpp_shape_inference.proto"], + cc_api_version = 2, + make_default_target_header_only = True, + protodeps = [ + ":full_type_proto", + ":tensor_shape_proto", + ":types_proto", + ], +) + tf_proto_library( name = "variable_proto", srcs = ["variable.proto"], @@ -1757,7 +1770,7 @@ tf_proto_library( # ":function_proto", # ], # ) -# copybara:uncomment_end +#copybara : uncomment_end tf_proto_library( name = "summary_proto", @@ -1828,6 +1841,7 @@ tf_proto_library( protodeps = [ ":allocation_description_proto", ":api_def_proto", + ":cpp_shape_inference_proto", ":attr_value_proto", ":cost_graph_proto", ":dataset_proto", diff --git a/tensorflow/core/framework/cpp_shape_inference.proto b/tensorflow/core/framework/cpp_shape_inference.proto new file mode 100644 index 00000000000000..d2fd1f29f23b87 --- /dev/null +++ b/tensorflow/core/framework/cpp_shape_inference.proto @@ -0,0 +1,36 @@ +syntax = "proto3"; + +package tensorflow; + +import "tensorflow/core/framework/full_type.proto"; +import "tensorflow/core/framework/tensor_shape.proto"; +import "tensorflow/core/framework/types.proto"; + +option cc_enable_arenas = true; +option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/python/framework/cpp_shape_inference_go_proto"; + +message CppShapeInferenceResult { + message HandleShapeAndType { + reserved 3; + + TensorShapeProto shape = 1; + DataType dtype = 2; + FullTypeDef type = 4; + } + message HandleData { + bool is_set = 1; + + // Only valid if . + repeated HandleShapeAndType shape_and_type = 2; + } + TensorShapeProto shape = 1; + + reserved 2; // was handle_shape + reserved 3; // was handle_dtype + HandleData handle_data = 4; +} + +message CppShapeInferenceInputsNeeded { + repeated int32 input_tensors_needed = 1; + repeated int32 input_tensors_as_shapes_needed = 2; +} From 8f10182866e4365fcf3f88e2aaf9f3d4200ae199 Mon Sep 17 00:00:00 2001 From: Yaohui Liu Date: Mon, 8 May 2023 15:45:12 +0800 Subject: [PATCH 011/376] Revise some typos. --- tensorflow/core/framework/BUILD | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tensorflow/core/framework/BUILD b/tensorflow/core/framework/BUILD index 90a78a8edbdfdb..1f0b52fd65f072 100644 --- a/tensorflow/core/framework/BUILD +++ b/tensorflow/core/framework/BUILD @@ -30,7 +30,7 @@ package( default_visibility = [ "//tensorflow/core:__subpackages__", "//tensorflow/security/fuzzing:__subpackages__", -# TODO(pedaveeraiah): to be removed when summary.proto.h deps moves to TSL + # TODO(pedaveeraiah): to be removed when summary.proto.h deps moves to TSL "//tensorflow/tsl/lib:__subpackages__", # copybara:uncomment "//learning/brain/tfrt/aot:__subpackages__", # copybara:uncomment "//platforms/xla/megascale/tensorflow:__subpackages__", @@ -1390,7 +1390,7 @@ cc_library( # protos from the same package, so we can build the protos here and then # link them from core:protos_all without circular dependencies. -#Generate the C++ sources for some of the protos. +# Generate the C++ sources for some of the protos. tf_generate_proto_text_sources( name = "attr_value_proto_text", srcs = [ @@ -1770,7 +1770,7 @@ tf_proto_library( # ":function_proto", # ], # ) -#copybara : uncomment_end +# copybara:uncomment_end tf_proto_library( name = "summary_proto", From 62c682cc46dc5a84f44af78e5e09ec985d075d6e Mon Sep 17 00:00:00 2001 From: Yaohui Liu Date: Mon, 15 May 2023 06:09:49 +0800 Subject: [PATCH 012/376] Fix the PR run error. --- tensorflow/c/c_api.cc | 2 +- tensorflow/c/c_api.h | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index 599b676bbd69b7..86604ddbcba67f 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -2576,7 +2576,7 @@ void TF_UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst, // Apis that are correponding to python c api. -------------------------- -void TF_AddControlInput(TF_Graph* graph, TF_Operation* op, TF_Operation* input) { +void TF_AddOperationControlInput(TF_Graph* graph, TF_Operation* op, TF_Operation* input) { mutex_lock l(graph->mu); graph->graph.AddControlEdge(&input->node, &op->node); tensorflow::RecordMutation(graph, *op, "adding control input"); diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h index 5800a3fbc4b571..78494827e03e83 100644 --- a/tensorflow/c/c_api.h +++ b/tensorflow/c/c_api.h @@ -1580,7 +1580,7 @@ TF_CAPI_EXPORT extern void TF_RegisterFilesystemPlugin( // Apis that are correponding to python c api. -------------------- -TF_CAPI_EXPORT extern void TF_AddControlInput(TF_Graph* graph, TF_Operation* op, TF_Operation* input); +TF_CAPI_EXPORT extern void TF_AddOperationControlInput(TF_Graph* graph, TF_Operation* op, TF_Operation* input); TF_CAPI_EXPORT extern void TF_SetAttr(TF_Graph* graph, TF_Operation* op, const char* attr_name, From 0bef47506a880ddd134ff3b52e9352d0fb152c17 Mon Sep 17 00:00:00 2001 From: Yaohui Liu Date: Mon, 15 May 2023 11:20:01 +0800 Subject: [PATCH 013/376] Fix the PR run error. --- tensorflow/c/c_api.cc | 5 ----- tensorflow/c/c_api.h | 30 ++++++++++++++++++++++++++---- 2 files changed, 26 insertions(+), 9 deletions(-) diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index 86604ddbcba67f..c7f8befc2c045a 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -2618,11 +2618,6 @@ void TF_SetRequestedDevice(TF_Graph* graph, TF_Operation* op, const char* device tensorflow::RecordMutation(graph, *op, "setting device"); } -void TF_UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst, - TF_Status* status) { - TF_UpdateEdge(graph, new_src, dst, status); -} - void TF_RemoveAllControlInputs(TF_Graph* graph, TF_Operation* op) { mutex_lock l(graph->mu); std::vector control_edges; diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h index 78494827e03e83..d3b489374adab0 100644 --- a/tensorflow/c/c_api.h +++ b/tensorflow/c/c_api.h @@ -1580,17 +1580,23 @@ TF_CAPI_EXPORT extern void TF_RegisterFilesystemPlugin( // Apis that are correponding to python c api. -------------------- +// Add control input to `op`. TF_CAPI_EXPORT extern void TF_AddOperationControlInput(TF_Graph* graph, TF_Operation* op, TF_Operation* input); +// Changes an attr value in the node_def Protocol Buffer and sets a status upon +// completion. TF_CAPI_EXPORT extern void TF_SetAttr(TF_Graph* graph, TF_Operation* op, const char* attr_name, TF_Buffer* attr_value_proto, TF_Status* status); +// Clears the attr in the node_def Protocol Buffer and sets a status upon +// completion. TF_CAPI_EXPORT extern void TF_ClearAttr(TF_Graph* graph, TF_Operation* op, const char* attr_name, TF_Status* status); +// Sets the experimental_type` field in the node_def Protocol Buffer. TF_CAPI_EXPORT extern void TF_SetFullType(TF_Graph* graph, TF_Operation* op, const tensorflow::FullTypeDef& full_type); @@ -1598,24 +1604,40 @@ TF_CAPI_EXPORT extern void TF_SetRequestedDevice(TF_Graph* graph, TF_Operation* op, const char* device); -TF_CAPI_EXPORT extern void TF_UpdateEdge(TF_Graph* graph, TF_Output new_src, - TF_Input dst, TF_Status* status); - TF_CAPI_EXPORT extern void TF_RemoveAllControlInputs(TF_Graph* graph, TF_Operation* op); TF_CAPI_EXPORT extern void TF_SetRequireShapeInferenceFns(TF_Graph* graph, bool require); +// Extends `session` with any new operations added to its associated graph. +// Usually this happens automatically in TF_SessionRun. After this is called, +// TF_SessionRun will no longer extend the session on every call. +// +// We expose this here to allow fine-grained synchronization in multi-threaded +// workloads, which is required since the Python implementation depends on the +// above mutation methods. This allows us to prevent modifications to nodes in +// the graph after the session has been made aware of them. TF_CAPI_EXPORT extern void TF_ExtendSession(TF_Session* session, TF_Status* status); +// Returns the serialized CppShapeInferenceResult::HandleData proto for +// `output` if its a resource or variant tensor, or otherwise returns the empty +// string. TF_CAPI_EXPORT extern const char* TF_GetHandleShapeAndType(TF_Graph* graph, TF_Output output); +// Sets `output` based on `proto`, which should be a serialized +// CppShapeInferenceResult::HandleData proto. `output` should be a resource +// or variant tensor. +// NOTE(skyewm): `proto` is passed a void*/size_t pair instead of a std::string +// because I couldn't get SWIG to work otherwise. TF_CAPI_EXPORT extern void TF_SetHandleShapeAndType(TF_Graph* graph, TF_Output output, const void* proto, size_t proto_len, TF_Status* status); -void TFC_AddWhileInputHack(TF_Graph* graph, TF_Output new_src, TF_Operation* dst, +// This method is used to add a new input edge to 'dst', which must be a While +// op. The While op's "T" attribute must have already been updated to include +// the new edge. This is used to construct tf.while_loop gradients. +TF_CAPI_EXPORT extern void TF_AddWhileInputHack(TF_Graph* graph, TF_Output new_src, TF_Operation* dst, TF_Status* status); // ---------------------------------------------------------------- From cdfdc3120c5280ea0c82249500420f1e68f78285 Mon Sep 17 00:00:00 2001 From: Yaohui Liu Date: Mon, 15 May 2023 14:58:51 +0800 Subject: [PATCH 014/376] Remoce unneccesary dependencies. --- tensorflow/c/BUILD | 1 - tensorflow/c/python_api.cc | 88 ++++---------------------------------- 2 files changed, 9 insertions(+), 80 deletions(-) diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD index 0e70244453f1a9..2be76fbcc15438 100644 --- a/tensorflow/c/BUILD +++ b/tensorflow/c/BUILD @@ -1092,7 +1092,6 @@ tf_cuda_library( ":c_api_internal", "//tensorflow/core:protos_all_cc", # TODO(b/74620627): remove when _USE_C_SHAPES is removed - "//tensorflow/python/framework:cpp_shape_inference_proto_cc", ], alwayslink = 1, ) diff --git a/tensorflow/c/python_api.cc b/tensorflow/c/python_api.cc index faf93475541da3..35f782b5f8629b 100644 --- a/tensorflow/c/python_api.cc +++ b/tensorflow/c/python_api.cc @@ -22,44 +22,26 @@ limitations under the License. namespace tensorflow { void AddControlInput(TF_Graph* graph, TF_Operation* op, TF_Operation* input) { - mutex_lock l(graph->mu); - graph->graph.AddControlEdge(&input->node, &op->node); - RecordMutation(graph, *op, "adding control input"); + TF_AddOperationControlInput(graph, op, input); } void SetAttr(TF_Graph* graph, TF_Operation* op, const char* attr_name, TF_Buffer* attr_value_proto, TF_Status* status) { - AttrValue attr_val; - if (!attr_val.ParseFromArray(attr_value_proto->data, - attr_value_proto->length)) { - status->status = - tensorflow::errors::InvalidArgument("Invalid AttrValue proto"); - return; - } - - mutex_lock l(graph->mu); - op->node.AddAttr(attr_name, attr_val); - RecordMutation(graph, *op, "setting attribute"); + TF_SetAttr(graph, op, attr_name, attr_value_proto, status); } void ClearAttr(TF_Graph* graph, TF_Operation* op, const char* attr_name, TF_Status* status) { - mutex_lock l(graph->mu); - op->node.ClearAttr(attr_name); - RecordMutation(graph, *op, "clearing attribute"); + TF_ClearAttr(graph, op, attr_name, status); } void SetFullType(TF_Graph* graph, TF_Operation* op, const FullTypeDef& full_type) { - mutex_lock l(graph->mu); - *op->node.mutable_def()->mutable_experimental_type() = full_type; - RecordMutation(graph, *op, "setting fulltype"); + TF_SetFullType(graph, op, full_type); } void SetRequestedDevice(TF_Graph* graph, TF_Operation* op, const char* device) { - mutex_lock l(graph->mu); - op->node.set_requested_device(device); - RecordMutation(graph, *op, "setting device"); + TF_SetRequestedDevice(graph, op, device); } void UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst, @@ -68,73 +50,21 @@ void UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst, } void ExtendSession(TF_Session* session, TF_Status* status) { - ExtendSessionGraphHelper(session, status); - session->extend_before_run = false; + TF_ExtendSession(session, status) } std::string GetHandleShapeAndType(TF_Graph* graph, TF_Output output) { - Node* node = &output.oper->node; - CppShapeInferenceResult::HandleData handle_data; - handle_data.set_is_set(true); - { - mutex_lock l(graph->mu); - tensorflow::shape_inference::InferenceContext* ic = - graph->refiner.GetContext(node); - CHECK(ic != nullptr); - CHECK_LT(output.index, ic->num_outputs()); - const auto* shapes_and_types = - ic->output_handle_shapes_and_types(output.index); - if (shapes_and_types == nullptr) return ""; - - for (const auto& p : *shapes_and_types) { - auto* out_shape_and_type = handle_data.add_shape_and_type(); - ic->ShapeHandleToProto(p.shape, out_shape_and_type->mutable_shape()); - out_shape_and_type->set_dtype(p.dtype); - *out_shape_and_type->mutable_type() = p.type; - } - } - string result; - handle_data.SerializeToString(&result); - return result; + return TF_GetHandleShapeAndType(graph, output); } void SetHandleShapeAndType(TF_Graph* graph, TF_Output output, const void* proto, size_t proto_len, TF_Status* status) { - tensorflow::CppShapeInferenceResult::HandleData handle_data; - if (!handle_data.ParseFromArray(proto, proto_len)) { - status->status = tensorflow::errors::InvalidArgument( - "Couldn't deserialize HandleData proto"); - return; - } - DCHECK(handle_data.is_set()); - - tensorflow::mutex_lock l(graph->mu); - tensorflow::shape_inference::InferenceContext* ic = - graph->refiner.GetContext(&output.oper->node); - - std::vector shapes_and_types; - for (const auto& shape_and_type_proto : handle_data.shape_and_type()) { - tensorflow::shape_inference::ShapeHandle shape; - status->status = - ic->MakeShapeFromShapeProto(shape_and_type_proto.shape(), &shape); - if (TF_GetCode(status) != TF_OK) return; - shapes_and_types.emplace_back(shape, shape_and_type_proto.dtype(), - shape_and_type_proto.type()); - } - ic->set_output_handle_shapes_and_types(output.index, shapes_and_types); + TF_SetHandleShapeAndType(graph, output, proto, proto_len, status); } void AddWhileInputHack(TF_Graph* graph, TF_Output new_src, TF_Operation* dst, TF_Status* status) { - mutex_lock l(graph->mu); - status->status = graph->graph.AddWhileInputHack(&new_src.oper->node, - new_src.index, &dst->node); - if (TF_GetCode(status) == TF_OK) { - // This modification only updates the destination node for - // the purposes of running this graph in a session. Thus, we don't - // record the source node as being modified. - RecordMutation(graph, *dst, "adding input tensor"); - } + TF_AddWhileInputHack(graph, new_src, dst, status); } } // namespace tensorflow From f7256ea82aa4fd6e58836f19b087a0cbac53a61d Mon Sep 17 00:00:00 2001 From: Yaohui Liu Date: Mon, 15 May 2023 16:25:37 +0800 Subject: [PATCH 015/376] Fix ci error. --- tensorflow/c/python_api.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/c/python_api.cc b/tensorflow/c/python_api.cc index 35f782b5f8629b..550ef5480e212b 100644 --- a/tensorflow/c/python_api.cc +++ b/tensorflow/c/python_api.cc @@ -17,7 +17,7 @@ limitations under the License. #include "tensorflow/c/c_api_internal.h" #include "tensorflow/core/framework/full_type.pb.h" -#include "tensorflow/python/framework/cpp_shape_inference.pb.h" +#include "tensorflow/core/framework/cpp_shape_inference.pb.h" namespace tensorflow { From e5d012a9f484134ed48706b777e3ef850816f7e2 Mon Sep 17 00:00:00 2001 From: Yimei Sun Date: Tue, 16 May 2023 17:01:08 -0700 Subject: [PATCH 016/376] [oneDNN] Update fused instance norm kernel to support oneDNN v3.x library --- .../kernels/mkl/mkl_fused_instance_norm_op.cc | 88 ++++++++++++++++--- 1 file changed, 74 insertions(+), 14 deletions(-) diff --git a/tensorflow/core/kernels/mkl/mkl_fused_instance_norm_op.cc b/tensorflow/core/kernels/mkl/mkl_fused_instance_norm_op.cc index 6373bf09539fe4..7e2b57c7a08485 100644 --- a/tensorflow/core/kernels/mkl/mkl_fused_instance_norm_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_fused_instance_norm_op.cc @@ -102,18 +102,26 @@ class MklFusedInstanceNormOp : public OpKernel { void* src_buf = static_cast(const_cast(src_tensor.flat().data())); +#ifndef ENABLE_ONEDNN_V3 +#define NUM_DUPLICATE 2 +#else +#define NUM_DUPLICATE 1 +#endif // !ENABLE_ONEDNN_V3 memory::dims scale_shift_dims = { - 2, static_cast(num_elements_scale)}; + static_cast(NUM_DUPLICATE * num_elements_scale)}; auto scale_shift_md = memory::desc(scale_shift_dims, MklDnnType(), - memory::format_tag::nc); - Tensor scale_shift_tensor; + memory::format_tag::x); int64_t tensor_shape = scale_shift_md.get_size() / sizeof(float); +#undef NUM_DUPLICATE + +#ifndef ENABLE_ONEDNN_V3 + Tensor scale_shift_tensor; OP_REQUIRES_OK( ctx, ctx->allocate_temp(DataTypeToEnum::v(), {tensor_shape}, &scale_shift_tensor)); void* scale_shift_buf = static_cast(scale_shift_tensor.flat().data()); - SetupScaleShiftBuffer(scale_tensor, shift_tensor, engine_stream_ptr, + SetupScaleShiftBuffer(ctx, scale_tensor, shift_tensor, engine_stream_ptr, num_elements_scale, scale_shift_buf); auto scale_shift_mem = memory(scale_shift_md, cpu_engine_, scale_shift_buf); @@ -122,18 +130,53 @@ class MklFusedInstanceNormOp : public OpKernel { auto bnorm_desc = batch_normalization_forward::desc( prop_kind::forward_inference, src_md, epsilon_, normalization_flags::use_scale_shift); +#else + Tensor scale_fp32_tensor; + Tensor shift_fp32_tensor; + OP_REQUIRES_OK( + ctx, ctx->allocate_temp(DataTypeToEnum::v(), {tensor_shape}, + &scale_fp32_tensor)); + OP_REQUIRES_OK( + ctx, ctx->allocate_temp(DataTypeToEnum::v(), {tensor_shape}, + &shift_fp32_tensor)); + void* scale_fp32_buf = + static_cast(scale_fp32_tensor.flat().data()); + void* shift_fp32_buf = + static_cast(shift_fp32_tensor.flat().data()); + + SetupScaleShiftBuffer(ctx, scale_tensor, shift_tensor, engine_stream_ptr, + num_elements_scale, scale_fp32_buf, shift_fp32_buf); + auto scale_mem = memory(scale_shift_md, cpu_engine_, scale_fp32_buf); + auto shift_mem = memory(scale_shift_md, cpu_engine_, shift_fp32_buf); +#endif // !ENABLE_ONEDNN_V3 batch_normalization_forward::primitive_desc bnorm_pd; if (fuse_activation_) { dnnl::post_ops post_ops; dnnl::primitive_attr post_ops_attr; +#ifndef ENABLE_ONEDNN_V3 post_ops.append_eltwise(1.0, dnnl::algorithm::eltwise_relu, leakyrelu_alpha_, 0.0); post_ops_attr.set_post_ops(post_ops); bnorm_pd = batch_normalization_forward::primitive_desc( bnorm_desc, post_ops_attr, cpu_engine_); +#else + post_ops.append_eltwise(dnnl::algorithm::eltwise_relu, leakyrelu_alpha_, + 0.0); + post_ops_attr.set_post_ops(post_ops); + bnorm_pd = batch_normalization_forward::primitive_desc( + cpu_engine_, prop_kind::forward_inference, src_md, src_md, epsilon_, + normalization_flags::use_scale | normalization_flags::use_shift, + post_ops_attr); +#endif // !ENABLE_ONEDNN_V3 } else { +#ifndef ENABLE_ONEDNN_V3 bnorm_pd = batch_normalization_forward::primitive_desc(bnorm_desc, cpu_engine_); +#else + bnorm_pd = batch_normalization_forward::primitive_desc( + cpu_engine_, prop_kind::forward_inference, src_md, src_md, epsilon_, + normalization_flags::use_scale | normalization_flags::use_shift); +#endif // !ENABLE_ONEDNN_V3 } auto bnorm_prim = batch_normalization_forward(bnorm_pd); @@ -154,8 +197,13 @@ class MklFusedInstanceNormOp : public OpKernel { std::unordered_map bnorm_args; bnorm_args.insert({DNNL_ARG_SRC, *src_mem_ptr}); - bnorm_args.insert({DNNL_ARG_SCALE_SHIFT, scale_shift_mem}); bnorm_args.insert({DNNL_ARG_DST, *dst_mem_ptr}); +#ifndef ENABLE_ONEDNN_V3 + bnorm_args.insert({DNNL_ARG_SCALE_SHIFT, scale_shift_mem}); +#else + bnorm_args.insert({DNNL_ARG_SCALE, scale_mem}); + bnorm_args.insert({DNNL_ARG_SHIFT, shift_mem}); +#endif // !ENABLE_ONEDNN_V3 // Perform batchnorm computation for each batch in input for (int i = 0; i < batch_size; i++) { @@ -221,23 +269,35 @@ class MklFusedInstanceNormOp : public OpKernel { return valid; } - // Helper function to add scale and shift data into same buffer in float - // type as requested by oneDNN - void SetupScaleShiftBuffer(const Tensor& scale_tensor, + // Helper function to prepare scale and shift data in float type as + // required by oneDNN library. Prior to oneDNN 3.x version, the library + // requires the final scale and shift data to be passed in the same buffer + // wherase the 3.x version requires separate buffers for scale and shift + // data. + void SetupScaleShiftBuffer(OpKernelContext* ctx, const Tensor& scale_tensor, const Tensor& shift_tensor, std::shared_ptr engine_stream_ptr, - int num_elements, void* scale_shift_buf) { + int num_elements, void* fp32_scale_or_combine_buf, + void* fp32_shift_buf = nullptr) { void* scale_buf_src = static_cast(const_cast(scale_tensor.flat().data())); void* shift_buf_src = static_cast(const_cast(shift_tensor.flat().data())); - auto scale_offset = sizeof(float) * num_elements; - void* scale_buf_dst = scale_shift_buf; - void* shift_buf_dst = static_cast(scale_shift_buf) + scale_offset; + auto data_size = sizeof(float) * num_elements; + void* scale_buf_dst = fp32_scale_or_combine_buf; + void* shift_buf_dst = nullptr; +#ifndef ENABLE_ONEDNN_V3 + shift_buf_dst = static_cast(fp32_scale_or_combine_buf) + data_size; + (void)fp32_shift_buf; +#else + OP_REQUIRES(ctx, (fp32_shift_buf != nullptr), + errors::InvalidArgument("Invalid shift buffer")); + shift_buf_dst = fp32_shift_buf; +#endif // !ENABLE_ONEDNN_V3 if (std::is_same::value) { - memcpy(scale_buf_dst, scale_buf_src, scale_offset); - memcpy(shift_buf_dst, shift_buf_src, scale_offset); + memcpy(scale_buf_dst, scale_buf_src, data_size); + memcpy(shift_buf_dst, shift_buf_src, data_size); } else { // oneDNN requires float type for scale_shift, need to convert to float // type From 63a9b86f6c5ee5911492089271769e1658f261d7 Mon Sep 17 00:00:00 2001 From: Yaohui Liu Date: Thu, 1 Jun 2023 19:47:45 +0800 Subject: [PATCH 017/376] fix ci error. --- tensorflow/c/python_api.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/c/python_api.cc b/tensorflow/c/python_api.cc index 550ef5480e212b..cd0e45aabcc2b1 100644 --- a/tensorflow/c/python_api.cc +++ b/tensorflow/c/python_api.cc @@ -50,7 +50,7 @@ void UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst, } void ExtendSession(TF_Session* session, TF_Status* status) { - TF_ExtendSession(session, status) + TF_ExtendSession(session, status); } std::string GetHandleShapeAndType(TF_Graph* graph, TF_Output output) { From af0c31875f8ac92a3047f991d8b4956a70384e08 Mon Sep 17 00:00:00 2001 From: Yaohui Liu Date: Sun, 4 Jun 2023 11:18:35 +0800 Subject: [PATCH 018/376] Add more docstrings. --- tensorflow/c/c_api.h | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h index d3b489374adab0..40fff3999acadf 100644 --- a/tensorflow/c/c_api.h +++ b/tensorflow/c/c_api.h @@ -1600,12 +1600,15 @@ TF_CAPI_EXPORT extern void TF_ClearAttr(TF_Graph* graph, TF_Operation* op, TF_CAPI_EXPORT extern void TF_SetFullType(TF_Graph* graph, TF_Operation* op, const tensorflow::FullTypeDef& full_type); +// Set the requested device for `graph`. TF_CAPI_EXPORT extern void TF_SetRequestedDevice(TF_Graph* graph, TF_Operation* op, const char* device); +// Remove all the control inputs from `op` in `graph`. TF_CAPI_EXPORT extern void TF_RemoveAllControlInputs(TF_Graph* graph, TF_Operation* op); +// Set if `graph` requires shape inference functions. TF_CAPI_EXPORT extern void TF_SetRequireShapeInferenceFns(TF_Graph* graph, bool require); // Extends `session` with any new operations added to its associated graph. From 90a352b9059ba53ccdf647a1655bf95796eafa13 Mon Sep 17 00:00:00 2001 From: Mahmoud Abuzaina Date: Thu, 8 Jun 2023 18:24:06 -0700 Subject: [PATCH 019/376] Refactored some code based on review of similar PR --- tensorflow/core/kernels/mkl/mkl_conv_ops.cc | 194 +++++++------------- 1 file changed, 70 insertions(+), 124 deletions(-) diff --git a/tensorflow/core/kernels/mkl/mkl_conv_ops.cc b/tensorflow/core/kernels/mkl/mkl_conv_ops.cc index e5677a8a1c1705..ffd72971770ef7 100644 --- a/tensorflow/core/kernels/mkl/mkl_conv_ops.cc +++ b/tensorflow/core/kernels/mkl/mkl_conv_ops.cc @@ -70,6 +70,12 @@ namespace tensorflow { #define SUMMAND_SCALE_S8(summand_range, output_range) summand_range / 127.0f #endif // !ENABLE_ONEDNN_V3 +#if !defined(ENABLE_ONEDNN_OPENMP) && !defined(ENABLE_ONEDNN_V3) +#define FWD_STREAM , *fwd_stream +#else +#define FWD_STREAM +#endif // !ENABLE_ONEDNN_OPENMP && !ENABLE_ONEDNN_V3 + // TODO(intel-tf) Remove this once old API of quantized ops is abandoned namespace quantized_fusions { string none[] = {""}; @@ -171,89 +177,44 @@ class MklConvFwdPrimitive : public MklPrimitive { // should happen under the lock. mutex_lock lock(primitive_execution_mu_); #endif -#if !defined(ENABLE_ONEDNN_OPENMP) && !defined(ENABLE_ONEDNN_V3) - context_.src_mem->set_data_handle( - static_cast(const_cast(src_data)), *fwd_stream); - context_.filter_mem->set_data_handle( - static_cast(const_cast(filter_data)), *fwd_stream); - if (bias_data != nullptr) { - context_.bias_mem->set_data_handle(const_cast(bias_data), - *fwd_stream); - } - auto const& post_op_params = convFwdDims.post_op_params; - if (!post_op_params.empty()) { - for (auto const& post_op_param : post_op_params) { - if (post_op_param.name == "src_scale") { - context_.src_scale_mem->set_data_handle( - static_cast( - const_cast(post_op_param.param.data())), - *fwd_stream); - } else if (post_op_param.name == "wei_scale") { - context_.wei_scale_mem->set_data_handle( - static_cast( - const_cast(post_op_param.param.data())), - *fwd_stream); - } else if (post_op_param.name == "dst_scale") { - context_.dst_scale_mem->set_data_handle( - static_cast( - const_cast(post_op_param.param.data())), - *fwd_stream); - } - } - } - if (bn_scale_data != nullptr) { - context_.bn_scale_mem->set_data_handle( - static_cast(const_cast(bn_scale_data)), *fwd_stream); - context_.bn_mean_mem->set_data_handle( - static_cast(const_cast(bn_mean_data)), *fwd_stream); - context_.bn_rsqrt_mem->set_data_handle( - static_cast(const_cast(bn_rsqrt_data)), *fwd_stream); - context_.bn_offset_mem->set_data_handle( - static_cast(const_cast(bn_offset_data)), *fwd_stream); - } - context_.dst_mem->set_data_handle( - static_cast(const_cast(dst_data)), *fwd_stream); - if (sp_data) { - context_.sp_mem->set_data_handle(static_cast(sp_data), - *fwd_stream); - } -#else context_.src_mem->set_data_handle( - static_cast(const_cast(src_data))); + static_cast(const_cast(src_data)) FWD_STREAM); context_.filter_mem->set_data_handle( - static_cast(const_cast(filter_data))); + static_cast(const_cast(filter_data)) FWD_STREAM); if (bias_data != nullptr) { - context_.bias_mem->set_data_handle(const_cast(bias_data)); + context_.bias_mem->set_data_handle(const_cast(bias_data) + FWD_STREAM); } auto const& post_op_params = convFwdDims.post_op_params; if (!post_op_params.empty()) { for (auto const& post_op_param : post_op_params) { if (post_op_param.name == "src_scale") { context_.src_scale_mem->set_data_handle(static_cast( - const_cast(post_op_param.param.data()))); + const_cast(post_op_param.param.data())) FWD_STREAM); } else if (post_op_param.name == "wei_scale") { context_.wei_scale_mem->set_data_handle(static_cast( - const_cast(post_op_param.param.data()))); + const_cast(post_op_param.param.data())) FWD_STREAM); } else if (post_op_param.name == "dst_scale") { context_.dst_scale_mem->set_data_handle(static_cast( - const_cast(post_op_param.param.data()))); + const_cast(post_op_param.param.data())) FWD_STREAM); } } } if (bn_scale_data != nullptr) { context_.bn_scale_mem->set_data_handle( - static_cast(const_cast(bn_scale_data))); + static_cast(const_cast(bn_scale_data)) FWD_STREAM); context_.bn_mean_mem->set_data_handle( - static_cast(const_cast(bn_mean_data))); + static_cast(const_cast(bn_mean_data)) FWD_STREAM); context_.bn_rsqrt_mem->set_data_handle( - static_cast(const_cast(bn_rsqrt_data))); + static_cast(const_cast(bn_rsqrt_data)) FWD_STREAM); context_.bn_offset_mem->set_data_handle( - static_cast(const_cast(bn_offset_data))); + static_cast(const_cast(bn_offset_data)) FWD_STREAM); } context_.dst_mem->set_data_handle( - static_cast(const_cast(dst_data))); - if (sp_data) context_.sp_mem->set_data_handle(static_cast(sp_data)); -#endif // !ENABLE_ONEDNN_OPENMP && !ENABLE_ONEDNN_V3 + static_cast(const_cast(dst_data)) FWD_STREAM); + if (sp_data) { + context_.sp_mem->set_data_handle(static_cast(sp_data) FWD_STREAM); + } DCHECK_EQ(context_.fwd_primitives.size(), context_.fwd_primitives_args.size()); @@ -561,31 +522,25 @@ class MklConvFwdPrimitive : public MklPrimitive { new dnnl::memory(scratchpad_md, cpu_engine_, DummyData)); // Create convolution primitive and add it to net + std::unordered_map net_args; if (!convFwdDims.bias_dims.empty()) { context_.bias_mem.reset(new memory(context_.fwd_pd.get()->bias_desc(), cpu_engine_, DummyData)); - if (is_scale_set["src"] && is_scale_set["wei"] && is_scale_set["dst"]) { - context_.fwd_primitives_args.push_back( - {{DNNL_ARG_SRC, *context_.src_mem}, - {DNNL_ARG_WEIGHTS, *context_.filter_mem}, - {DNNL_ARG_BIAS, *context_.bias_mem}, - {DNNL_ARG_SCRATCHPAD, *context_.sp_mem}, - {DNNL_ARG_DST, *context_.dst_mem}, + net_args = {{DNNL_ARG_SRC, *context_.src_mem}, + {DNNL_ARG_WEIGHTS, *context_.filter_mem}, + {DNNL_ARG_BIAS, *context_.bias_mem}, + {DNNL_ARG_SCRATCHPAD, *context_.sp_mem}, + {DNNL_ARG_DST, *context_.dst_mem}}; #ifdef ENABLE_ONEDNN_V3 - {DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, *context_.src_scale_mem}, - {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, *context_.wei_scale_mem}, - { DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST, - *context_.dst_scale_mem } -#endif // ENABLE_ONEDNN_V3 - }); - } else { - context_.fwd_primitives_args.push_back( - {{DNNL_ARG_SRC, *context_.src_mem}, - {DNNL_ARG_WEIGHTS, *context_.filter_mem}, - {DNNL_ARG_BIAS, *context_.bias_mem}, - {DNNL_ARG_SCRATCHPAD, *context_.sp_mem}, - {DNNL_ARG_DST, *context_.dst_mem}}); + if (is_scale_set["src"] && is_scale_set["wei"] && is_scale_set["dst"]) { + net_args.insert( + {DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, *context_.src_scale_mem}); + net_args.insert( + {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, *context_.wei_scale_mem}); + net_args.insert( + {DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST, *context_.dst_scale_mem}); } +#endif // ENABLE_ONEDNN_V3 } else if (!convFwdDims.fuse_bn_dims.empty()) { context_.bn_scale_mem.reset( new memory(*context_.bn_scale_md, cpu_engine_, DummyData)); @@ -596,41 +551,35 @@ class MklConvFwdPrimitive : public MklPrimitive { context_.bn_rsqrt_mem.reset( new memory(*context_.bn_rsqrt_md, cpu_engine_, DummyData)); - context_.fwd_primitives_args.push_back( - {{DNNL_ARG_SRC, *context_.src_mem}, - {DNNL_ARG_WEIGHTS, *context_.filter_mem}, - {DNNL_ARG_DST, *context_.dst_mem}, - {DNNL_ARG_SCRATCHPAD, *context_.sp_mem}, - {DNNL_ARG_ATTR_MULTIPLE_POST_OP(0) | DNNL_ARG_SRC_1, - *context_.bn_mean_mem}, - {DNNL_ARG_ATTR_MULTIPLE_POST_OP(1) | DNNL_ARG_SRC_1, - *context_.bn_rsqrt_mem}, - {DNNL_ARG_ATTR_MULTIPLE_POST_OP(2) | DNNL_ARG_SRC_1, - *context_.bn_scale_mem}, - {DNNL_ARG_ATTR_MULTIPLE_POST_OP(3) | DNNL_ARG_SRC_1, - *context_.bn_offset_mem}}); + net_args = {{DNNL_ARG_SRC, *context_.src_mem}, + {DNNL_ARG_WEIGHTS, *context_.filter_mem}, + {DNNL_ARG_DST, *context_.dst_mem}, + {DNNL_ARG_SCRATCHPAD, *context_.sp_mem}, + {DNNL_ARG_ATTR_MULTIPLE_POST_OP(0) | DNNL_ARG_SRC_1, + *context_.bn_mean_mem}, + {DNNL_ARG_ATTR_MULTIPLE_POST_OP(1) | DNNL_ARG_SRC_1, + *context_.bn_rsqrt_mem}, + {DNNL_ARG_ATTR_MULTIPLE_POST_OP(2) | DNNL_ARG_SRC_1, + *context_.bn_scale_mem}, + {DNNL_ARG_ATTR_MULTIPLE_POST_OP(3) | DNNL_ARG_SRC_1, + *context_.bn_offset_mem}}; } else { - if (is_scale_set["src"] && is_scale_set["wei"] && is_scale_set["dst"]) { - context_.fwd_primitives_args.push_back( - {{DNNL_ARG_SRC, *context_.src_mem}, - {DNNL_ARG_WEIGHTS, *context_.filter_mem}, - {DNNL_ARG_SCRATCHPAD, *context_.sp_mem}, - {DNNL_ARG_DST, *context_.dst_mem}, + net_args = {{DNNL_ARG_SRC, *context_.src_mem}, + {DNNL_ARG_WEIGHTS, *context_.filter_mem}, + {DNNL_ARG_SCRATCHPAD, *context_.sp_mem}, + {DNNL_ARG_DST, *context_.dst_mem}}; #ifdef ENABLE_ONEDNN_V3 - {DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, *context_.src_scale_mem}, - {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, *context_.wei_scale_mem}, - { DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST, - *context_.dst_scale_mem } -#endif // ENABLE_ONEDNN_V3 - }); - } else { - context_.fwd_primitives_args.push_back( - {{DNNL_ARG_SRC, *context_.src_mem}, - {DNNL_ARG_WEIGHTS, *context_.filter_mem}, - {DNNL_ARG_SCRATCHPAD, *context_.sp_mem}, - {DNNL_ARG_DST, *context_.dst_mem}}); + if (is_scale_set["src"] && is_scale_set["wei"] && is_scale_set["dst"]) { + net_args.insert( + {DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, *context_.src_scale_mem}); + net_args.insert( + {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, *context_.wei_scale_mem}); + net_args.insert( + {DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST, *context_.dst_scale_mem}); } +#endif // ENABLE_ONEDNN_V3 } + context_.fwd_primitives_args.push_back(net_args); context_.fwd_primitives.push_back(*context_.conv_fwd); } @@ -719,15 +668,12 @@ class MklConvFwdPrimitiveFactory : public MklPrimitiveFactory { } #ifndef ENABLE_ONEDNN_V3 } else if (post_op_param.name == "output_scale") { - key_creator.AddAsKey(post_op_param.partial_key); #else - } else if (post_op_param.name == "src_scale") { - key_creator.AddAsKey(post_op_param.partial_key); - } else if (post_op_param.name == "wei_scale") { - key_creator.AddAsKey(post_op_param.partial_key); - } else if (post_op_param.name == "dst_scale") { - key_creator.AddAsKey(post_op_param.partial_key); + } else if (post_op_param.name == "src_scale" || + post_op_param.name == "wei_scale" || + post_op_param.name == "dst_scale") { #endif // !ENABLE_ONEDNN_V3 + key_creator.AddAsKey(post_op_param.partial_key); } else if (post_op_param.name == "fuse_bn") { key_creator.AddAsKey(post_op_param.name); key_creator.AddAsKey(convFwdDims.fuse_bn_dims); @@ -2487,7 +2433,7 @@ class MklQuantizedConvOp (std::is_same::value) ? 255.0 * 127.0 : 127.0 * 127.0; // Re-scale bias if either of following 2 conditions are met: // 1. Bias is not const; - // 2. Bias is const, bias has not been cached (first iteration). + // 2. Bias is const, but bias cache is empty (first iteration). size_t depth = min_filter_vector.NumElements(); bool scales_are_valid = (depth == scales_.size()); @@ -2556,10 +2502,10 @@ class MklQuantizedConvOp return static_cast( const_cast(bias_tensor.flat().data())); } - // Starting oneDNN v3.0, bias needs to be passed as is (in float datatype). - // However, for backward compatibility we need to handle the case where bias - // is qint32. Since oneDNN v3.0 does not support qint32 bias, we need to - // dequantize to float. + // Starting with oneDNN v3.0, bias needs to be passed as is (in float + // datatype). However, for backward compatibility we need to handle the case + // where bias is qint32. Since oneDNN v3.0 does not support qint32 bias, we + // need to dequantize to float. const float min_input = context->input(min_input_idx_).template scalar()(); const float max_input = @@ -2579,7 +2525,7 @@ class MklQuantizedConvOp (std::is_same::value) ? 255.0 * 127.0 : 127.0 * 127.0; // Re-scale bias if either of following 2 conditions are met: // 1. Bias is not const; - // 2. Bias is const, but bias cache is empty (first iteration). + // 2. Bias is const, bias has not been cached (first iteration). size_t depth = min_filter_vector.NumElements(); bool scales_are_valid = (depth == scales_.size()); From 63f7e421fd5f7d0b5a1f26d068db50fe159cdff0 Mon Sep 17 00:00:00 2001 From: Mahmoud Abuzaina Date: Tue, 13 Jun 2023 14:22:30 -0700 Subject: [PATCH 020/376] Addressed reveiw comments --- tensorflow/core/kernels/mkl/mkl_conv_ops.cc | 187 ++++++++++---------- 1 file changed, 98 insertions(+), 89 deletions(-) diff --git a/tensorflow/core/kernels/mkl/mkl_conv_ops.cc b/tensorflow/core/kernels/mkl/mkl_conv_ops.cc index ffd72971770ef7..0aebeb10c2a0ba 100644 --- a/tensorflow/core/kernels/mkl/mkl_conv_ops.cc +++ b/tensorflow/core/kernels/mkl/mkl_conv_ops.cc @@ -715,7 +715,7 @@ class MklConvOp : public OpKernel { context, !(context->HasAttr("padding_list") && context->HasAttr("explicit_paddings")), - errors::InvalidArgument("Can only have 1 `padding` list at most")); + absl::InvalidArgumentError("Can only have 1 `padding` list at most")); if (context->HasAttr("padding_list")) { OP_REQUIRES_OK(context, context->GetAttr("padding_list", &padding_list_)); } @@ -727,17 +727,17 @@ class MklConvOp : public OpKernel { OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_)); OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format_str_)); OP_REQUIRES(context, FormatFromString(data_format_str_, &data_format_), - errors::InvalidArgument("Invalid data format")); + absl::InvalidArgumentError("Invalid data format")); OP_REQUIRES(context, (strides_.size() == 4 || strides_.size() == 5), - errors::InvalidArgument("Sliding window strides field must " - "specify 4 or 5 dimensions")); + absl::InvalidArgumentError("Sliding window strides field must " + "specify 4 or 5 dimensions")); const int64 stride_n = GetTensorDim(strides_, data_format_, 'N'); const int64 stride_c = GetTensorDim(strides_, data_format_, 'C'); OP_REQUIRES( context, stride_n == 1 && stride_c == 1, - errors::Unimplemented("Current implementation does not yet support " - "strides in the batch and depth dimensions.")); + absl::UnimplementedError("Current implementation does not yet support " + "strides in the batch and depth dimensions.")); OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); is_filter_const_ = false; @@ -749,28 +749,29 @@ class MklConvOp : public OpKernel { } if (strides_.size() == 4) { - OP_REQUIRES(context, dilations_.size() == 4, - errors::InvalidArgument("Sliding window dilations field must " - "specify 4 dimensions")); + OP_REQUIRES( + context, dilations_.size() == 4, + absl::InvalidArgumentError("Sliding window dilations field must " + "specify 4 dimensions")); const int64 dilation_n = GetTensorDim(dilations_, data_format_, 'N'); const int64 dilation_c = GetTensorDim(dilations_, data_format_, 'C'); const int64 dilation_h = GetTensorDim(dilations_, data_format_, 'H'); const int64 dilation_w = GetTensorDim(dilations_, data_format_, 'W'); OP_REQUIRES(context, dilation_n == 1 && dilation_c == 1, - errors::InvalidArgument( + absl::InvalidArgumentError( "Current implementation does not yet support " "dilations in the batch and depth dimensions.")); OP_REQUIRES( context, dilation_h > 0 && dilation_w > 0, - errors::InvalidArgument("Dilated rates should be larger than 0.")); + absl::InvalidArgumentError("Dilated rates should be larger than 0.")); } else if (strides_.size() == 5) { OP_REQUIRES(context, dilations_.size() == 5, - errors::InvalidArgument("Dilation rates field must " - "specify 5 dimensions")); + absl::InvalidArgumentError("Dilation rates field must " + "specify 5 dimensions")); OP_REQUIRES(context, (GetTensorDim(dilations_, data_format_, 'N') == 1 && GetTensorDim(dilations_, data_format_, 'C') == 1), - errors::InvalidArgument( + absl::InvalidArgumentError( "Current implementation does not yet support " "dilations rates in the batch and depth dimensions.")); OP_REQUIRES( @@ -778,7 +779,7 @@ class MklConvOp : public OpKernel { (GetTensorDim(dilations_, data_format_, '0') > 0 && GetTensorDim(dilations_, data_format_, '1') > 0 && GetTensorDim(dilations_, data_format_, '2') > 0), - errors::InvalidArgument("Dilated rates should be larger than 0.")); + absl::InvalidArgumentError("Dilated rates should be larger than 0.")); } } @@ -789,8 +790,8 @@ class MklConvOp : public OpKernel { const Tensor& filter_tensor = MklGetInput(context, kInputIndex_Filter); OP_REQUIRES( context, filter_tensor.NumElements() > 0, - errors::InvalidArgument("filter must not have zero elements " - "(i.e. all dimensions must be non-zero)")); + absl::InvalidArgumentError("filter must not have zero elements " + "(i.e. all dimensions must be non-zero)")); if (std::is_same::value) { (void)SetFPMathMode(); @@ -802,8 +803,8 @@ class MklConvOp : public OpKernel { native_format); OP_REQUIRES(context, !filter_mkl_shape.IsMklTensor(), - errors::InvalidArgument("Filter should not be in " - "Mkl Layout")); + absl::InvalidArgumentError("Filter should not be in " + "Mkl Layout")); MklDnnData src(&cpu_engine_); MklDnnData filter(&cpu_engine_); @@ -875,18 +876,18 @@ class MklConvOp : public OpKernel { bool is_conv3d = (strides_.size() == 5); if (!is_conv2d && !is_conv3d) { - OP_REQUIRES( - context, !pad_enabled, - errors::InvalidArgument("Pad + Conv fusion only works for 2D/3D")); + OP_REQUIRES(context, !pad_enabled, + absl::InvalidArgumentError( + "Pad + Conv fusion only works for 2D/3D")); OP_REQUIRES( context, !fuse_pad_, - errors::InvalidArgument("Pad+Conv fusion only works for 2D/3D")); + absl::InvalidArgumentError("Pad+Conv fusion only works for 2D/3D")); } // TODO(intel-tf) 3-D support for Depthwise is not there if (is_depthwise) { OP_REQUIRES(context, is_conv2d, - errors::InvalidArgument( + absl::InvalidArgumentError( "Only 2D convolution is supported for depthwise.")); } @@ -899,7 +900,7 @@ class MklConvOp : public OpKernel { auto mkl_fmt_tag = MklTensorFormatToMklDnnDataFormat(tf_fmt); // NOTE: `mkl_fmt_tag` will be `format_tag::undef` for ReLU OP_REQUIRES(context, mkl_fmt_tag != memory::format_tag::undef, - errors::InvalidArgument("Invalid data format")); + absl::InvalidArgumentError("Invalid data format")); // If input is in MKL layout, then simply grab the layout; otherwise, // construct TF layout for input. @@ -957,8 +958,9 @@ class MklConvOp : public OpKernel { // Inputs to FusedBatchNorm have same 1D shape fuse_bn_shape = MklGetInput(context, kInputIndex_BN_Mean).shape(); OP_REQUIRES(context, fuse_bn_shape.dims() == 1, - errors::InvalidArgument("FusedBatchNorm must be 1D, not: ", - fuse_bn_shape.DebugString())); + absl::InvalidArgumentError( + absl::StrCat("FusedBatchNorm must be 1D, not: ", + fuse_bn_shape.DebugString()))); // Note - MKL-DNN expects {1, C, 1, 1} for binary post-op even for NHWC fuse_bn_dims = {1, fuse_bn_shape.dim_size(0), 1, 1}; @@ -1090,9 +1092,9 @@ class MklConvOp : public OpKernel { string error_msg = tensorflow::strings::StrCat( "Status: ", e.status, ", message: ", string(e.message), ", in file ", __FILE__, ":", __LINE__); - OP_REQUIRES_OK( - context, - errors::Aborted("Operation received an exception:", error_msg)); + OP_REQUIRES_OK(context, + absl::AbortedError(absl::StrCat( + "Operation received an exception:", error_msg))); } } @@ -1105,8 +1107,9 @@ class MklConvOp : public OpKernel { } else { const Tensor& paddings_tf = MklGetInput(context, input_index_pad_); OP_REQUIRES(context, paddings_tf.dims() == 2, - errors::InvalidArgument("paddings must be 2-dimensional: ", - paddings_tf.shape().DebugString())); + absl::InvalidArgumentError( + absl::StrCat("paddings must be 2-dimensional: ", + paddings_tf.shape().DebugString()))); // Flatten tensor to get individual paddings. paddings = static_cast( const_cast(paddings_tf.flat().data())); @@ -1201,9 +1204,9 @@ class MklConvOp : public OpKernel { virtual void ComputeBNScale(OpKernelContext* context, float epsilon, int bn_variance_index, Tinput* scale_buf_ptr) { - OP_REQUIRES( - context, false, - errors::Unimplemented("Compute BN scale not expected in base class")); + OP_REQUIRES(context, false, + absl::UnimplementedError( + "Compute BN scale not expected in base class")); return; } @@ -1323,7 +1326,7 @@ class MklConvOp : public OpKernel { auto output_format_tag = MklTensorFormatToMklDnnDataFormat( output_mkl_shape->GetTfDataFormat()); OP_REQUIRES(context, output_format_tag != memory::format_tag::undef, - errors::InvalidArgument( + absl::InvalidArgumentError( "MklConvOp: AddN fusion: Invalid data format")); auto add_md = add_mkl_shape.IsMklTensor() @@ -1563,14 +1566,14 @@ class MklFusedConvOp int num_args; OP_REQUIRES_OK(context, context->GetAttr("num_args", &num_args)); OP_REQUIRES(context, !fused_ops.empty(), - errors::InvalidArgument( + absl::InvalidArgumentError( "Fused Conv2D must have at least one fused op.")); // TODO(intel-tf): Compact the code for activation checking if (fused_ops == std::vector{"BiasAdd"}) { this->set_fuse_biasadd(true); OP_REQUIRES(context, num_args == 1, - errors::InvalidArgument( + absl::InvalidArgumentError( "Fused Conv2D must have one extra argument: bias.")); } else if (fused_ops == std::vector{"Relu"}) { this->set_fuse_activation(true, dnnl::algorithm::eltwise_relu); @@ -1589,26 +1592,26 @@ class MklFusedConvOp OP_REQUIRES_OK(context, context->GetAttr("epsilon", &epsilon)); OP_REQUIRES( context, num_args == 4, - errors::InvalidArgument( + absl::InvalidArgumentError( "Fused Conv2D with batchnorm must have 4 extra argument")); this->set_fuse_bn(true, epsilon); } else if (fused_ops == std::vector{"BiasAdd", "Relu"}) { this->set_fuse_biasadd(true); this->set_fuse_activation(true, dnnl::algorithm::eltwise_relu); OP_REQUIRES(context, num_args == 1, - errors::InvalidArgument( + absl::InvalidArgumentError( "Fused Conv2D must have one extra argument: bias.")); } else if (fused_ops == std::vector{"BiasAdd", "Relu6"}) { this->set_fuse_biasadd(true); this->SET_FUSE_ACTIVATION_FOR_RELU6; OP_REQUIRES(context, num_args == 1, - errors::InvalidArgument( + absl::InvalidArgumentError( "Fused Conv2D must have one extra argument: bias.")); } else if (fused_ops == std::vector{"BiasAdd", "Elu"}) { this->set_fuse_biasadd(true); this->set_fuse_activation(true, dnnl::algorithm::eltwise_elu, 1.0); OP_REQUIRES(context, num_args == 1, - errors::InvalidArgument( + absl::InvalidArgumentError( "Fused Conv2D must have one extra argument: bias.")); } else if (fused_ops == std::vector{"BiasAdd", "LeakyRelu"}) { this->set_fuse_biasadd(true); @@ -1618,21 +1621,21 @@ class MklFusedConvOp this->set_fuse_activation(true, dnnl::algorithm::eltwise_relu, leakyrelu_alpha); OP_REQUIRES(context, num_args == 1, - errors::InvalidArgument( + absl::InvalidArgumentError( "Fused Conv2D must have one extra argument: bias.")); } else if (fused_ops == std::vector{"BiasAdd", "Add"}) { this->set_fuse_biasadd(true); this->set_fuse_add(true); OP_REQUIRES( context, num_args == 2, - errors::InvalidArgument( + absl::InvalidArgumentError( "Fused Conv2D must have two extra arguments: bias and add.")); } else if (fused_ops == std::vector{"FusedBatchNorm", "Relu"}) { float epsilon; OP_REQUIRES_OK(context, context->GetAttr("epsilon", &epsilon)); OP_REQUIRES( context, num_args == 4, - errors::InvalidArgument( + absl::InvalidArgumentError( "Fused Conv2D with batchnorm must have 4 extra argument")); this->set_fuse_bn(true, epsilon); this->set_fuse_activation(true, dnnl::algorithm::eltwise_relu); @@ -1641,7 +1644,7 @@ class MklFusedConvOp OP_REQUIRES_OK(context, context->GetAttr("epsilon", &epsilon)); OP_REQUIRES( context, num_args == 4, - errors::InvalidArgument( + absl::InvalidArgumentError( "Fused Conv2D with batchnorm must have 4 extra argument")); this->set_fuse_bn(true, epsilon); this->SET_FUSE_ACTIVATION_FOR_RELU6; @@ -1650,7 +1653,7 @@ class MklFusedConvOp OP_REQUIRES_OK(context, context->GetAttr("epsilon", &epsilon)); OP_REQUIRES( context, num_args == 4, - errors::InvalidArgument( + absl::InvalidArgumentError( "Fused Conv2D with batchnorm must have 4 extra argument")); this->set_fuse_bn(true, epsilon); this->set_fuse_activation(true, dnnl::algorithm::eltwise_elu, 1.0); @@ -1662,7 +1665,7 @@ class MklFusedConvOp context->GetAttr("leakyrelu_alpha", &leakyrelu_alpha)); OP_REQUIRES( context, num_args == 4, - errors::InvalidArgument( + absl::InvalidArgumentError( "Fused Conv2D with batchnorm must have 4 extra argument")); this->set_fuse_bn(true, epsilon); this->set_fuse_activation(true, dnnl::algorithm::eltwise_relu, @@ -1673,7 +1676,7 @@ class MklFusedConvOp OP_REQUIRES_OK(context, context->GetAttr("epsilon", &epsilon)); OP_REQUIRES( context, num_args == 4, - errors::InvalidArgument( + absl::InvalidArgumentError( "Fused Conv2D with batchnorm must have 4 extra argument")); this->set_fuse_bn(true, epsilon); this->set_fuse_activation(true, dnnl::algorithm::eltwise_swish, 1.0); @@ -1683,7 +1686,7 @@ class MklFusedConvOp this->set_fuse_activation(true, dnnl::algorithm::eltwise_relu); OP_REQUIRES( context, num_args == 2, - errors::InvalidArgument( + absl::InvalidArgumentError( "Fused Conv2D must have two extra arguments: bias and add.")); } else if (fused_ops == std::vector{"BiasAdd", "Add", "Relu6"}) { this->set_fuse_biasadd(true); @@ -1691,7 +1694,7 @@ class MklFusedConvOp this->SET_FUSE_ACTIVATION_FOR_RELU6; OP_REQUIRES( context, num_args == 2, - errors::InvalidArgument( + absl::InvalidArgumentError( "Fused Conv2D must have two extra arguments: bias and add.")); } else if (fused_ops == std::vector{"BiasAdd", "Add", "Elu"}) { this->set_fuse_biasadd(true); @@ -1699,7 +1702,7 @@ class MklFusedConvOp this->set_fuse_activation(true, dnnl::algorithm::eltwise_elu, 1.0); OP_REQUIRES( context, num_args == 2, - errors::InvalidArgument( + absl::InvalidArgumentError( "Fused Conv2D must have two extra arguments: bias and add.")); } else if (fused_ops == std::vector{"BiasAdd", "Add", "LeakyRelu"}) { @@ -1712,18 +1715,19 @@ class MklFusedConvOp leakyrelu_alpha); OP_REQUIRES( context, num_args == 2, - errors::InvalidArgument( + absl::InvalidArgumentError( "Fused Conv2D must have two extra arguments: bias and add.")); } else if (fused_ops == std::vector{"BiasAdd", "_MklSwish"}) { this->set_fuse_biasadd(true); this->set_fuse_activation(true, dnnl::algorithm::eltwise_swish, 1.0); OP_REQUIRES(context, num_args == 1, - errors::InvalidArgument( + absl::InvalidArgumentError( "Fused Conv2D must have one extra argument: bias.")); } else { OP_REQUIRES(context, false, - errors::Unimplemented("Fusion is not implemented: [", - absl::StrJoin(fused_ops, ","), "]")); + absl::UnimplementedError( + absl::StrCat("Fusion is not implemented: [", + absl::StrJoin(fused_ops, ","), "]"))); } if (pad_enabled) { @@ -1770,7 +1774,7 @@ class MklFusedDepthwiseConvOp int num_args; OP_REQUIRES_OK(context, context->GetAttr("num_args", &num_args)); OP_REQUIRES(context, !fused_ops.empty(), - errors::InvalidArgument( + absl::InvalidArgumentError( "Fused DepthwiseConv2D must have at least one fused op.")); if (fused_ops == std::vector{"BiasAdd"}) { @@ -1786,13 +1790,14 @@ class MklFusedDepthwiseConvOp this->set_fuse_activation(true, dnnl::algorithm::eltwise_elu, 1.0); } else { OP_REQUIRES(context, false, - errors::Unimplemented("Fusion is not implemented: [", - absl::StrJoin(fused_ops, ","), "]")); + absl::InvalidArgumentError( + absl::StrCat("Fusion is not implemented: [", + absl::StrJoin(fused_ops, ","), "]"))); } OP_REQUIRES( context, num_args == 1, - errors::InvalidArgument( + absl::InvalidArgumentError( "Fused DepthwiseConv2D must have one extra argument: bias.")); if (pad_enabled) { @@ -1860,7 +1865,7 @@ class MklQuantizedConvOp // TODO(intel-tf): num_fused_ops and legacy_fused_ops should go away once // old API is abandoned. OP_REQUIRES(context, !(fused_ops_attr.size() > 0 && num_fused_ops > 0), - errors::InvalidArgument( + absl::InvalidArgumentError( "QuantizedConv fused ops should be only available through " "either new API or old API, got both.")); @@ -1877,8 +1882,9 @@ class MklQuantizedConvOp std::find(supported_fusions.begin(), supported_fusions.end(), fused_ops_) != supported_fusions.end(); OP_REQUIRES(context, is_fusion_supported, - errors::InvalidArgument("Unsupported QuantizedConv fusion: [", - absl::StrJoin(fused_ops_, ","), "]")); + absl::InvalidArgumentError( + absl::StrCat("Unsupported QuantizedConv fusion: [", + absl::StrJoin(fused_ops_, ","), "]"))); } // Set the flag for every fused op. @@ -1902,9 +1908,10 @@ class MklQuantizedConvOp const bool fuse_requantize = IsFused(oneDNNFusedOps::kRequantize); OP_REQUIRES_OK(context, context->GetAttr("out_type", &out_dt)); if (fuse_requantize) { - OP_REQUIRES(context, out_dt == DT_QINT8 || out_dt == DT_QUINT8, - errors::InvalidArgument("QuantizedConv: unsupported output " - "type when Requantize is fused.")); + OP_REQUIRES( + context, out_dt == DT_QINT8 || out_dt == DT_QUINT8, + absl::InvalidArgumentError("QuantizedConv: unsupported output " + "type when Requantize is fused.")); } if (context->HasAttr("Tsummand")) { @@ -1912,7 +1919,7 @@ class MklQuantizedConvOp if (!this->get_fuse_add()) { OP_REQUIRES( context, summand_dt == out_dt, - errors::InvalidArgument( + absl::InvalidArgumentError( "QuantizedConv: incorrect summand data type. When Sum is not " "fused, Tsummand attribute must have same value as out_type.")); } @@ -1947,7 +1954,7 @@ class MklQuantizedConvOp OP_REQUIRES( context, is_filter_const, - errors::InvalidArgument("QuantizedConv: filter must be a constant")); + absl::InvalidArgumentError("QuantizedConv: filter must be a constant")); if (num_fused_ops == -1) { // If num_fused_ops is -1 then the new API (ops) are being used. @@ -2094,8 +2101,8 @@ class MklQuantizedConvOp ((min_filter_vector.NumElements() > 0) && (max_filter_vector.NumElements() > 0) && (min_filter_vector.shape() == max_filter_vector.shape())), - errors::InvalidArgument("`min_ and max_filter` must have same" - "shape and contain at least one element.")); + absl::InvalidArgumentError("`min_ and max_filter` must have same" + "shape and contain at least one element.")); float int_input_limit = std::is_same::value ? 255.0f : 127.0f; size_t depth = min_filter_vector.NumElements(); @@ -2215,15 +2222,15 @@ class MklQuantizedConvOp OP_REQUIRES( context, TensorShapeUtils::IsScalar(min_freezed_output_tensor.shape()), - errors::InvalidArgument( - "`min_freezed_output` must be rank 0 but is rank ", - min_freezed_output_tensor.dims())); + absl::InvalidArgumentError( + absl::StrCat("`min_freezed_output` must be rank 0 but is rank ", + min_freezed_output_tensor.dims()))); OP_REQUIRES( context, TensorShapeUtils::IsScalar(max_freezed_output_tensor.shape()), - errors::InvalidArgument( - "`max_freezed_output` must be rank 0 but is rank ", - max_freezed_output_tensor.dims())); + absl::InvalidArgumentError( + absl::StrCat("`max_freezed_output` must be rank 0 but is rank ", + max_freezed_output_tensor.dims()))); const Tensor& min_freezed_summand_tensor = context->input(min_summand_idx_); const Tensor& max_freezed_summand_tensor = @@ -2231,15 +2238,15 @@ class MklQuantizedConvOp OP_REQUIRES( context, TensorShapeUtils::IsScalar(min_freezed_summand_tensor.shape()), - errors::InvalidArgument( + absl::InvalidArgumentError(absl::StrCat( "`min_freezed_summand` must be rank 0 but is rank ", - min_freezed_summand_tensor.dims())); + min_freezed_summand_tensor.dims()))); OP_REQUIRES( context, TensorShapeUtils::IsScalar(max_freezed_summand_tensor.shape()), - errors::InvalidArgument( + absl::InvalidArgumentError(absl::StrCat( "`max_freezed_summand` must be rank 0 but is rank ", - max_freezed_summand_tensor.dims())); + max_freezed_summand_tensor.dims()))); const float min_freezed_output = min_freezed_output_tensor.template scalar()(); const float max_freezed_output = @@ -2320,7 +2327,7 @@ class MklQuantizedConvOp OP_REQUIRES(context, context->forward_input_to_output_with_shape( summand_idx, 0, summand.shape(), output_tensor), - errors::InvalidArgument( + absl::InvalidArgumentError( "Summand cannot be forwarded in the current fusion.")); return; } @@ -2401,7 +2408,7 @@ class MklQuantizedConvOp OP_REQUIRES(context, context->forward_input_to_output_with_shape( summand_idx, 0, summand_float.shape(), output_tensor), - errors::InvalidArgument( + absl::InvalidArgumentError( "Summand cannot be forwarded in the current fusion.")); #endif // !ENABLE_ONEDNN_V3 @@ -2711,13 +2718,14 @@ class MklFusedConv3DOp std::vector padding_list; OP_REQUIRES_OK(context, context->GetAttr("padding_list", &padding_list)); if (padding_list.empty()) { - OP_REQUIRES(context, !fused_ops.empty(), - errors::InvalidArgument("Fused Conv3D must have at least one " - "fused op when Pad is not fused.")); + OP_REQUIRES( + context, !fused_ops.empty(), + absl::InvalidArgumentError("Fused Conv3D must have at least one " + "fused op when Pad is not fused.")); if (std::find(fused_ops.begin(), fused_ops.end(), "BiasAdd") == fused_ops.end()) { OP_REQUIRES(context, num_args == 1, - errors::InvalidArgument( + absl::InvalidArgumentError( "Fused Conv3D must have one extra argument: bias.")); } else if (std::find(fused_ops.begin(), fused_ops.end(), "BiasAdd") == fused_ops.end() && @@ -2725,7 +2733,7 @@ class MklFusedConv3DOp fused_ops.end()) { OP_REQUIRES( context, num_args == 2, - errors::InvalidArgument( + absl::InvalidArgumentError( "Fused Conv3D must have two extra arguments: bias and add.")); } } @@ -2775,8 +2783,9 @@ class MklFusedConv3DOp } else { if (padding_list.empty()) { OP_REQUIRES(context, false, - errors::Unimplemented("Fusion is not implemented: [", - absl::StrJoin(fused_ops, ","), "]")); + absl::UnimplementedError( + absl::StrCat("Fusion is not implemented: [", + absl::StrJoin(fused_ops, ","), "]"))); } } } From 8afafd2d325fadb91bb495177193b976c11bf0af Mon Sep 17 00:00:00 2001 From: Mahmoud Abuzaina Date: Wed, 14 Jun 2023 10:04:05 -0700 Subject: [PATCH 021/376] Addressed review comments --- tensorflow/core/kernels/mkl/mkl_conv_ops.cc | 22 ++++++++++---------- tensorflow/core/util/mkl_util.h | 23 ++++++++------------- 2 files changed, 20 insertions(+), 25 deletions(-) diff --git a/tensorflow/core/kernels/mkl/mkl_conv_ops.cc b/tensorflow/core/kernels/mkl/mkl_conv_ops.cc index 0aebeb10c2a0ba..274bd0e378e45a 100644 --- a/tensorflow/core/kernels/mkl/mkl_conv_ops.cc +++ b/tensorflow/core/kernels/mkl/mkl_conv_ops.cc @@ -425,8 +425,8 @@ class MklConvFwdPrimitive : public MklPrimitive { post_ops.append_sum(op_scale, /*zero_point=*/0, MklDnnType()); } else { - TF_CHECK_OK(Status(absl::StatusCode::kFailedPrecondition, - "Summand data type is expected to be float")); + TF_CHECK_OK(absl::FailedPreconditionError( + "Summand data type is expected to be float")); } } else { post_ops.append_sum(op_scale); @@ -2168,8 +2168,8 @@ class MklQuantizedConvOp } else { #ifdef ENABLE_ONEDNN_V3 if (!std::is_same::value) - TF_CHECK_OK(Status(absl::StatusCode::kFailedPrecondition, - "Output datatype is expected to be qint32.")); + TF_CHECK_OK(absl::FailedPreconditionError( + "Output datatype is expected to be qint32.")); float min_min_filter = min_filter[0]; float max_max_filter = max_filter[0]; for (size_t i = 0; i < depth; ++i) { @@ -2343,8 +2343,8 @@ class MklQuantizedConvOp output_tensor); const Tensor& summand = context->input(this->get_input_add_idx()); if (summand.dtype() != DT_FLOAT) - TF_CHECK_OK(Status(absl::StatusCode::kFailedPrecondition, - "Current fusion requires summand to be float")); + TF_CHECK_OK(absl::FailedPreconditionError( + "Current fusion requires summand to be float")); // We need to compute scale for the summand const float min_input = context->input(min_input_idx_).template scalar()(); @@ -2399,8 +2399,8 @@ class MklQuantizedConvOp int summand_idx = this->get_input_add_idx(); DataType summand_dt = this->input_type(summand_idx); if (summand_dt != DT_FLOAT) - TF_CHECK_OK(Status(absl::StatusCode::kFailedPrecondition, - "Summand datatype is expected to be float.")); + TF_CHECK_OK(absl::FailedPreconditionError( + "Summand datatype is expected to be float.")); Tensor& summand_float = const_cast(context->input(summand_idx)); OP_REQUIRES_OK(context, summand_float.BitcastFrom(summand_float, DT_QINT32, @@ -2522,9 +2522,9 @@ class MklQuantizedConvOp if ((min_filter_vector.NumElements() == 0) || (max_filter_vector.NumElements() == 0) || (min_filter_vector.shape() != max_filter_vector.shape())) { - TF_CHECK_OK(Status(absl::StatusCode::kFailedPrecondition, - "`min_filter and max_filter` must have same" - "shape and contain at least one element.")); + TF_CHECK_OK(absl::FailedPreconditionError( + "`min_filter and max_filter` must have same" + "shape and contain at least one element.")); } const float* min_filter = min_filter_vector.flat().data(); const float* max_filter = max_filter_vector.flat().data(); diff --git a/tensorflow/core/util/mkl_util.h b/tensorflow/core/util/mkl_util.h index 84f71fa2761388..8a29458b0a9dcb 100644 --- a/tensorflow/core/util/mkl_util.h +++ b/tensorflow/core/util/mkl_util.h @@ -706,8 +706,8 @@ inline Status ConvertMklToTF(OpKernelContext* context, bool status = input.CheckReorderToOpMem(output_tf_md, output_tf_tensor, net, net_args, cpu_engine); if (!status) { - return Status(absl::StatusCode::kInternal, - "ConvertMklToTF(): Failed to create reorder for input"); + return absl::InternalError( + "ConvertMklToTF(): Failed to create reorder for input"); } ExecutePrimitive(net, &net_args, cpu_engine, context); } else { @@ -715,8 +715,7 @@ inline Status ConvertMklToTF(OpKernelContext* context, bool status = output_tf_tensor->CopyFrom(input_mkl_tensor, output_tf_shape); if (!status) { - return Status( - absl::StatusCode::kInternal, + return absl::InternalError( "ConvertMklToTF(): Failed to forward input tensor to output"); } } @@ -1114,8 +1113,7 @@ inline memory::format_tag MklTensorFormatToMklDnnDataFormat( inline MklTensorFormat TFDataFormatToMklDnn3DDataFormat(TensorFormat format) { if (format == FORMAT_NHWC) return MklTensorFormat::FORMAT_NDHWC; if (format == FORMAT_NCHW) return MklTensorFormat::FORMAT_NCDHW; - TF_CHECK_OK( - Status(absl::StatusCode::kInvalidArgument, "Unsupported data format")); + TF_CHECK_OK(absl::InvalidArgumentError("Unsupported data format")); return MklTensorFormat::FORMAT_INVALID; } @@ -1127,8 +1125,7 @@ inline MklTensorFormat TFDataFormatToMklDnn3DDataFormat(TensorFormat format) { inline MklTensorFormat TFDataFormatToMklDnnDataFormat(TensorFormat format) { if (format == FORMAT_NHWC) return MklTensorFormat::FORMAT_NHWC; if (format == FORMAT_NCHW) return MklTensorFormat::FORMAT_NCHW; - TF_CHECK_OK( - Status(absl::StatusCode::kInvalidArgument, "Unsupported data format")); + TF_CHECK_OK(absl::InvalidArgumentError("Unsupported data format")); return MklTensorFormat::FORMAT_INVALID; } @@ -1144,8 +1141,7 @@ inline TensorFormat MklDnnDataFormatToTFDataFormat(MklTensorFormat format) { if (format == MklTensorFormat::FORMAT_NCHW || format == MklTensorFormat::FORMAT_NCDHW) return FORMAT_NCHW; - TF_CHECK_OK( - Status(absl::StatusCode::kInvalidArgument, "Unsupported data format")); + TF_CHECK_OK(absl::InvalidArgumentError("Unsupported data format")); // Return to prevent compiler warnings, otherwise TF_CHECK_OK will ensure // that we don't come here. @@ -1311,10 +1307,9 @@ inline Status CreateBlockedMemDescHelper(const memory::dims& dim, } catch (dnnl::error& e) { delete[] input_dims; delete[] input_strides; - return Status(absl::StatusCode::kInternal, - tensorflow::strings::StrCat( - "Failed to create blocked memory descriptor.", - "Status: ", e.status, ", message: ", e.message)); + return absl::InternalError( + absl::StrCat("Failed to create blocked memory descriptor.", + "Status: ", e.status, ", message: ", e.message)); } return OkStatus(); } From 538e345b6158e3bf543cc3acb0421a99d6b001f8 Mon Sep 17 00:00:00 2001 From: Yimei Sun Date: Fri, 16 Jun 2023 16:35:20 -0700 Subject: [PATCH 022/376] Use absl for error message --- .../kernels/mkl/mkl_fused_instance_norm_op.cc | 27 ++++++++++--------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/tensorflow/core/kernels/mkl/mkl_fused_instance_norm_op.cc b/tensorflow/core/kernels/mkl/mkl_fused_instance_norm_op.cc index 7e2b57c7a08485..df5a7e80239734 100644 --- a/tensorflow/core/kernels/mkl/mkl_fused_instance_norm_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_fused_instance_norm_op.cc @@ -41,7 +41,7 @@ class MklFusedInstanceNormOp : public OpKernel { OP_REQUIRES_OK(context, context->GetAttr("reduction_axes", &mean_reduction_axes)); OP_REQUIRES(context, InferDataFormat(mean_reduction_axes), - errors::InvalidArgument( + absl::InvalidArgumentError( "Failed to infer data format from reduction axes")); CheckFusedActivation(context); } @@ -57,19 +57,19 @@ class MklFusedInstanceNormOp : public OpKernel { (src_tensor.dims() == 4 && data_format_ == "NCHW") || (src_tensor.dims() == 5 && data_format_ == "NDHWC") || (src_tensor.dims() == 5 && data_format_ == "NCDHW"), - errors::InvalidArgument( + absl::InvalidArgumentError(absl::StrCat( "Unsupported input: ", src_tensor.shape().DebugString(), - ", ", data_format_)); + ", ", data_format_))); size_t num_elements_scale = scale_tensor.NumElements(); size_t num_elements_shift = shift_tensor.NumElements(); - OP_REQUIRES( - ctx, num_elements_scale == num_elements_shift, - errors::InvalidArgument("Number of elements in scale and shift", - "tensors are not same.")); + OP_REQUIRES(ctx, num_elements_scale == num_elements_shift, + absl::InvalidArgumentError( + absl::StrCat("Number of elements in scale and shift", + "tensors are not same."))); TensorFormat tensor_format; OP_REQUIRES(ctx, FormatFromString(data_format_, &tensor_format), - errors::InvalidArgument("Invalid data format")); + absl::InvalidArgumentError("Invalid data format")); MklDnnThreadPool eigen_tp(ctx); std::shared_ptr engine_stream_ptr; @@ -217,8 +217,8 @@ class MklFusedInstanceNormOp : public OpKernel { string error_msg = "Status: " + std::to_string(e.status) + ", message: " + string(e.message) + ", in file " + string(__FILE__) + ":" + std::to_string(__LINE__); - OP_REQUIRES_OK( - ctx, errors::Aborted("Operation received an exception:", error_msg)); + OP_REQUIRES_OK(ctx, absl::AbortedError(absl::StrCat( + "Operation received an exception:", error_msg))); } } @@ -247,8 +247,9 @@ class MklFusedInstanceNormOp : public OpKernel { context->GetAttr("leakyrelu_alpha", &leakyrelu_alpha_)); } else { OP_REQUIRES(context, false, - errors::Unimplemented("Fusion is not implemented: [", - absl::StrJoin(fused_ops, ","), "]")); + absl::UnimplementedError( + absl::StrCat("Fusion is not implemented: [", + absl::StrJoin(fused_ops, ","), "]"))); } } @@ -291,7 +292,7 @@ class MklFusedInstanceNormOp : public OpKernel { (void)fp32_shift_buf; #else OP_REQUIRES(ctx, (fp32_shift_buf != nullptr), - errors::InvalidArgument("Invalid shift buffer")); + absl::InvalidArgumentError("Invalid shift buffer")); shift_buf_dst = fp32_shift_buf; #endif // !ENABLE_ONEDNN_V3 From ada6f7571d84e4fb5856680a44d3442709f14316 Mon Sep 17 00:00:00 2001 From: Yaohui Liu Date: Wed, 21 Jun 2023 02:25:13 +0800 Subject: [PATCH 023/376] fix the error of TF_GetHandleShapeAndType. --- tensorflow/c/c_api.cc | 23 +++++++++++++---------- tensorflow/c/c_api.h | 2 +- tensorflow/c/python_api.cc | 24 +++++++++++++++++++++++- 3 files changed, 37 insertions(+), 12 deletions(-) diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index c7f8befc2c045a..cda738df5a6578 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -2640,30 +2640,33 @@ void TF_ExtendSession(TF_Session* session, TF_Status* status) { session->extend_before_run = false; } -const char* TF_GetHandleShapeAndType(TF_Graph* graph, TF_Output output) { - Node* node = &output.oper->node; +TF_Buffer* TF_GetHandleShapeAndType(TF_Graph* graph, TF_Output output) { + Node *node = &output.oper->node; tensorflow::CppShapeInferenceResult::HandleData handle_data; handle_data.set_is_set(true); { mutex_lock l(graph->mu); - tensorflow::shape_inference::InferenceContext* ic = + tensorflow::shape_inference::InferenceContext *ic = graph->refiner.GetContext(node); CHECK(ic != nullptr); CHECK_LT(output.index, ic->num_outputs()); - const auto* shapes_and_types = + const auto *shapes_and_types = ic->output_handle_shapes_and_types(output.index); - if (shapes_and_types == nullptr) return ""; + if (shapes_and_types == nullptr) + return nullptr; - for (const auto& p : *shapes_and_types) { - auto* out_shape_and_type = handle_data.add_shape_and_type(); + for (const auto &p : *shapes_and_types) { + auto *out_shape_and_type = handle_data.add_shape_and_type(); ic->ShapeHandleToProto(p.shape, out_shape_and_type->mutable_shape()); out_shape_and_type->set_dtype(p.dtype); *out_shape_and_type->mutable_type() = p.type; } } - string result; - handle_data.SerializeToString(&result); - return result.c_str(); + string str_data; + handle_data.SerializeToString(&str_data); + + TF_Buffer *result = TF_NewBufferFromString(str_data.c_str(), str_data.size()); + return result; } void TF_SetHandleShapeAndType(TF_Graph* graph, TF_Output output, const void* proto, diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h index 40fff3999acadf..fc999a41218c52 100644 --- a/tensorflow/c/c_api.h +++ b/tensorflow/c/c_api.h @@ -1624,7 +1624,7 @@ TF_CAPI_EXPORT extern void TF_ExtendSession(TF_Session* session, TF_Status* stat // Returns the serialized CppShapeInferenceResult::HandleData proto for // `output` if its a resource or variant tensor, or otherwise returns the empty // string. -TF_CAPI_EXPORT extern const char* TF_GetHandleShapeAndType(TF_Graph* graph, TF_Output output); +TF_CAPI_EXPORT extern TF_Buffer* TF_GetHandleShapeAndType(TF_Graph* graph, TF_Output output); // Sets `output` based on `proto`, which should be a serialized // CppShapeInferenceResult::HandleData proto. `output` should be a resource diff --git a/tensorflow/c/python_api.cc b/tensorflow/c/python_api.cc index cd0e45aabcc2b1..2de4dde7a202ed 100644 --- a/tensorflow/c/python_api.cc +++ b/tensorflow/c/python_api.cc @@ -54,7 +54,29 @@ void ExtendSession(TF_Session* session, TF_Status* status) { } std::string GetHandleShapeAndType(TF_Graph* graph, TF_Output output) { - return TF_GetHandleShapeAndType(graph, output); + Node* node = &output.oper->node; + tensorflow::CppShapeInferenceResult::HandleData handle_data; + handle_data.set_is_set(true); + { + mutex_lock l(graph->mu); + tensorflow::shape_inference::InferenceContext* ic = + graph->refiner.GetContext(node); + CHECK(ic != nullptr); + CHECK_LT(output.index, ic->num_outputs()); + const auto* shapes_and_types = + ic->output_handle_shapes_and_types(output.index); + if (shapes_and_types == nullptr) return ""; + + for (const auto& p : *shapes_and_types) { + auto* out_shape_and_type = handle_data.add_shape_and_type(); + ic->ShapeHandleToProto(p.shape, out_shape_and_type->mutable_shape()); + out_shape_and_type->set_dtype(p.dtype); + *out_shape_and_type->mutable_type() = p.type; + } + } + string result; + handle_data.SerializeToString(&result); + return result; } void SetHandleShapeAndType(TF_Graph* graph, TF_Output output, const void* proto, From 4428bb63bb42f1145b8bec9420b62bf59dd1e1d0 Mon Sep 17 00:00:00 2001 From: Mahmoud Abuzaina Date: Fri, 23 Jun 2023 18:34:15 -0700 Subject: [PATCH 024/376] Fixed win failures --- tensorflow/core/kernels/mkl/mkl_conv_ops.cc | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/tensorflow/core/kernels/mkl/mkl_conv_ops.cc b/tensorflow/core/kernels/mkl/mkl_conv_ops.cc index 274bd0e378e45a..f13ecbeb6613ed 100644 --- a/tensorflow/core/kernels/mkl/mkl_conv_ops.cc +++ b/tensorflow/core/kernels/mkl/mkl_conv_ops.cc @@ -49,6 +49,7 @@ namespace tensorflow { #define SET_FUSE_ACTIVATION_FOR_RELU6 \ set_fuse_activation(true, dnnl::algorithm::eltwise_bounded_relu, 6.0) #define SET_MKL_LAYOUT(md) SetMklLayout(&md) +#define OUTPUT_SCALE_DCHECK (post_op_param.name == "output_scale") #define TSCALED_BIAS Tbias #define SCALE scales #define SUMMAND_SCALE_U8(summand_range, output_range) \ @@ -64,6 +65,10 @@ namespace tensorflow { #define SET_FUSE_ACTIVATION_FOR_RELU6 \ set_fuse_activation(true, dnnl::algorithm::eltwise_clip, 0.0, 6.0) #define SET_MKL_LAYOUT(md) SetMklLayout(md) +#define OUTPUT_SCALE_DCHECK \ + (post_op_param.name == "src_scale") || \ + (post_op_param.name == "wei_scale") || \ + (post_op_param.name == "dst_scale") #define TSCALED_BIAS float #define SCALE wei_scale #define SUMMAND_SCALE_U8(summand_range, output_range) summand_range / 255.0f @@ -476,14 +481,7 @@ class MklConvFwdPrimitive : public MklPrimitive { *context_.bn_offset_md); } else { DCHECK((post_op_param.name == "activation") || - (post_op_param.name == "sum") || -#ifndef ENABLE_ONEDNN_V3 - (post_op_param.name == "output_scale") || -#else - (post_op_param.name == "src_scale") || - (post_op_param.name == "wei_scale") || - (post_op_param.name == "dst_scale") || -#endif // !ENABLE_ONEDNN_V3 + (post_op_param.name == "sum") || OUTPUT_SCALE_DCHECK || (post_op_param.name == "fuse_bn")); } } @@ -2576,7 +2574,9 @@ class MklQuantizedConvOp scaled_bias_->set_data_handle(scaled_bias_buf_); } std::unique_ptr scale_mem( - new memory({{depth}, MklDnnType(), memory::format_tag::x}, + new memory({{static_cast(depth)}, + MklDnnType(), + memory::format_tag::x}, this->cpu_engine_, scales_.data())); auto reorder_desc = ReorderPd(this->cpu_engine_, input_bias_->get_desc(), @@ -3200,6 +3200,7 @@ REGISTER_KERNEL_BUILDER( #undef GET_DATA_TYPE #undef SET_FUSE_ACTIVATION_FOR_RELU6 #undef SET_MKL_LAYOUT +#undef OUTPUT_SCALE_DCHECK #undef TSCALED_BIAS #undef SCALE #undef SUMMAND_SCALE_U8 From 7b3c565b745560ace2c7b569e787c01fb1464da4 Mon Sep 17 00:00:00 2001 From: Crefeda Rodrigues Date: Tue, 27 Jun 2023 15:19:24 +0000 Subject: [PATCH 025/376] Update oneDNN reorder Signed-off-by: Crefeda Rodrigues --- tensorflow/workspace2.bzl | 1 + .../mkl_dnn/onednn_reorder_padded.patch | 858 ++++++++++++++++++ 2 files changed, 859 insertions(+) create mode 100644 third_party/mkl_dnn/onednn_reorder_padded.patch diff --git a/tensorflow/workspace2.bzl b/tensorflow/workspace2.bzl index d7708712ef6800..fa9a9dc04af216 100644 --- a/tensorflow/workspace2.bzl +++ b/tensorflow/workspace2.bzl @@ -207,6 +207,7 @@ def _tf_repositories(): "//third_party/mkl_dnn:onednn_acl_fixed_format_kernels.patch", "//third_party/mkl_dnn:onednn_acl_depthwise_convolution.patch", "//third_party/mkl_dnn:onednn_acl_threadpool_scheduler.patch", + "//third_party/mkl_dnn:onednn_reorder_padded.patch", ], sha256 = "a50993aa6265b799b040fe745e0010502f9f7103cc53a9525d59646aef006633", strip_prefix = "oneDNN-2.7.3", diff --git a/third_party/mkl_dnn/onednn_reorder_padded.patch b/third_party/mkl_dnn/onednn_reorder_padded.patch new file mode 100644 index 00000000000000..f290f21ec87e9b --- /dev/null +++ b/third_party/mkl_dnn/onednn_reorder_padded.patch @@ -0,0 +1,858 @@ + ******************************************************************************* + Copyright 2022 Arm Limited and affiliates. + SPDX-License-Identifier: Apache-2.0 + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ******************************************************************************* + +diff --git a/src/cpu/aarch64/jit_uni_reorder.cpp b/src/cpu/aarch64/jit_uni_reorder.cpp +index 24d6220cf..a6cefaa20 100644 +--- a/src/cpu/aarch64/jit_uni_reorder.cpp ++++ b/src/cpu/aarch64/jit_uni_reorder.cpp +@@ -1,6 +1,7 @@ + /******************************************************************************* + * Copyright 2018-2021 Intel Corporation + * Copyright 2020-2021 FUJITSU LIMITED ++* Copyright 2022 Arm Ltd. and affiliates + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. +@@ -54,6 +55,35 @@ namespace aarch64 { + + namespace tr { + ++static bool prb_has_small_strides(const prb_t &prb) { ++ constexpr ptrdiff_t max_stride = (1LL << 31) - 1; ++ for (int d = 0; d < prb.ndims; ++d) { ++ const ptrdiff_t cms = max_stride / prb.nodes[d].n; ++ const bool small_strides = true ++ && prb.nodes[d].is < cms / (int)data_type_size(prb.itype) ++ && prb.nodes[d].os < cms / (int)data_type_size(prb.otype); ++ if (!small_strides) return false; ++ } ++ return true; ++} ++ ++static bool prb_tail_friendly(const prb_t &prb) { ++ /* find optimal ndims to makes it easier to ++ * identify the blk_chunk in the loop*/ ++ int ndims = prb.full_ndims - prb.ndims; ++ ++ int n = prb.nodes[0].is; ++ for (int d = 1; d < prb.ndims; ++d) { ++ if (d != prb.blk_chunk_idx) n *= prb.nodes[d].n; ++ } ++ if (prb.ip_tail > 0 ++ && ((ndims == 0 && n != 1) ++ || (ndims > 0 && prb.ndims > prb.blk_chunk_idx))) ++ return false; ++ ++ return true; ++} ++ + /** Minimal reasonable/desirable kernel size. + * The constant might be used to determine how a problem should be split + * between kernel and threading driver. */ +@@ -121,18 +151,10 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { + && utils::one_of(p.otype, f32, s32, data_type::s8, u8) + && utils::everyone_is(0, p.ioff, p.ooff) /* do we need this? */ + && utils::one_of(p.beta, 0.f, 1.f) /* anything else? */ +- && simple_impl_desc_init(p, nullptr); ++ && simple_impl_desc_init(p, nullptr) && prb_has_small_strides(p) ++ && prb_tail_friendly(p); + if (!ok) return false; + +- const ptrdiff_t max_stride = (1LL << 31) - 1; +- for (int d = 0; d < p.ndims; ++d) { +- const ptrdiff_t cms = max_stride / p.nodes[d].n; +- bool strides_ok = true +- && p.nodes[d].is < cms / (int)data_type_size(p.itype) +- && p.nodes[d].os < cms / (int)data_type_size(p.otype); +- if (!strides_ok) return false; +- } +- + return true; + } + +@@ -153,6 +175,13 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { + return (int)prb_.nodes[d].ss; + } + ++ int blk_cnt() { ++ assert(prb_.blk_chunk_idx < prb_.full_ndims); ++ return (int)prb_.nodes[prb_.blk_chunk_idx].n - 1; ++ } ++ int op_padding() { return prb_.op_tail ? prb_.iblock - prb_.op_tail : 0; } ++ int ip_padding() { return prb_.ip_tail ? prb_.oblock - prb_.ip_tail : 0; } ++ + void step(int off, int prev_i_off, int prev_o_off, int prev_s_off, + int &i_off, int &o_off, int &s_off, int step_size = 1) { + i_off = prev_i_off; +@@ -385,6 +414,7 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { + prb_.otype, u8, data_type::s8, s32, f32))) + && utils::everyone_is(8, n(0), n(1)) + && utils::everyone_is(1, os(0), is(1)) ++ && utils::everyone_is(0, prb_.ip_tail, prb_.op_tail) + && prb_.scale_type == scale_type_t::NONE && prb_.beta == 0.f; + } + +@@ -405,17 +435,14 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { + bool process_direct_copy(int len) { + using namespace data_type; + +- const int simd_w = cpu_isa_traits::vlen == 16 +- ? cpu_isa_traits::vlen / itype_sz /* use 128-bit VReg */ +- : cpu_isa_traits::vlen / itype_sz +- / 2; /* use lower half of 512-bit ZReg */ +- ++ const int simd_w = cpu_isa_traits::vlen / itype_sz; + bool can_do = true && mayiuse(isa) + && utils::everyone_is(1, os(0), is(0)) + && (false || prb_.itype == prb_.otype + || (prb_.itype == s32 && prb_.otype == f32) + || (prb_.itype == f32 && prb_.otype == s32)) + && len % simd_w == 0 && n(0) % len == 0 ++ && prb_.ip_tail % simd_w == 0 && prb_.op_tail % simd_w == 0 + && prb_.scale_type == scale_type_t::NONE && prb_.beta == 0.f; + if (!can_do) return false; + +@@ -511,7 +538,8 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { + } + + void process_unroll_generic_step(int reg_unroll, const int *i_off, +- const int *o_off, const int *s_off) { ++ const int *o_off, const int *s_off, const int *ip_padding, ++ const bool h_padded) { + using namespace data_type; + + auto cvt2ps +@@ -571,6 +599,8 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { + for (int ur = 1; ur < reg_unroll; ++ur) + if (o_off[ur] != o_off[ur - 1] + 1) can_store_xmm = false; + const int ur_step = can_store_xmm ? 4 : 1; ++ const int load_tail_step ++ = !can_load_xmm && can_store_xmm ? ur_step : load_step; + + const bool interim_f32 = false + || utils::one_of(f32, prb_.itype, prb_.otype) +@@ -579,55 +609,85 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { + const bool need_saturation + = (utils::one_of(prb_.otype, u8, data_type::s8, s32) + && interim_f32); +- +- if (!can_load_xmm && can_store_xmm) { +- assert(ur_step == 4); +- /* load with stride */ +- for (int ur = 0; ur < reg_unroll; ur += ur_step) { +- ++ if (h_padded) { ++ for (int ur = 0; ur < reg_unroll; ur += load_tail_step) { ++ if (itype_sz == 4) ++ movi(VReg4S(ur), 0); ++ else if (itype_sz == 2) ++ movi(VReg8H(ur), 0); ++ else ++ movi(VReg16B(ur), 0); + /* x_tmp_vec = X_TMP_0 - X_TMP_4 + Do not use X_TMP_? as the last arg. */ +- for (int r = 0; r < ur_step; ++r) { +- add_imm(x_tmp_vec[r], x_ptr_in_off, +- i_off[ur + r] * itype_sz, X_DEFAULT_ADDR); ++ for (int r = 0; r < load_tail_step; ++r) { ++ if (ip_padding[ur + r] == 0) { ++ add_imm(x_tmp_vec[r], x_ptr_in_off, ++ i_off[ur + r] * itype_sz, X_DEFAULT_ADDR); ++ } + } + +- for (int r = 0; r < ur_step; ++r) { +- if (itype_sz == 4) +- ld1(VReg4S(ur)[r], ptr(x_tmp_vec[r])); +- else if (itype_sz == 2) +- ld1(VReg8H(ur)[r], ptr(x_tmp_vec[r])); +- else +- ld1(VReg16B(ur)[r], ptr(x_tmp_vec[r])); ++ for (int r = 0; r < load_tail_step; ++r) { ++ if (ip_padding[ur + r] == 0) { ++ if (itype_sz == 4) ++ ld1(VReg4S(ur)[r], ptr(x_tmp_vec[r])); ++ else if (itype_sz == 2) ++ ld1(VReg8H(ur)[r], ptr(x_tmp_vec[r])); ++ else ++ ld1(VReg16B(ur)[r], ptr(x_tmp_vec[r])); ++ } + } + } + } else { +- int ur = 0; +- int tmp_ur = 0; +- while (ur < reg_unroll) { +- int count = 0; ++ if (!can_load_xmm && can_store_xmm) { ++ assert(ur_step == 4); ++ /* load with stride */ ++ for (int ur = 0; ur < reg_unroll; ur += ur_step) { + +- do { +- add_imm(x_tmp_vec[count++], x_ptr_in_off, +- i_off[ur] * itype_sz, X_DEFAULT_ADDR); +- ur += load_step; +- } while (ur < reg_unroll && count < x_tmp_vec_size); ++ /* x_tmp_vec = X_TMP_0 - X_TMP_4 ++ Do not use X_TMP_? as the last arg. */ ++ for (int r = 0; r < ur_step; ++r) { ++ add_imm(x_tmp_vec[r], x_ptr_in_off, ++ i_off[ur + r] * itype_sz, X_DEFAULT_ADDR); ++ } + +- for (int i = 0; i < count; i++) { ++ for (int r = 0; r < ur_step; ++r) { ++ if (itype_sz == 4) ++ ld1(VReg4S(ur)[r], ptr(x_tmp_vec[r])); ++ else if (itype_sz == 2) ++ ld1(VReg8H(ur)[r], ptr(x_tmp_vec[r])); ++ else ++ ld1(VReg16B(ur)[r], ptr(x_tmp_vec[r])); ++ } ++ } ++ } else { ++ int ur = 0; ++ int tmp_ur = 0; ++ while (ur < reg_unroll) { ++ int count = 0; ++ ++ do { ++ add_imm(x_tmp_vec[count++], x_ptr_in_off, ++ i_off[ur] * itype_sz, X_DEFAULT_ADDR); ++ ur += load_step; ++ } while (ur < reg_unroll && count < x_tmp_vec_size); ++ ++ for (int i = 0; i < count; i++) { + +- switch (load_step * itype_sz) { +- case 16: ldr(QReg(tmp_ur), ptr(x_tmp_vec[i])); break; +- case 8: ldr(DReg(tmp_ur), ptr(x_tmp_vec[i])); break; +- case 4: ldr(SReg(tmp_ur), ptr(x_tmp_vec[i])); break; +- case 2: ldr(HReg(tmp_ur), ptr(x_tmp_vec[i])); break; +- case 1: ldr(BReg(tmp_ur), ptr(x_tmp_vec[i])); break; +- default: assert(!"unreachable"); ++ switch (load_step * itype_sz) { ++ case 16: ++ ldr(QReg(tmp_ur), ptr(x_tmp_vec[i])); ++ break; ++ case 8: ldr(DReg(tmp_ur), ptr(x_tmp_vec[i])); break; ++ case 4: ldr(SReg(tmp_ur), ptr(x_tmp_vec[i])); break; ++ case 2: ldr(HReg(tmp_ur), ptr(x_tmp_vec[i])); break; ++ case 1: ldr(BReg(tmp_ur), ptr(x_tmp_vec[i])); break; ++ default: assert(!"unreachable"); ++ } ++ tmp_ur += load_step; + } +- tmp_ur += load_step; + } + } + } +- + /* xmm[:] <-- (f32)xmm[:] */ + if (interim_f32) { + const int cvt_step = nstl::max(load_step, ur_step); +@@ -708,7 +768,8 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { + if (s_off[r] != s_off[r - 1] + 0) + scale_load_type = scale_load_type_t::load; + +- if (scale_load_type == scale_load_type_t::bcast) { ++ if (scale_load_type == scale_load_type_t::bcast ++ && !h_padded) { + VReg4S v(xmm_scale.getIdx()); + VReg4S v_dst(ur); + add_imm(X_TMP_0, x_ptr_scale_off, s_off[ur] * stype_sz, +@@ -724,7 +785,8 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { + if (s_off[r] != s_off[r - 1] + 1) + scale_load_type = scale_load_type_t::gather; + +- if (scale_load_type == scale_load_type_t::load) { ++ if (scale_load_type == scale_load_type_t::load ++ && !h_padded) { + uint32_t idx = xmm_scale.getIdx(); + VReg4S v_dst(ur); + add_imm(X_TMP_0, x_ptr_scale_off, s_off[ur] * stype_sz, +@@ -739,14 +801,18 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { + // so gather the scale factors one by one + /*ur_step is 1 or 4. */ + for (int r = ur; r < ur + ur_step; ++r) { +- /* x_tmp_vec = X_TMP_0 - X_TMP_4 ++ if (ip_padding[r] == 0 || !h_padded) { ++ /* x_tmp_vec = X_TMP_0 - X_TMP_4 + Do not use X_TMP_? as the last arg. */ +- add_imm(x_tmp_vec[r - ur], x_ptr_scale_off, +- s_off[r] * stype_sz, X_DEFAULT_ADDR); ++ add_imm(x_tmp_vec[r - ur], x_ptr_scale_off, ++ s_off[r] * stype_sz, X_DEFAULT_ADDR); ++ } + } + for (int r = ur; r < ur + ur_step; ++r) { +- VReg4S v(xmm_scale.getIdx()); +- ld1(v[r - ur], ptr(x_tmp_vec[r - ur])); ++ if (ip_padding[r] == 0 || !h_padded) { ++ VReg4S v(xmm_scale.getIdx()); ++ ld1(v[r - ur], ptr(x_tmp_vec[r - ur])); ++ } + } + fmul(VReg4S(ur), VReg4S(ur), xmm_scale); + } +@@ -925,7 +991,15 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { + } + } + +- void process_unroll_generic(int len) { ++ void comp_padding_flag(int ndims, int off, int len, int &i_tail) { ++ const int ip_without_padding ++ = ndims == 0 ? len - ip_padding() : prb_.ip_tail; ++ if ((ndims == 0 && off >= ip_without_padding) ++ || (ndims > 0 && (off % prb_.oblock) >= ip_without_padding)) ++ i_tail = 1; ++ } ++ ++ void process_unroll_generic(const int ndims, int len, const bool h_padded) { + const int blk = 8; + + int i_off[2 * blk] = {0}; +@@ -936,22 +1010,37 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { + + for (int off = 0; off < len; off += blk) { + const int reg_unroll = nstl::min(off + blk, len) - off; ++ int ip_padding[blk] = {0}; + +- /* compute offsets */ ++ /* compute offsets and tail*/ + for (int ur = off != 0 ? 0 : 1; ur < reg_unroll; ++ur) { + const int ur_c = curr * blk + ur; + const int ur_p = (ur_c - 1 + 2 * blk) % (2 * blk); // prev ur + step(off + ur, i_off[ur_p], o_off[ur_p], s_off[ur_p], + i_off[ur_c], o_off[ur_c], s_off[ur_c]); ++ if (h_padded) ++ comp_padding_flag(ndims, off + ur, len, ip_padding[ur]); + } +- + process_unroll_generic_step(reg_unroll, i_off + curr * blk, +- o_off + curr * blk, s_off + curr * blk); ++ o_off + curr * blk, s_off + curr * blk, ip_padding, ++ h_padded); + + curr = 1 - curr; + } + } + ++ void compute_ker( ++ const int ndims, const int len_unroll, const bool h_padded) { ++ bool optimized = false; ++ optimized = optimized ++ || (process_direct_copy(len_unroll) && !h_padded); ++ optimized = optimized ++ || (process_direct_copy(len_unroll) && !h_padded); ++ optimized ++ = optimized || (process_unroll_tr8x8(len_unroll) && !h_padded); ++ if (!optimized) process_unroll_generic(ndims, len_unroll, h_padded); ++ } ++ + void loop_begin(Label &l, XReg reg_cnt, int len) { + mov(reg_cnt, len); + L(l); +@@ -985,6 +1074,28 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { + } + } + ++ void compute_blk_ker(const int len_unroll) { ++ int omp_ndims = prb_.full_ndims - prb_.ndims; ++ Label no_last_blk, end_label; ++ ++ if (prb_.ip_tail > 0 && prb_.op_tail == 0) { ++ if (omp_ndims == 0) { ++ cmp(reg_last_loop_cnt, 1); ++ bne(no_last_blk); ++ compute_ker(omp_ndims, len_unroll, true); ++ } else { ++ cmp(reg_blk_chunks, blk_cnt()); ++ bne(no_last_blk); ++ compute_ker(omp_ndims, len_unroll, true); ++ } ++ b(end_label); ++ } ++ ++ L(no_last_blk); ++ compute_ker(omp_ndims, len_unroll, false); ++ L(end_label); ++ } ++ + bool simple_impl() { + simple_impl_desc_t d; + if (!simple_impl_desc_init(prb_, &d)) return false; +@@ -1013,11 +1124,7 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { + if (n_jit_loops > 0) + loop_begin(l_loop[0], reg_cnt[0], n(nfu + 0) / ldu); + +- bool optimized = false; +- optimized = optimized || process_direct_copy(d.len_unroll); +- optimized = optimized || process_direct_copy(d.len_unroll); +- optimized = optimized || process_unroll_tr8x8(d.len_unroll); +- if (!optimized) process_unroll_generic(d.len_unroll); ++ compute_blk_ker(d.len_unroll); + + if (n_jit_loops > 0) + loop_end(l_loop[0], reg_cnt[0], n(nfu + 0) / ldu, is(nfu + 0) * ldu, +@@ -1236,9 +1343,13 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { + } + add_imm(X_TMP_0, abi_param1, PARAM(in), X_TMP_2); + add_imm(X_TMP_1, abi_param1, PARAM(out), X_TMP_2); ++ add_imm(reg_blk, abi_param1, PARAM(blk_chunks), reg_blk); + ldr(reg_ptr_in, ptr(X_TMP_0)); + ldr(reg_ptr_out, ptr(X_TMP_1)); ++ ldr(reg_blk_chunks, ptr(reg_blk)); ++ + #undef PARAM ++ mov_imm(reg_last_loop_cnt, 1); + + mov(x_ptr_in_off, XReg(reg_ptr_in.getIdx())); + mov(x_ptr_out_off, XReg(reg_ptr_out.getIdx())); +@@ -1282,6 +1393,10 @@ private: + XReg reg_off_out = x9; + XReg reg_off_scale = x10; + ++ XReg reg_blk = x11; ++ XReg reg_blk_chunks = x12; ++ XReg reg_last_loop_cnt = x11; ++ + XReg reg_tmp = x0; + + VReg4S xmm_scale = v15.s; +@@ -1416,10 +1531,16 @@ static void prb_thread_kernel_balance( + for (int d = 0; d < prb.ndims; ++d) + sz_total *= prb.nodes[d].n; + ++ /* The general expression for sz_drv_thr can be written as ++ * sz_drv_min = C0 + FC * (nthr > 1 ? 1 : 0) + VC * (nthr - 1) ++ * where FC and VC are fixed and variable costs respectively. ++ * Though for now, the below heuristic seems to be good enough */ ++ const size_t sz_drv_thr = (nthr > 1) ? 16 * nthr : 1; ++ + /* sz_drv_min is the minimal size for the parallel + * driver required for good parallelization */ + const size_t sz_drv_min +- = nstl::min(16 * nthr, utils::div_up(sz_total, 1024)); ++ = nstl::min(sz_drv_thr, utils::div_up(sz_total, 1024)); + + /* kdims -- # of dimensions processed by a kernel + * sz_ker_cur -- product of the dimension processed by a kernel +@@ -1440,7 +1561,8 @@ static void prb_thread_kernel_balance( + * (less than tr::ker_prb_size_min). In that case try to split the + * innermost driver dimension into two, to increase sz_ker_cur. */ + bool want_borrow_ker_from_drv = true && kdims < prb.ndims +- && sz_ker_cur < tr::ker_prb_size_min && sz_drv_cur > sz_drv_min; ++ && sz_ker_cur < tr::ker_prb_size_min && sz_drv_cur > sz_drv_min ++ && kdims != prb.blk_chunk_idx; + if (want_borrow_ker_from_drv) { + /* sz_want_borrow is the minimal sz, so that: + * o) sz_ker_cur * sz_want_borrow >= tr::ker_prb_size_min +@@ -1464,7 +1586,7 @@ static void prb_thread_kernel_balance( + * try to split the outermost kernel dimension into two, to increase + * sz_drv_cur. */ + bool want_borrow_drv_from_ker = true && sz_ker_cur > tr::ker_prb_size_min +- && sz_drv_cur < sz_drv_min; ++ && sz_drv_cur < sz_drv_min && kdims != prb.blk_chunk_idx; + if (want_borrow_drv_from_ker) { + size_t sz_want_borrow = utils::div_up(sz_drv_min, sz_drv_cur); + for (; prb.nodes[kdims - 1].n % sz_want_borrow; ++sz_want_borrow) +@@ -1518,6 +1640,8 @@ status_t jit_uni_reorder_t::pd_t::create(reorder_pd_t **reorder_pd, + prb_dump(prb); + }); + ++ CHECK(prb_check_blk(prb, *dst_md)); ++ + int ndims_ker_max; + int nthr = dnnl_get_max_threads(); + prb_thread_kernel_balance(prb, ndims_ker_max, nthr); +@@ -1552,7 +1676,7 @@ status_t jit_uni_reorder_t::pd_t::create(reorder_pd_t **reorder_pd, + + void jit_uni_reorder_t::omp_driver_0d( + int off, const char *in, char *out, const float *scale) const { +- tr::call_param_t c {in, out, scale}; ++ tr::call_param_t c {in, out, scale, 0}; + (*kernel_)(&c); + } + +@@ -1564,6 +1688,7 @@ void jit_uni_reorder_t::omp_driver_1d(int ithr, int nthr, int off, + c.in = in + d0 * ns[0].is * data_type_size(pd()->prb_.itype); + c.out = out + d0 * ns[0].os * data_type_size(pd()->prb_.otype); + c.scale = scale + d0 * ns[0].ss; ++ c.blk_chunks = d0; + (*kernel_)(&c); + }); + } +@@ -1571,6 +1696,7 @@ void jit_uni_reorder_t::omp_driver_1d(int ithr, int nthr, int off, + void jit_uni_reorder_t::omp_driver_2d(int ithr, int nthr, int off, + const char *in, char *out, const float *scale) const { + const tr::node_t *ns = pd()->prb_.nodes + off; ++ const int blk_idx_off = pd()->prb_.blk_chunk_idx - off; + for_nd(ithr, nthr, (ptrdiff_t)ns[1].n, (ptrdiff_t)ns[0].n, + [&](ptrdiff_t d1, ptrdiff_t d0) { + auto c = tr::call_param_t(); +@@ -1581,6 +1707,7 @@ void jit_uni_reorder_t::omp_driver_2d(int ithr, int nthr, int off, + + (d0 * ns[0].os + d1 * ns[1].os) + * data_type_size(pd()->prb_.otype); + c.scale = scale + d0 * ns[0].ss + d1 * ns[1].ss; ++ c.blk_chunks = utils::pick(blk_idx_off, d0, d1); + (*kernel_)(&c); + }); + } +@@ -1588,6 +1715,7 @@ void jit_uni_reorder_t::omp_driver_2d(int ithr, int nthr, int off, + void jit_uni_reorder_t::omp_driver_3d(int ithr, int nthr, int off, + const char *in, char *out, const float *scale) const { + const tr::node_t *ns = pd()->prb_.nodes + off; ++ const int blk_idx_off = pd()->prb_.blk_chunk_idx - off; + for_nd(ithr, nthr, (ptrdiff_t)ns[2].n, (ptrdiff_t)ns[1].n, + (ptrdiff_t)ns[0].n, [&](ptrdiff_t d2, ptrdiff_t d1, ptrdiff_t d0) { + auto c = tr::call_param_t(); +@@ -1598,6 +1726,7 @@ void jit_uni_reorder_t::omp_driver_3d(int ithr, int nthr, int off, + + (d0 * ns[0].os + d1 * ns[1].os + d2 * ns[2].os) + * data_type_size(pd()->prb_.otype); + c.scale = scale + d0 * ns[0].ss + d1 * ns[1].ss + d2 * ns[2].ss; ++ c.blk_chunks = utils::pick(blk_idx_off, d0, d1, d2); + (*kernel_)(&c); + }); + } +@@ -1605,6 +1734,7 @@ void jit_uni_reorder_t::omp_driver_3d(int ithr, int nthr, int off, + void jit_uni_reorder_t::omp_driver_4d(int ithr, int nthr, int off, + const char *in, char *out, const float *scale) const { + const tr::node_t *ns = pd()->prb_.nodes + off; ++ const int blk_idx_off = pd()->prb_.blk_chunk_idx - off; + for_nd(ithr, nthr, (ptrdiff_t)ns[3].n, (ptrdiff_t)ns[2].n, + (ptrdiff_t)ns[1].n, (ptrdiff_t)ns[0].n, + [&](ptrdiff_t d3, ptrdiff_t d2, ptrdiff_t d1, ptrdiff_t d0) { +@@ -1619,6 +1749,7 @@ void jit_uni_reorder_t::omp_driver_4d(int ithr, int nthr, int off, + * data_type_size(pd()->prb_.otype); + c.scale = scale + d0 * ns[0].ss + d1 * ns[1].ss + d2 * ns[2].ss + + d3 * ns[3].ss; ++ c.blk_chunks = utils::pick(blk_idx_off, d0, d1, d2, d3); + (*kernel_)(&c); + }); + } +diff --git a/src/cpu/aarch64/jit_uni_reorder.hpp b/src/cpu/aarch64/jit_uni_reorder.hpp +index 88762756c..2fb6f0f89 100644 +--- a/src/cpu/aarch64/jit_uni_reorder.hpp ++++ b/src/cpu/aarch64/jit_uni_reorder.hpp +@@ -1,6 +1,7 @@ + /******************************************************************************* + * Copyright 2018-2020 Intel Corporation + * Copyright 2020 FUJITSU LIMITED ++* Copyright 2022 Arm Ltd. and affiliates + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. +@@ -52,11 +53,19 @@ struct prb_t { + ptrdiff_t ooff; + scale_type_t scale_type; + float beta; ++ int full_ndims; ++ int ip_tail; ++ int op_tail; ++ int iblock; ++ int oblock; ++ int blk_chunk_idx; + }; + + status_t prb_init(prb_t &prb, const memory_desc_t &imd, + const memory_desc_t &omd, const primitive_attr_t *attr); + ++status_t prb_check_blk(prb_t &prb, const memory_desc_t &imd); ++ + /** sorts the problem nodes so that output strides come in ascending order */ + void prb_normalize(prb_t &p); + +@@ -82,6 +91,7 @@ struct call_param_t { + const void *in; + void *out; + const float *scale; ++ size_t blk_chunks; + }; + + struct kernel_t { +diff --git a/src/cpu/aarch64/jit_uni_reorder_utils.cpp b/src/cpu/aarch64/jit_uni_reorder_utils.cpp +index 3d6e424e3..7123811f8 100644 +--- a/src/cpu/aarch64/jit_uni_reorder_utils.cpp ++++ b/src/cpu/aarch64/jit_uni_reorder_utils.cpp +@@ -1,6 +1,7 @@ + /******************************************************************************* + * Copyright 2018-2021 Intel Corporation + * Copyright 2020 FUJITSU LIMITED ++* Copyright 2022 Arm Ltd. and affiliates + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. +@@ -15,7 +16,8 @@ + * limitations under the License. + *******************************************************************************/ + +-#include ++#include ++#include + + #include "common/c_types_map.hpp" + #include "common/dnnl_thread.hpp" +@@ -46,8 +48,65 @@ struct layout_desc_t { + strides_t strides; + }; + +-status_t cvt_mem_desc_to_layout_desc( +- const memory_desc_t &md_, layout_desc_t &ld, const dims_t &blocks) { ++static status_t compute_blk_and_tail( ++ const memory_desc_t &md_, const int idx, int &blk, int &tail) { ++ const auto md = memory_desc_wrapper(md_); ++ const auto &bd = md.blocking_desc(); ++ if (tail == 0) return status::success; ++ ++ const std::set unique_inner_idxs( ++ bd.inner_idxs, bd.inner_idxs + bd.inner_nblks); ++ std::set dims_with_multiple_blks; ++ for (dim_t dim : unique_inner_idxs) { ++ if (std::count(bd.inner_idxs, bd.inner_idxs + bd.inner_nblks, dim) > 1) ++ dims_with_multiple_blks.insert(dim); ++ } ++ ++ // Dims that have a tail and have multiple blocks are not supported by the jit kernel yet. ++ // For example: ++ // src_tag = abcd ++ // dst_tag = ABcd16b16a4b ++ // 16x15x3x3 ++ // In this case, 'b' dim has two blocks and has a tail. It is not a supported case. ++ if (dims_with_multiple_blks.find(idx) != dims_with_multiple_blks.end()) ++ return status::unimplemented; ++ ++ // Only supports inconsistent padding in single and double blocks ++ // and the total block size <= 256 ++ for (int iblk = bd.inner_nblks - 1; iblk > 0; --iblk) { ++ if (bd.inner_idxs[iblk] == idx) break; ++ blk *= bd.inner_blks[iblk]; ++ tail *= bd.inner_blks[iblk]; ++ } ++ if (unique_inner_idxs.size() > 2 || blk > 256) return status::unimplemented; ++ ++ return status::success; ++} ++ ++static status_t compute_chunk_idx(const prb_t &p, const memory_desc_t &imd_, ++ const memory_desc_t &omd_, const int blk_idx, int &chunk_idx) { ++ const auto imd = memory_desc_wrapper(imd_); ++ const auto omd = memory_desc_wrapper(omd_); ++ const auto &ibd = imd.blocking_desc(); ++ const auto &obd = omd.blocking_desc(); ++ if (p.ip_tail == 0 && p.op_tail == 0) return status::success; ++ ++ const ptrdiff_t is ++ = ibd.strides[blk_idx] * obd.inner_blks[obd.inner_idxs[blk_idx]]; ++ const ptrdiff_t os = obd.strides[blk_idx]; ++ ++ for (int i = blk_idx; i < omd.ndims(); ++i) { ++ if (p.nodes[i].os == os && p.nodes[i].is == is) { ++ chunk_idx = i; ++ return status::success; ++ } ++ } ++ ++ return status::invalid_arguments; ++} ++ ++status_t cvt_mem_desc_to_layout_desc(const memory_desc_t &md_, ++ layout_desc_t &ld, const dims_t &blocks, const dims_t &ext_padding) { + const auto md = memory_desc_wrapper(md_); + + bool ok = true && md.is_blocking_desc() && md.extra().flags == 0; +@@ -75,7 +134,7 @@ status_t cvt_mem_desc_to_layout_desc( + stride *= bd.inner_blks[iblk]; + } + } +- P(d, md.padded_dims()[d] / blocks[d], bd.strides[d]); ++ P(d, (md.padded_dims()[d] + ext_padding[d]) / blocks[d], bd.strides[d]); + + // TODO: NOW: revisit, do we need a reverse? + // TODO: NOW: consider using strides instead of block sizes in md +@@ -98,7 +157,8 @@ status_t prb_init(prb_t &p, const memory_desc_t &imd, const memory_desc_t &omd, + + auto check_post_ops = [](const primitive_attr_t *attr) { + const auto &po = attr->post_ops_; +- return po.len() == 0 || (po.len() == 1 && po.entry_[0].is_sum(false)); ++ return po.len() == 0 ++ || (po.len() == 1 && po.contain(primitive_kind::sum, 0)); + }; + + bool ok = im_d.is_blocking_desc() && om_d.is_blocking_desc() +@@ -110,26 +170,58 @@ status_t prb_init(prb_t &p, const memory_desc_t &imd, const memory_desc_t &omd, + && check_post_ops(attr); + if (!ok) return unimplemented; + +- dims_t iblocks, oblocks; ++ dims_t iblocks, oblocks, ip_padding, op_padding; + im_d.compute_blocks(iblocks); + om_d.compute_blocks(oblocks); ++ utils::array_set(ip_padding, 0, im_d.ndims()); ++ utils::array_set(op_padding, 0, om_d.ndims()); ++ ++ /* padding_dim consistency check ++ * only supports inconsitent padding for src ++ * TODO: Add inconsistent padding support for dst */ ++ int ip_tail = 0; ++ int op_tail = 0; ++ int iblk_w_tail = 1; ++ int oblk_w_tail = 1; ++ int blk_idx = 0; + +- /* padding_dim consistency check */ + for (int d = 0; d < im_d.ndims(); ++d) { +- const auto pdim = im_d.padded_dims()[d]; +- bool ok = true && pdim == om_d.padded_dims()[d] +- && pdim % iblocks[d] == 0 && pdim % oblocks[d] == 0; +- if (!ok) return unimplemented; ++ const int ip_tmp_dim = im_d.padded_dims()[d]; ++ const int op_tmp_dim = om_d.padded_dims()[d]; ++ const int ip_tmp_tail = ip_tmp_dim % oblocks[d]; ++ const int op_tmp_tail = op_tmp_dim % iblocks[d]; ++ ++ const bool pdim_consistent = ip_tmp_dim == op_tmp_dim ++ && ip_tmp_tail == 0 && op_tmp_tail == 0; ++ const bool pdim_tail = ip_tmp_tail > 0 ++ && (ip_tmp_dim + oblocks[d] - ip_tmp_tail) == op_tmp_dim ++ && op_tmp_tail == 0 && ip_tail == 0; ++ if (!pdim_consistent && !pdim_tail) return status::unimplemented; ++ if (pdim_tail) { ++ blk_idx = d; ++ ip_tail = ip_tmp_tail; ++ op_tail = op_tmp_tail; ++ iblk_w_tail = iblocks[d]; ++ oblk_w_tail = oblocks[d]; ++ ip_padding[d] = oblocks[d] - ip_tmp_tail; ++ op_padding[d] = iblocks[d] - op_tmp_tail; ++ } + } ++ CHECK(compute_blk_and_tail(omd, blk_idx, oblk_w_tail, ip_tail)); + + layout_desc_t ild, old; +- status_t status = cvt_mem_desc_to_layout_desc(imd, ild, iblocks); ++ status_t status ++ = cvt_mem_desc_to_layout_desc(imd, ild, iblocks, ip_padding); + if (status != success) return status; +- status = cvt_mem_desc_to_layout_desc(omd, old, oblocks); ++ status = cvt_mem_desc_to_layout_desc(omd, old, oblocks, op_padding); + if (status != success) return status; + + p.itype = ild.dt; + p.otype = old.dt; ++ p.ip_tail = ip_tail; ++ p.op_tail = op_tail; ++ p.iblock = iblk_w_tail; ++ p.oblock = oblk_w_tail; + + p.scale_type = attr->output_scales_.has_default_values() + ? scale_type_t::NONE +@@ -156,7 +248,6 @@ status_t prb_init(prb_t &p, const memory_desc_t &imd, const memory_desc_t &omd, + + while (i_pos < ild.ndims && o_pos < old.ndims) { + assert(ild.id[i_pos] == old.id[o_pos]); +- if (ild.id[i_pos] != old.id[o_pos]) return runtime_error; + + assert(ndims < max_ndims); + if (ndims == max_ndims) return runtime_error; +@@ -191,7 +282,12 @@ status_t prb_init(prb_t &p, const memory_desc_t &imd, const memory_desc_t &omd, + ild.dims[i_pos] = factor; + } + } ++ int blk_chunk_idx = ndims; ++ CHECK(compute_chunk_idx(p, imd, omd, blk_idx, blk_chunk_idx)); ++ + p.ndims = ndims; ++ p.full_ndims = ndims; ++ p.blk_chunk_idx = blk_chunk_idx; + + p.ioff = memory_desc_wrapper(imd).offset0(); + p.ooff = memory_desc_wrapper(omd).offset0(); +@@ -211,8 +307,28 @@ void prb_normalize(prb_t &p) { + && p.nodes[j].n < p.nodes[min_pos].n); + if (new_min) min_pos = j; + } +- if (min_pos != d) nstl::swap(p.nodes[d], p.nodes[min_pos]); ++ if (min_pos != d) { ++ nstl::swap(p.nodes[d], p.nodes[min_pos]); ++ if (p.blk_chunk_idx == min_pos || p.blk_chunk_idx == d) ++ p.blk_chunk_idx = p.blk_chunk_idx == min_pos ? d : min_pos; ++ } ++ } ++} ++ ++status_t prb_check_blk(prb_t &p, const memory_desc_t &md_) { ++ const auto md = memory_desc_wrapper(md_); ++ const auto &bd = md.blocking_desc(); ++ if (p.ip_tail == 0) return status::success; ++ ++ // Check if the inner blocks and p.nodes[blk].n in the firsti nblks ++ // is equivalent in reverse order when has tail in block layout. ++ const int nblk = bd.inner_nblks; ++ for (int iblk = 0; iblk < nblk; ++iblk) { ++ if (bd.inner_blks[nblk - iblk - 1] ++ != static_cast(p.nodes[iblk].n)) ++ return status::unimplemented; + } ++ return status::success; + } + + void prb_simplify(prb_t &p) { +@@ -225,18 +341,29 @@ void prb_simplify(prb_t &p) { + for (int d = 0; d < p.ndims - 1; ++d) { + auto &this_node = p.nodes[d + 0]; + auto &next_node = p.nodes[d + 1]; ++ const bool skip_blk_idx = (p.ip_tail > 0 || p.op_tail > 0) ++ && (p.blk_chunk_idx == d || p.blk_chunk_idx == d + 1); + const bool fold = false +- || next_node.n == (size_t)1 // trivial case, just drop next node ++ || (next_node.n == static_cast(1) ++ && !skip_blk_idx) // trivial case, just drop next node + || (true // or real folding if possible +- && next_node.is == (ptrdiff_t)this_node.n * this_node.is +- && next_node.os == (ptrdiff_t)this_node.n * this_node.os ++ && !skip_blk_idx ++ && next_node.is ++ == static_cast( ++ this_node.n * this_node.is) ++ && next_node.os ++ == static_cast( ++ this_node.n * this_node.os) + && next_node.ss +- == (ptrdiff_t)this_node.n * this_node.ss); ++ == static_cast( ++ this_node.n * this_node.ss)); + if (fold) { + this_node.n *= next_node.n; + for (int j = d + 2; j < p.ndims; ++j) + p.nodes[j - 1] = p.nodes[j]; ++ if (d < p.blk_chunk_idx) --p.blk_chunk_idx; + --p.ndims; ++ --p.full_ndims; + --d; // make another try + } + } +@@ -251,6 +378,8 @@ void prb_node_split(prb_t &p, int dim, size_t n1) { + assert(p.nodes[dim].n % n1 == 0); + + p.ndims += 1; ++ p.full_ndims += 1; ++ if (dim < p.blk_chunk_idx) p.blk_chunk_idx += 1; + + for (int d = p.ndims; d > dim + 1; --d) + p.nodes[d] = p.nodes[d - 1]; From a071af0ca8a3ea1fe242094514507f649f3ab64e Mon Sep 17 00:00:00 2001 From: David Svantesson Date: Wed, 28 Jun 2023 09:47:55 +0000 Subject: [PATCH 026/376] Support for jit-ed block reorder on AArch64 --- tensorflow/workspace2.bzl | 1 + .../mkl_dnn/onednn_acl_reorder_update.patch | 4193 +++++++++++++++++ 2 files changed, 4194 insertions(+) create mode 100644 third_party/mkl_dnn/onednn_acl_reorder_update.patch diff --git a/tensorflow/workspace2.bzl b/tensorflow/workspace2.bzl index fa9a9dc04af216..f740842a64869b 100644 --- a/tensorflow/workspace2.bzl +++ b/tensorflow/workspace2.bzl @@ -208,6 +208,7 @@ def _tf_repositories(): "//third_party/mkl_dnn:onednn_acl_depthwise_convolution.patch", "//third_party/mkl_dnn:onednn_acl_threadpool_scheduler.patch", "//third_party/mkl_dnn:onednn_reorder_padded.patch", + "//third_party/mkl_dnn:onednn_acl_reorder_update.patch", ], sha256 = "a50993aa6265b799b040fe745e0010502f9f7103cc53a9525d59646aef006633", strip_prefix = "oneDNN-2.7.3", diff --git a/third_party/mkl_dnn/onednn_acl_reorder_update.patch b/third_party/mkl_dnn/onednn_acl_reorder_update.patch new file mode 100644 index 00000000000000..40d4cfa9fc2d11 --- /dev/null +++ b/third_party/mkl_dnn/onednn_acl_reorder_update.patch @@ -0,0 +1,4193 @@ +From b84c533dad4db495a92fc6d390a7db5ebd938a88 Mon Sep 17 00:00:00 2001 +From: Kentaro Kawakami +Date: Tue, 1 Nov 2022 09:33:41 +0900 +Subject: [PATCH] cpu: aarch64: reorder: support jit-ed blk_reorder + +--- + src/cpu/aarch64/jit_generator.hpp | 20 + + src/cpu/aarch64/jit_uni_reorder.cpp | 2315 +++++++++++++---- + src/cpu/aarch64/jit_uni_reorder.hpp | 183 +- + src/cpu/aarch64/jit_uni_reorder_utils.cpp | 482 ++-- + .../reorder/cpu_reorder_regular_f32_f32.cpp | 6 + + .../reorder/cpu_reorder_regular_f32_s32.cpp | 2 + + .../reorder/cpu_reorder_regular_f32_s8.cpp | 2 + + .../reorder/cpu_reorder_regular_f32_u8.cpp | 2 + + src/cpu/reorder/cpu_reorder_regular_s32.cpp | 2 + + src/cpu/reorder/cpu_reorder_regular_s8.cpp | 2 + + src/cpu/reorder/cpu_reorder_regular_u8.cpp | 2 + + 11 files changed, 2272 insertions(+), 746 deletions(-) + +diff --git a/src/cpu/aarch64/jit_generator.hpp b/src/cpu/aarch64/jit_generator.hpp +index dd781a622e1..12de9fa8c01 100644 +--- a/src/cpu/aarch64/jit_generator.hpp ++++ b/src/cpu/aarch64/jit_generator.hpp +@@ -435,6 +435,26 @@ class jit_generator : public Xbyak_aarch64::CodeGenerator, public c_compatible { + Xbyak_aarch64::ZRegD(z3.getIdx())); + } + ++ void uni_ld1rw(const Xbyak_aarch64::VReg4S &dst, ++ const Xbyak_aarch64::XReg &base, const int64_t off) { ++ if (off == 0) { ++ ld1r(dst, ptr(base)); ++ } else { ++ add_imm(X_DEFAULT_ADDR, base, off, X_TMP_0); ++ ld1r(dst, ptr(X_DEFAULT_ADDR)); ++ } ++ } ++ ++ void uni_ld1rw(const Xbyak_aarch64::ZRegS &dst, ++ const Xbyak_aarch64::XReg &base, const int64_t off) { ++ if (-32 <= off && off < 32) { ++ ld1rw(dst, P_ALL_ONE / Xbyak_aarch64::T_z, ptr(base, (int)off)); ++ } else { ++ add_imm(X_DEFAULT_ADDR, base, off, X_TMP_0); ++ ld1rw(dst, P_ALL_ONE / Xbyak_aarch64::T_z, ptr(X_DEFAULT_ADDR)); ++ } ++ } ++ + void uni_ldr( + const Xbyak_aarch64::VReg &dst, const Xbyak_aarch64::XReg &addr) { + ldr(Xbyak_aarch64::QReg(dst.getIdx()), ptr(addr)); +diff --git a/src/cpu/aarch64/jit_uni_reorder.cpp b/src/cpu/aarch64/jit_uni_reorder.cpp +index a6cefaa20e8..a708da808c0 100644 +--- a/src/cpu/aarch64/jit_uni_reorder.cpp ++++ b/src/cpu/aarch64/jit_uni_reorder.cpp +@@ -1,6 +1,6 @@ + /******************************************************************************* +-* Copyright 2018-2021 Intel Corporation +-* Copyright 2020-2021 FUJITSU LIMITED ++* Copyright 2018-2022 Intel Corporation ++* Copyright 2020-2022 FUJITSU LIMITED + * Copyright 2022 Arm Ltd. and affiliates + * + * Licensed under the Apache License, Version 2.0 (the "License"); +@@ -19,19 +19,21 @@ + #include + #include + +-#include "dnnl_debug.h" ++#include "oneapi/dnnl/dnnl_debug.h" + + #include "common/c_types_map.hpp" ++#include "common/dnnl_thread.hpp" + #include "common/memory_desc_wrapper.hpp" + #include "common/nstl.hpp" + #include "common/primitive.hpp" + #include "common/type_helpers.hpp" + #include "common/utils.hpp" + +-#include "cpu/aarch64/jit_uni_reorder.hpp" + #include "cpu/cpu_primitive.hpp" + #include "cpu/reorder/cpu_reorder_pd.hpp" + ++#include "cpu/aarch64/jit_uni_reorder.hpp" ++ + #include "cpu/aarch64/jit_generator.hpp" + + // #define TR_DEBUG +@@ -67,23 +69,6 @@ static bool prb_has_small_strides(const prb_t &prb) { + return true; + } + +-static bool prb_tail_friendly(const prb_t &prb) { +- /* find optimal ndims to makes it easier to +- * identify the blk_chunk in the loop*/ +- int ndims = prb.full_ndims - prb.ndims; +- +- int n = prb.nodes[0].is; +- for (int d = 1; d < prb.ndims; ++d) { +- if (d != prb.blk_chunk_idx) n *= prb.nodes[d].n; +- } +- if (prb.ip_tail > 0 +- && ((ndims == 0 && n != 1) +- || (ndims > 0 && prb.ndims > prb.blk_chunk_idx))) +- return false; +- +- return true; +-} +- + /** Minimal reasonable/desirable kernel size. + * The constant might be used to determine how a problem should be split + * between kernel and threading driver. */ +@@ -96,6 +81,9 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { + void operator()(const call_param_t *c) const override { + jit_generator::operator()(c); + } ++ void operator()(const tail_call_param_t *c) const override { ++ jit_generator::operator()(c); ++ } + + status_t create_kernel() override { return jit_generator::create_kernel(); } + +@@ -105,30 +93,53 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { + }; + + struct simple_impl_desc_t { +- int ndims_full_unroll; +- int len_last_dim_unroll; +- int len_unroll; ++ int ndims_full_unroll = 0; ++ int len_last_dim_unroll = 0; ++ int tail_len_unroll = 0; ++ int len_unroll = 0; + }; + ++#define PARAM(x) \ ++ abi_param1, \ ++ prb_.is_tail_present ? offsetof(tail_call_param_t, base_params) \ ++ + offsetof(call_param_t, x) \ ++ : offsetof(call_param_t, x) ++#define TAIL_PARAM(x) abi_param1, offsetof(tail_call_param_t, x) ++ + static bool simple_impl_desc_init( + const prb_t &prb, simple_impl_desc_t *desc) { + const int ndims = prb.ndims; + + int ndims_full_unroll = 0; + int len_last_dim_unroll = 1; ++ int tail_len_unroll = 0; + int len_unroll = 1; + +- for (int d = 0; d < ndims; ++d) { +- auto &node = prb.nodes[d]; +- if (len_unroll * node.n <= len_unroll_max) { +- ndims_full_unroll++; +- len_unroll *= node.n; +- } else { +- len_last_dim_unroll = len_unroll_max / len_unroll; +- while (node.n % len_last_dim_unroll) +- --len_last_dim_unroll; +- len_unroll *= len_last_dim_unroll; +- break; ++ // It is responsible for finding as many values ++ // as kernel can unroll. If tail is present then ++ // kernel will unroll only last node (possible improvement). ++ // If there is no tail kernel can unroll a few nodes without any loops etc. ++ // ndims_full_unroll - how many nodes will be unrolled ++ // len_last_dim_unroll - what piece of last unrolled node will be unrolled ++ if (prb.is_tail_present) { ++ ndims_full_unroll = 1; ++ len_unroll = prb.nodes[0].n; ++ tail_len_unroll = prb.nodes[0].is_zero_pad_needed ++ ? 0 ++ : static_cast(prb.nodes[0].tail_size); ++ } else { ++ for (int d = 0; d < ndims; ++d) { ++ const auto &node = prb.nodes[d]; ++ if (len_unroll * node.n <= len_unroll_max) { ++ ndims_full_unroll++; ++ len_unroll *= node.n; ++ } else { ++ len_last_dim_unroll = len_unroll_max / len_unroll; ++ while (node.n % len_last_dim_unroll) ++ --len_last_dim_unroll; ++ len_unroll *= len_last_dim_unroll; ++ break; ++ } + } + } + +@@ -137,6 +148,7 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { + if (desc) { + desc->ndims_full_unroll = ndims_full_unroll; + desc->len_last_dim_unroll = len_last_dim_unroll; ++ desc->tail_len_unroll = tail_len_unroll; + desc->len_unroll = len_unroll; + } + +@@ -151,62 +163,69 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { + && utils::one_of(p.otype, f32, s32, data_type::s8, u8) + && utils::everyone_is(0, p.ioff, p.ooff) /* do we need this? */ + && utils::one_of(p.beta, 0.f, 1.f) /* anything else? */ +- && simple_impl_desc_init(p, nullptr) && prb_has_small_strides(p) +- && prb_tail_friendly(p); +- if (!ok) return false; ++ && simple_impl_desc_init(p, nullptr) ++ && prb_has_small_strides(p); + +- return true; ++ return ok; + } + +- int n(int d) { +- assert(d < prb_.ndims); +- return (int)prb_.nodes[d].n; +- } +- int is(int d) { +- assert(d < prb_.ndims); +- return (int)prb_.nodes[d].is; +- } +- int os(int d) { +- assert(d < prb_.ndims); +- return (int)prb_.nodes[d].os; ++ XReg o_addr(int o_off, bool with_type_multiplier = true) { ++ if (o_off) { ++ add_imm(X_DEFAULT_ADDR, x_ptr_out_off, ++ o_off * (with_type_multiplier ? otype_sz_ : 1), X_TMP_0); ++ return X_DEFAULT_ADDR; ++ } ++ ++ return x_ptr_out_off; + } +- int ss(int d) { +- assert(d < prb_.ndims); +- return (int)prb_.nodes[d].ss; ++ ++ XReg c_addr(int c_off) { ++ if (c_off) { ++ add_imm(X_DEFAULT_ADDR, x_ptr_comp_off, c_off, X_TMP_0); ++ return X_DEFAULT_ADDR; ++ } ++ ++ return x_ptr_comp_off; + } + +- int blk_cnt() { +- assert(prb_.blk_chunk_idx < prb_.full_ndims); +- return (int)prb_.nodes[prb_.blk_chunk_idx].n - 1; ++ XReg data_chunk_addr(int node_id) { ++ add_imm(X_DEFAULT_ADDR, abi_param1, ++ offsetof(tail_call_param_t, curr_data_chunks) ++ + sizeof(int64_t) * (node_id), ++ X_TMP_0); ++ return X_DEFAULT_ADDR; + } +- int op_padding() { return prb_.op_tail ? prb_.iblock - prb_.op_tail : 0; } +- int ip_padding() { return prb_.ip_tail ? prb_.oblock - prb_.ip_tail : 0; } + + void step(int off, int prev_i_off, int prev_o_off, int prev_s_off, +- int &i_off, int &o_off, int &s_off, int step_size = 1) { ++ int prev_c_off, int &i_off, int &o_off, int &s_off, int &c_off, ++ int step_size = 1) { + i_off = prev_i_off; + o_off = prev_o_off; + s_off = prev_s_off; ++ c_off = prev_c_off; + + if (off == 0) return; + + int start_dim = 0, dims_prod = 1; + for (; start_dim < prb_.ndims && dims_prod != step_size; ++start_dim) +- dims_prod *= n(start_dim); ++ dims_prod *= prb_.n(start_dim); + assert(start_dim < prb_.ndims); + off /= step_size; + +- for (int d = start_dim; d < prb_.ndims; ++d) { +- i_off += is(d); +- o_off += os(d); +- s_off += ss(d); ++ for (int dim_id = start_dim; dim_id < prb_.ndims; ++dim_id) { ++ i_off += prb_.is(dim_id); ++ o_off += prb_.os(dim_id); ++ s_off += prb_.ss(dim_id); ++ c_off += prb_.cs(dim_id); ++ ++ if (off % prb_.n(dim_id)) break; + +- if (off % n(d)) break; ++ i_off += -prb_.n(dim_id) * prb_.is(dim_id); ++ o_off += -prb_.n(dim_id) * prb_.os(dim_id); ++ s_off += -prb_.n(dim_id) * prb_.ss(dim_id); ++ c_off += -prb_.n(dim_id) * prb_.cs(dim_id); + +- i_off += -n(d) * is(d); +- o_off += -n(d) * os(d); +- s_off += -n(d) * ss(d); +- off /= n(d); ++ off /= prb_.n(dim_id); + + if (off == 0) break; /* FIXME: is it really required? */ + } +@@ -215,8 +234,8 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { + void step(int off, int prev_i_off, int prev_o_off, int &i_off, int &o_off, + int step_size = 1) { + int dummy = 0; +- step(off, prev_i_off, prev_o_off, dummy, i_off, o_off, dummy, +- step_size); ++ step(off, prev_i_off, prev_o_off, dummy, dummy, i_off, o_off, dummy, ++ dummy, step_size); + } + + void tr8x8_sve256(int i_off, int o_off) { +@@ -278,40 +297,36 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { + && interim_f32); + const uint64_t sveLen = get_sve_length(); + +- add_imm(X_TMP_0, XReg(x_ptr_in_off), i_off * itype_sz, X_DEFAULT_ADDR); +- add_imm(X_TMP_1, X_TMP_0, is(0) * itype_sz, X_DEFAULT_ADDR); +- add_imm(X_TMP_2, X_TMP_1, is(0) * itype_sz, X_DEFAULT_ADDR); +- add_imm(X_TMP_3, X_TMP_2, is(0) * itype_sz, X_DEFAULT_ADDR); +- +- if (unroll * itype_sz == 32) +- for (uint32_t i = 0; i < 4; i++) +- ld1w(ZRegS {i}, p_lsb_256 / T_z, ptr(x_tmp_vec[i])); +- else if (unroll * itype_sz == 16) +- for (uint32_t i = 0; i < 4; i++) +- ldr(QReg {i}, ptr(x_tmp_vec[i])); +- else if (unroll * itype_sz == 8) +- for (uint32_t i = 0; i < 4; i++) +- ldr(DReg {i}, ptr(x_tmp_vec[i])); +- +- add_imm(X_TMP_0, X_TMP_3, is(0) * itype_sz, X_DEFAULT_ADDR); +- add_imm(X_TMP_1, X_TMP_0, is(0) * itype_sz, X_DEFAULT_ADDR); +- add_imm(X_TMP_2, X_TMP_1, is(0) * itype_sz, X_DEFAULT_ADDR); +- add_imm(X_TMP_3, X_TMP_2, is(0) * itype_sz, X_DEFAULT_ADDR); +- +- if (unroll * itype_sz == 32) +- for (uint32_t i = 0; i < 4; i++) +- ld1w(ZRegS {4 + i}, p_lsb_256 / T_z, ptr(x_tmp_vec[i])); +- else if (unroll * itype_sz == 16) +- for (uint32_t i = 0; i < 4; i++) +- ldr(QReg {4 + i}, ptr(x_tmp_vec[i])); +- else if (unroll * itype_sz == 8) +- for (uint32_t i = 0; i < 4; i++) +- ldr(DReg {4 + i}, ptr(x_tmp_vec[i])); ++ PReg p_size(DUMMY_IDX); ++ switch (unroll * itype_sz_) { ++ case 32: p_size = p_lsb_256; break; ++ case 16: p_size = p_lsb_128; break; ++ case 8: p_size = p_lsb_64; break; ++ default: assert(!"unreachable"); ++ } ++ ++ const int node_0_input_stride = prb_.is(0); ++ add_imm(X_TMP_0, XReg(x_ptr_in_off), itype_sz_ * i_off, X_DEFAULT_ADDR); ++ for (int i = 1; i < unroll / 2; i++) { ++ add_imm(x_tmp_vec[i], x_tmp_vec[i - 1], ++ itype_sz_ * node_0_input_stride, X_DEFAULT_ADDR); ++ } ++ ++ for (uint32_t i = 0; i < unroll / 2; i++) ++ ld1w(ZRegS {i}, p_size / T_z, ptr(x_tmp_vec[i])); ++ ++ for (int i = 0; i < unroll / 2; i++) { ++ add_imm(x_tmp_vec[i], x_tmp_vec[(i + 3) % 4], ++ itype_sz_ * node_0_input_stride, X_DEFAULT_ADDR); ++ } ++ ++ for (uint32_t i = 0; i < unroll / 2; i++) ++ ld1w(ZRegS {4 + i}, p_size / T_z, ptr(x_tmp_vec[i])); + + if (interim_f32) cvt2ps(0, unroll, prb_.itype); + + #if 0 +- /* Deubg code */ ++ /* Debug code */ + index(z0.s, 0, 1); + mov(z0.s, P_NOT_256/T_m, 0); + mov(z_tmp_vec[0].s, 16); +@@ -348,9 +363,9 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { + for (uint32_t i = 0; i < unroll / 2; i++) { + ZRegB z {unroll / 2 + i}; + ZRegB z_tmp = z_tmp_vec[unroll / 2 + i].b; +- /* Move bit 128-255 to 0-127. */ +- ext(z, z, 16); + /* Move bit 0-127 to 128-255. */ ++ ext(z, z, 16); ++ /* Move bit 128-255 to 0-127. */ + ext(z_tmp, z_tmp, sveLen - 16); + } + +@@ -363,65 +378,64 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { + } + + if (need_saturation) { +- init_saturate_f32(ymm_zero, ymm_saturation_ubound, reg_tmp, ++ init_saturate_f32(ymm_zero_, ymm_saturation_ubound_, reg_tmp_, + interim_f32 ? f32 : prb_.itype, prb_.otype); + for (int i = 0; i < unroll; i++) +- saturate_f32(ZRegS(i), ymm_zero, ymm_saturation_ubound, +- prb_.otype, p_all); ++ saturate_f32(ZRegS(i), ymm_zero_, ymm_saturation_ubound_, ++ prb_.otype, P_ALL_ONE); + } + + if (prb_.otype != f32) + cvt2odt(0, unroll, prb_.otype, interim_f32 ? f32 : prb_.itype); + +- add_imm(X_TMP_0, XReg(x_ptr_out_off), o_off * otype_sz, X_DEFAULT_ADDR); +- add_imm(X_TMP_1, X_TMP_0, os(1) * otype_sz, X_DEFAULT_ADDR); +- add_imm(X_TMP_2, X_TMP_1, os(1) * otype_sz, X_DEFAULT_ADDR); +- add_imm(X_TMP_3, X_TMP_2, os(1) * otype_sz, X_DEFAULT_ADDR); +- +- if (unroll * otype_sz == 32) +- for (uint32_t i = 0; i < 4; i++) +- st1w(ZRegS {i}, p_lsb_256 / T_z, ptr(x_tmp_vec[i])); +- else if (unroll * otype_sz == 16) +- for (uint32_t i = 0; i < 4; i++) +- str(QReg {i}, ptr(x_tmp_vec[i])); +- else if (unroll * otype_sz == 8) +- for (uint32_t i = 0; i < 4; i++) +- str(DReg {i}, ptr(x_tmp_vec[i])); +- +- add_imm(X_TMP_0, X_TMP_3, os(1) * otype_sz, X_DEFAULT_ADDR); +- add_imm(X_TMP_1, X_TMP_0, os(1) * otype_sz, X_DEFAULT_ADDR); +- add_imm(X_TMP_2, X_TMP_1, os(1) * otype_sz, X_DEFAULT_ADDR); +- add_imm(X_TMP_3, X_TMP_2, os(1) * otype_sz, X_DEFAULT_ADDR); +- +- if (unroll * otype_sz == 32) +- for (uint32_t i = 0; i < 4; i++) +- st1w(ZRegS {4 + i}, p_lsb_256 / T_z, ptr(x_tmp_vec[i])); +- else if (unroll * otype_sz == 16) +- for (uint32_t i = 0; i < 4; i++) +- str(QReg {4 + i}, ptr(x_tmp_vec[i])); +- else if (unroll * otype_sz == 8) +- for (uint32_t i = 0; i < 4; i++) +- str(DReg {4 + i}, ptr(x_tmp_vec[i])); ++ const int node_1_output_stride = prb_.os(1); ++ ++ switch (unroll * otype_sz_) { ++ case 32: p_size = p_lsb_256; break; ++ case 16: p_size = p_lsb_128; break; ++ case 8: p_size = p_lsb_64; break; ++ default: assert(!"unreachable"); ++ } ++ ++ add_imm(X_TMP_0, XReg(x_ptr_out_off), otype_sz_ * o_off, ++ X_DEFAULT_ADDR); ++ for (int i = 1; i < unroll / 2; i++) { ++ add_imm(x_tmp_vec[i], x_tmp_vec[i - 1], ++ otype_sz_ * node_1_output_stride, X_DEFAULT_ADDR); ++ } ++ ++ for (uint32_t i = 0; i < 4; i++) ++ st1w(ZRegS {i}, p_size / T_z, ptr(x_tmp_vec[i])); ++ ++ for (int i = 0; i < unroll / 2; i++) { ++ add_imm(x_tmp_vec[i], x_tmp_vec[(i + 3) % 4], ++ otype_sz_ * node_1_output_stride, X_DEFAULT_ADDR); ++ } ++ ++ for (uint32_t i = 0; i < unroll / 2; i++) ++ st1w(ZRegS {4 + i}, p_size / T_z, ptr(x_tmp_vec[i])); + } + + bool can_do_tr8x8() { + using namespace data_type; + +- return get_sve_length() >= Xbyak_aarch64::util::SVE_256 +- && prb_.ndims >= 2 ++ static constexpr int desirable_node_size = 8; ++ static constexpr int desirable_stride = 1; ++ ++ return mayiuse(sve_256) && prb_.ndims >= 2 + && ((utils::one_of(prb_.itype, u8, data_type::s8, s32, f32) + && utils::one_of( + prb_.otype, u8, data_type::s8, s32, f32))) +- && utils::everyone_is(8, n(0), n(1)) +- && utils::everyone_is(1, os(0), is(1)) +- && utils::everyone_is(0, prb_.ip_tail, prb_.op_tail) ++ && utils::everyone_is(desirable_node_size, prb_.n(0), prb_.n(1)) ++ && utils::everyone_is(desirable_stride, prb_.os(0), prb_.is(1)) ++ && !prb_.is_tail_present + && prb_.scale_type == scale_type_t::NONE && prb_.beta == 0.f; + } + +- bool process_unroll_tr8x8(int len) { ++ bool process_unroll_tr8x8(const int ndims, const int len) { + if (!can_do_tr8x8()) return false; + +- const int step_size = n(0) * n(1); ++ const int step_size = prb_.n(0) * prb_.n(1); + int i_off = 0, o_off = 0; + for (int off = 0; off < len; off += step_size) { + step(off, i_off, o_off, i_off, o_off, step_size); +@@ -432,23 +446,56 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { + } + + template +- bool process_direct_copy(int len) { ++ bool process_direct_copy(const int ndims, const int len) { + using namespace data_type; + +- const int simd_w = cpu_isa_traits::vlen / itype_sz; +- bool can_do = true && mayiuse(isa) +- && utils::everyone_is(1, os(0), is(0)) +- && (false || prb_.itype == prb_.otype ++ static constexpr int desirable_stride = 1; ++ using TRegS = ++ typename utils::conditional::type; ++ const int simd_w = cpu_isa_traits::vlen / itype_sz_; ++ ++ // TODO: support tail_processing for direct copy ++ ++ const bool do_src_zp = prb_.req_src_zp; ++ const bool do_dst_zp = prb_.req_dst_zp; ++ const bool zp_applicable = IMPLICATION( ++ (do_src_zp || do_dst_zp), utils::one_of(prb_.itype, s32, f32)); ++ const bool can_do = true && mayiuse(isa) ++ && compensation_needed_ == false ++ && utils::everyone_is(desirable_stride, prb_.os(0), prb_.is(0)) ++ && (false || (prb_.itype == prb_.otype ? zp_applicable : false) + || (prb_.itype == s32 && prb_.otype == f32) + || (prb_.itype == f32 && prb_.otype == s32)) +- && len % simd_w == 0 && n(0) % len == 0 +- && prb_.ip_tail % simd_w == 0 && prb_.op_tail % simd_w == 0 ++ && len % simd_w == 0 && prb_.n(0) % len == 0 ++ && !prb_.is_tail_present + && prb_.scale_type == scale_type_t::NONE && prb_.beta == 0.f; + if (!can_do) return false; + ++ static constexpr int vmm_zp_last_idx = 15; ++ const auto vmm_src_zp ++ = TRegS(do_dst_zp ? vmm_zp_last_idx - 1 : vmm_zp_last_idx); ++ if (do_src_zp) { ++ uni_ld1rw(vmm_src_zp, PARAM(src_zp)); ++ uni_scvtf(vmm_src_zp, vmm_src_zp); ++ } ++ const auto vmm_dst_zp = TRegS(vmm_zp_last_idx); ++ if (do_dst_zp) { ++ uni_ld1rw(vmm_dst_zp, PARAM(dst_zp)); ++ uni_scvtf(vmm_dst_zp, vmm_dst_zp); ++ } ++ ++ const auto apply_zp_ps = [&](const TRegS vmm) { ++ if (do_src_zp) fsub(vmm, vmm, vmm_src_zp); ++ if (do_dst_zp) fadd(vmm, vmm, vmm_dst_zp); ++ }; ++ + for (int off = 0; off < len;) { +- const int unroll ++ // TODO: we need extra reg for proper saturation if otype == s32 ++ int unroll + = nstl::min(16 - (prb_.otype == s32), (len - off) / simd_w); ++ unroll = (do_src_zp || do_dst_zp) ++ ? nstl::min(unroll, 16 - do_src_zp - do_dst_zp) ++ : unroll; + + int ur = 0; + int tmp_ur = 0; +@@ -458,14 +505,11 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { + + do { + add_imm(x_tmp_vec[count++], x_ptr_in_off, +- (off + ur * simd_w) * itype_sz, X_DEFAULT_ADDR); ++ (off + ur * simd_w) * itype_sz_, X_DEFAULT_ADDR); + ur++; + } while (ur < unroll && count < x_tmp_vec_size); + + for (int i = 0; i < count; i++) { +- /* if (vlen == 64) +- ldr(ZReg(tmp_ur + i), ptr(x_tmp_vec[i])); +- else */ + if (vlen == 64 || vlen == 32) + ld1w(ZRegS(tmp_ur + i), p_lsb_256 / T_z, + ptr(x_tmp_vec[i])); +@@ -478,33 +522,28 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { + } + + if (prb_.itype != prb_.otype) { +- const int vlen = cpu_isa_traits::vlen; + for (int ur = 0; ur < unroll; ++ur) { ++ TRegS r(ur); + if (prb_.itype == s32 && prb_.otype == f32) { +- if (vlen == 64 || vlen == 32) { +- ZRegS r(ur); +- /* MSB side 256 bits are ignored. */ +- scvtf(r, p_all / T_m, r); +- } else if (vlen == 16) { +- VReg4S r(ur); +- scvtf(r, r); +- } else +- assert(!"unreachable"); ++ uni_scvtf(r, r); + } else if (prb_.itype == f32 && prb_.otype == s32) { +- /* Out of order can be expected. */ +- if (vlen == 64 || vlen == 32) { +- ZRegS r(ur); +- frinti(r, p_all / T_m, r); +- fcvtzs(r, p_all / T_m, r); +- } else if (vlen == 16) { +- VReg4S r(ur); +- frinti(r, r); +- fcvtzs(r, r); +- } else +- assert(!"unreachable"); ++ uni_frinti(r, r); ++ uni_fcvtzs(r, r); + } else + assert(!"unreachable"); + } ++ } else if (do_src_zp || do_dst_zp) { ++ for (int ur = 0; ur < unroll; ++ur) { ++ const auto vmm = TRegS(ur); ++ if (prb_.otype == f32) { ++ apply_zp_ps(vmm); ++ } else if (prb_.otype == s32) { ++ uni_scvtf(vmm, vmm); ++ apply_zp_ps(vmm); ++ uni_frinti(vmm, vmm); ++ uni_fcvtzs(vmm, vmm); ++ } ++ } + } + + ur = 0; +@@ -515,7 +554,7 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { + + do { + add_imm(x_tmp_vec[count++], x_ptr_out_off, +- (off + ur * simd_w) * otype_sz, X_DEFAULT_ADDR); ++ (off + ur * simd_w) * otype_sz_, X_DEFAULT_ADDR); + ur++; + } while (ur < unroll && count < x_tmp_vec_size); + +@@ -538,8 +577,8 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { + } + + void process_unroll_generic_step(int reg_unroll, const int *i_off, +- const int *o_off, const int *s_off, const int *ip_padding, +- const bool h_padded) { ++ const int *o_off, const int *s_off, const int *c_off, ++ const int *zero_padding, const bool tail_processing) { + using namespace data_type; + + auto cvt2ps +@@ -588,76 +627,84 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { + } + }; + ++ auto load_bytes_addr = [=](const int ur, const int r) { ++ add_imm(x_tmp_vec[r], x_ptr_in_off, i_off[ur + r] * itype_sz_, ++ X_DEFAULT_ADDR); ++ }; ++ auto load_bytes = [=](const int ur, int size, int r) { ++ switch (size) { ++ case 4: ld1(VReg4S(ur)[r], ptr(x_tmp_vec[r])); break; ++ case 2: ld1(VReg8H(ur)[r], ptr(x_tmp_vec[r])); break; ++ case 1: ld1(VReg16B(ur)[r], ptr(x_tmp_vec[r])); break; ++ default: assert(!"unreachable"); ++ } ++ }; ++ ++ auto store = [=](const XReg &addr, const VReg ymm, int size) { ++ const uint32_t xmm = ymm.getIdx(); ++ switch (size) { ++ case 16: str(QReg(xmm), ptr(addr)); break; ++ case 8: str(DReg(xmm), ptr(addr)); break; ++ case 4: str(SReg(xmm), ptr(addr)); break; ++ case 2: str(HReg(xmm), ptr(addr)); break; ++ case 1: str(BReg(xmm), ptr(addr)); break; ++ default: assert(!"unreachable"); ++ } ++ }; ++ + /* check whether loading 4 values at once is possible */ +- bool can_load_xmm = reg_unroll % 4 == 0; ++ static constexpr int xmm_vlen = 4; ++ bool can_load_xmm = reg_unroll % xmm_vlen == 0; + for (int ur = 1; ur < reg_unroll; ++ur) +- if (i_off[ur] != i_off[ur - 1] + 1) can_load_xmm = false; +- const int load_step = can_load_xmm ? 4 : 1; ++ if (i_off[ur] != i_off[ur - 1] + 1) { ++ can_load_xmm = false; ++ break; ++ } ++ const int load_step = can_load_xmm ? xmm_vlen : 1; + + /* check whether storing 4 values at once is possible */ +- bool can_store_xmm = reg_unroll % 4 == 0; ++ bool can_store_xmm = reg_unroll % xmm_vlen == 0; + for (int ur = 1; ur < reg_unroll; ++ur) +- if (o_off[ur] != o_off[ur - 1] + 1) can_store_xmm = false; ++ if (o_off[ur] != o_off[ur - 1] + 1) { ++ can_store_xmm = false; ++ break; ++ } + const int ur_step = can_store_xmm ? 4 : 1; + const int load_tail_step + = !can_load_xmm && can_store_xmm ? ur_step : load_step; + +- const bool interim_f32 = false +- || utils::one_of(f32, prb_.itype, prb_.otype) +- || prb_.scale_type != scale_type_t::NONE || prb_.beta != 0.f; ++ const bool interim_f32 = interim_f32_needed(); + + const bool need_saturation + = (utils::one_of(prb_.otype, u8, data_type::s8, s32) + && interim_f32); +- if (h_padded) { ++ ++ std::vector store_masks; ++ if (tail_processing) { + for (int ur = 0; ur < reg_unroll; ur += load_tail_step) { +- if (itype_sz == 4) +- movi(VReg4S(ur), 0); +- else if (itype_sz == 2) +- movi(VReg8H(ur), 0); +- else +- movi(VReg16B(ur), 0); +- /* x_tmp_vec = X_TMP_0 - X_TMP_4 +- Do not use X_TMP_? as the last arg. */ +- for (int r = 0; r < load_tail_step; ++r) { +- if (ip_padding[ur + r] == 0) { +- add_imm(x_tmp_vec[r], x_ptr_in_off, +- i_off[ur + r] * itype_sz, X_DEFAULT_ADDR); +- } +- } ++ uni_clear(VReg(ur)); ++ store_masks.push_back(0); + + for (int r = 0; r < load_tail_step; ++r) { +- if (ip_padding[ur + r] == 0) { +- if (itype_sz == 4) +- ld1(VReg4S(ur)[r], ptr(x_tmp_vec[r])); +- else if (itype_sz == 2) +- ld1(VReg8H(ur)[r], ptr(x_tmp_vec[r])); +- else +- ld1(VReg16B(ur)[r], ptr(x_tmp_vec[r])); ++ if (zero_padding[ur + r] == 0) { ++ store_masks.back() += 1 << r; ++ load_bytes_addr(ur, r); + } + } ++ ++ for (int r = 0; r < load_tail_step; ++r) ++ if (zero_padding[ur + r] == 0) load_bytes(ur, itype_sz_, r); + } + } else { + if (!can_load_xmm && can_store_xmm) { +- assert(ur_step == 4); ++ assert(ur_step == xmm_vlen); + /* load with stride */ + for (int ur = 0; ur < reg_unroll; ur += ur_step) { +- +- /* x_tmp_vec = X_TMP_0 - X_TMP_4 +- Do not use X_TMP_? as the last arg. */ + for (int r = 0; r < ur_step; ++r) { +- add_imm(x_tmp_vec[r], x_ptr_in_off, +- i_off[ur + r] * itype_sz, X_DEFAULT_ADDR); +- } +- +- for (int r = 0; r < ur_step; ++r) { +- if (itype_sz == 4) +- ld1(VReg4S(ur)[r], ptr(x_tmp_vec[r])); +- else if (itype_sz == 2) +- ld1(VReg8H(ur)[r], ptr(x_tmp_vec[r])); +- else +- ld1(VReg16B(ur)[r], ptr(x_tmp_vec[r])); ++ load_bytes_addr(ur, r); + } ++ for (int r = 0; r < ur_step; ++r) ++ load_bytes(ur, itype_sz_, r); + } + } else { + int ur = 0; +@@ -667,13 +714,13 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { + + do { + add_imm(x_tmp_vec[count++], x_ptr_in_off, +- i_off[ur] * itype_sz, X_DEFAULT_ADDR); ++ i_off[ur] * itype_sz_, X_DEFAULT_ADDR); + ur += load_step; + } while (ur < reg_unroll && count < x_tmp_vec_size); + + for (int i = 0; i < count; i++) { + +- switch (load_step * itype_sz) { ++ switch (load_step * itype_sz_) { + case 16: + ldr(QReg(tmp_ur), ptr(x_tmp_vec[i])); + break; +@@ -688,6 +735,7 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { + } + } + } ++ + /* xmm[:] <-- (f32)xmm[:] */ + if (interim_f32) { + const int cvt_step = nstl::max(load_step, ur_step); +@@ -702,30 +750,32 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { + if (fast_return) { + if (prb_.scale_type == scale_type_t::COMMON) + for (int ur = 0; ur < reg_unroll; ur += load_step) +- fmul(VReg4S(ur), VReg4S(ur), xmm_scale); ++ fmul(VReg4S(ur), VReg4S(ur), xmm_scale_); + if (prb_.otype != f32) { +- init_saturate_f32(xmm_zero, xmm_saturation_ubound, reg_tmp, +- interim_f32 ? f32 : prb_.itype, prb_.otype); +- for (int ur = 0; ur < reg_unroll; ur += load_step) ++ init_saturate_f32(xmm_zero_, xmm_saturation_ubound_, ++ reg_tmp_, interim_f32 ? f32 : prb_.itype, ++ prb_.otype); ++ for (int ur = 0; ur < reg_unroll; ur += load_step) { + if (need_saturation) +- saturate_f32(VReg4S(ur), xmm_zero, +- xmm_saturation_ubound, prb_.otype, p_all); ++ saturate_f32(VReg4S(ur), xmm_zero_, ++ xmm_saturation_ubound_, prb_.otype, ++ P_ALL_ONE); ++ } + + for (int ur = 0; ur < reg_unroll; ur += load_step) + cvt2odt(ur, 1, prb_.otype, + interim_f32 ? f32 : prb_.itype); + } +- /* load_step is 1 or 4. */ + for (int ur = 0; ur < reg_unroll; ur += load_step) { + for (int r = 0; r < load_step; ++r) { + add_imm(x_tmp_vec[r], x_ptr_out_off, +- o_off[ur + r] * otype_sz, X_DEFAULT_ADDR); ++ o_off[ur + r] * otype_sz_, X_DEFAULT_ADDR); + } + + for (int r = 0; r < load_step; ++r) { +- if (otype_sz == 4) ++ if (otype_sz_ == 4) + st1(VReg4S(ur)[r], ptr(x_tmp_vec[r])); +- else if (otype_sz == 2) ++ else if (otype_sz_ == 2) + st1(VReg8H(ur)[r], ptr(x_tmp_vec[r])); + else + st1(VReg16B(ur)[r], ptr(x_tmp_vec[r])); +@@ -735,7 +785,7 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { + } + + /* scatter elements of xmm into 4 xmms */ +- if (itype_sz == 4 || interim_f32) { ++ if (itype_sz_ == 4 || interim_f32) { + for (int ur = 0; ur < reg_unroll; ur += load_step) + for (int r = 1; r < load_step; ++r) { + VReg4S v(ur); +@@ -747,7 +797,18 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { + for (int ur = 0; ur < reg_unroll; ur += load_step) + for (int r = 1; r < load_step; ++r) + ext(VReg16B(ur + r), VReg16B(ur), VReg16B(ur), +- itype_sz * r); ++ itype_sz_ * r); ++ } ++ } ++ ++ /* src zero point application */ ++ if (prb_.req_src_zp) { ++ for (int ur = 0; ur < reg_unroll; ur += ur_step) { ++ const auto xmm = VReg4S(ur); ++ if (interim_f32) ++ fsub(xmm, xmm, xmm_src_zp_); ++ else ++ sub(xmm, xmm, xmm_src_zp_); + } + } + +@@ -756,7 +817,7 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { + /* xmm <-- scale * xmm[:] */ + if (prb_.scale_type == scale_type_t::COMMON) { + for (int ur = 0; ur < reg_unroll; ur += ur_step) +- fmul(VReg4S(ur), VReg4S(ur), xmm_scale); ++ fmul(VReg4S(ur), VReg4S(ur), xmm_scale_); + } else if (prb_.scale_type == scale_type_t::MANY) { + enum class scale_load_type_t { bcast, load, gather }; + +@@ -769,13 +830,12 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { + scale_load_type = scale_load_type_t::load; + + if (scale_load_type == scale_load_type_t::bcast +- && !h_padded) { +- VReg4S v(xmm_scale.getIdx()); ++ && !tail_processing) { ++ VReg4S v(xmm_scale_.getIdx()); + VReg4S v_dst(ur); +- add_imm(X_TMP_0, x_ptr_scale_off, s_off[ur] * stype_sz, ++ add_imm(X_TMP_0, x_ptr_scale_off, s_off[ur] * stype_sz_, + X_DEFAULT_ADDR); +- ldr(W_TMP_0, ptr(X_TMP_0)); +- dup(v, W_TMP_0); ++ ld1r(v, ptr(X_TMP_0)); + fmul(v_dst, v_dst, v); + continue; + } +@@ -786,10 +846,10 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { + scale_load_type = scale_load_type_t::gather; + + if (scale_load_type == scale_load_type_t::load +- && !h_padded) { +- uint32_t idx = xmm_scale.getIdx(); ++ && !tail_processing) { ++ uint32_t idx = xmm_scale_.getIdx(); + VReg4S v_dst(ur); +- add_imm(X_TMP_0, x_ptr_scale_off, s_off[ur] * stype_sz, ++ add_imm(X_TMP_0, x_ptr_scale_off, s_off[ur] * stype_sz_, + X_DEFAULT_ADDR); + + ldr(QReg {idx}, ptr(X_TMP_0)); +@@ -799,22 +859,15 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { + + // load doesn't work as well + // so gather the scale factors one by one +- /*ur_step is 1 or 4. */ +- for (int r = ur; r < ur + ur_step; ++r) { +- if (ip_padding[r] == 0 || !h_padded) { +- /* x_tmp_vec = X_TMP_0 - X_TMP_4 +- Do not use X_TMP_? as the last arg. */ ++ for (int r = ur; r < ur + ur_step; ++r) ++ if (zero_padding[r] == 0 || !tail_processing) { + add_imm(x_tmp_vec[r - ur], x_ptr_scale_off, +- s_off[r] * stype_sz, X_DEFAULT_ADDR); +- } +- } +- for (int r = ur; r < ur + ur_step; ++r) { +- if (ip_padding[r] == 0 || !h_padded) { +- VReg4S v(xmm_scale.getIdx()); +- ld1(v[r - ur], ptr(x_tmp_vec[r - ur])); ++ s_off[r] * stype_sz_, X_DEFAULT_ADDR); + } +- } +- fmul(VReg4S(ur), VReg4S(ur), xmm_scale); ++ for (int r = ur; r < ur + ur_step; ++r) ++ if (zero_padding[r] == 0 || !tail_processing) ++ ld1(xmm_scale_[r - ur], ptr(x_tmp_vec[r - ur])); ++ fmul(VReg4S(ur), VReg4S(ur), xmm_scale_); + } + } + +@@ -829,7 +882,7 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { + + do { + add_imm(x_tmp_vec[count++], x_ptr_out_off, +- o_off[ur] * otype_sz, X_DEFAULT_ADDR); ++ o_off[ur] * otype_sz_, X_DEFAULT_ADDR); + ur += ur_step; + } while (ur < reg_unroll && count < x_tmp_vec_size); + +@@ -873,7 +926,7 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { + if (prb_.scale_type == scale_type_t::COMMON) { + for (int ur = 0; ur < reg_unroll; ur += ur_step) { + VReg4S tmp(ur); +- fmul(tmp, tmp, VReg4S(xmm_scale.getIdx())); ++ fmul(tmp, tmp, VReg4S(xmm_scale_.getIdx())); + } + } else if (prb_.scale_type == scale_type_t::MANY) { + int ur = 0; +@@ -883,7 +936,7 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { + + do { + add_imm(x_tmp_vec[count++], x_ptr_scale_off, +- s_off[ur] * stype_sz, X_DEFAULT_ADDR); ++ s_off[ur] * stype_sz_, X_DEFAULT_ADDR); + ur += ur_step; + } while (ur < reg_unroll && count < x_tmp_vec_size); + +@@ -908,7 +961,7 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { + + do { + add_imm(x_tmp_vec[count++], x_ptr_out_off, +- o_off[ur] * otype_sz, X_DEFAULT_ADDR); ++ o_off[ur] * otype_sz_, X_DEFAULT_ADDR); + ur += ur_step; + } while (ur < reg_unroll && count < (x_tmp_vec_size / 2)); + +@@ -951,94 +1004,272 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { + } + } + +- if (need_saturation) { +- init_saturate_f32( +- xmm_zero, xmm_saturation_ubound, reg_tmp, f32, prb_.otype); ++ /* dst zero point application */ ++ if (prb_.req_dst_zp) { + for (int ur = 0; ur < reg_unroll; ur += ur_step) { +- saturate_f32(VReg4S(ur), xmm_zero, xmm_saturation_ubound, +- prb_.otype, p_all); ++ const auto xmm = VReg4S(ur); ++ if (interim_f32) ++ fadd(xmm, xmm, xmm_dst_zp_); ++ else ++ add(xmm, xmm, xmm_dst_zp_); + } + } + +- for (int ur = 0; ur < reg_unroll; ur += ur_step) { +- if (prb_.otype != f32) +- cvt2odt(ur, 1, prb_.otype, interim_f32 ? f32 : prb_.itype); ++ /* adjust scale application */ ++ if (prb_.scale_adjust != 1.f) { ++ dup(xmm_tmp_, reg_scale_adjust_); ++ for (int ur = 0; ur < reg_unroll; ur += ur_step) { ++ fmul(VReg4S(ur), VReg4S(ur), xmm_tmp_); ++ } ++ } ++ ++ if (need_saturation) { ++ init_saturate_f32(xmm_zero_, xmm_saturation_ubound_, reg_tmp_, f32, ++ prb_.otype); ++ for (int ur = 0; ur < reg_unroll; ur += ur_step) { ++ saturate_f32(VReg4S(ur), xmm_zero_, xmm_saturation_ubound_, ++ prb_.otype, P_ALL_ONE); ++ } ++ ++ // reset back xmm_zero_ if needed. ++ if (compensation_needed_ && (prb_.req_src_zp || prb_.req_dst_zp)) ++ uni_clear(VReg(xmm_zero_.getIdx())); + } + +- int ur = 0; +- int tmp_ur = 0; +- while (ur < reg_unroll) { +- int count = 0; ++ if (compensation_needed_) { ++ const uint32_t xmm_begin = 9; ++ const uint32_t xmm_end = 11; ++ uint32_t xmm_id = xmm_begin; ++ const auto get_temp_xmm = [&] { ++ const Xbyak_aarch64::VReg temp {xmm_id++}; ++ ++ if (xmm_id > xmm_end) { xmm_id = xmm_begin; } ++ ++ return temp; ++ }; ++ if (can_store_xmm) { ++ enum class comp_load_type_t { bcast, load, gather }; ++ ++ for (int ur = 0; ur < reg_unroll; ur += ur_step) { ++ ++ bool all_ip_padding_one = true; ++ bool all_ip_padding_zero = true; ++ for (int r = ur; r < ur + ur_step; r++) { ++ if (zero_padding[r] != 1) ++ all_ip_padding_one = false; ++ else ++ all_ip_padding_zero = false; ++ } ++ if (all_ip_padding_one) continue; ++ ++ comp_load_type_t comp_load_type = comp_load_type_t::bcast; ++ ++ for (int r = ur + 1; r < ur + ur_step; ++r) ++ if (c_off[r] != c_off[r - 1] + 0) { ++ comp_load_type = comp_load_type_t::load; ++ break; ++ } + +- do { +- add_imm(x_tmp_vec[count++], x_ptr_out_off, o_off[ur] * otype_sz, +- X_DEFAULT_ADDR); +- ur += ur_step; +- } while (ur < reg_unroll && count < x_tmp_vec_size); ++ if (comp_load_type == comp_load_type_t::bcast ++ && all_ip_padding_zero) { ++ const auto reduction_xmm = get_temp_xmm().s4; ++ const auto xmm_reorder_result = VReg4S(ur); ++ frinti(reduction_xmm, xmm_reorder_result); ++ addv(SReg(reduction_xmm.getIdx()), reduction_xmm); ++ const auto comp_addr = c_addr(c_off[ur]); ++ const auto xmm_tmp_ = get_temp_xmm().s4; ++ ldr(SReg(xmm_tmp_.getIdx()), ptr(comp_addr)); ++ add(xmm_tmp_, xmm_tmp_, reduction_xmm); ++ str(SReg(xmm_tmp_.getIdx()), ptr(comp_addr)); ++ continue; ++ } ++ ++ if (comp_load_type == comp_load_type_t::load) ++ for (int r = ur + 1; r < ur + ur_step; ++r) ++ if (c_off[r] != c_off[r - 1] + 1) { ++ comp_load_type = comp_load_type_t::gather; ++ break; ++ } ++ ++ if (comp_load_type == comp_load_type_t::load ++ && all_ip_padding_zero) { ++ const auto xmm_reorder_result_dq = get_temp_xmm().s4; ++ const auto xmm_reorder_result = VReg4S(ur); ++ const auto comp_addr = c_addr(c_off[ur]); ++ frinti(xmm_reorder_result_dq, xmm_reorder_result); ++ const auto xmm_tmp_ = get_temp_xmm().s4; ++ ldr(SReg(xmm_tmp_.getIdx()), ptr(comp_addr)); ++ add(xmm_reorder_result_dq, xmm_reorder_result_dq, ++ xmm_tmp_); ++ str(SReg(xmm_tmp_.getIdx()), ptr(comp_addr)); ++ continue; ++ } + +- for (int i = 0; i < count; i++) { ++ const auto xmm_reorder_result_dq = get_temp_xmm().s4; ++ const auto xmm_reorder_result = VReg4S(ur); ++ frinti(xmm_reorder_result_dq, xmm_reorder_result); + +- switch (ur_step * otype_sz) { +- case 16: str(QReg(tmp_ur), ptr(x_tmp_vec[i])); break; +- case 8: str(DReg(tmp_ur), ptr(x_tmp_vec[i])); break; +- case 4: str(SReg(tmp_ur), ptr(x_tmp_vec[i])); break; +- case 2: str(HReg(tmp_ur), ptr(x_tmp_vec[i])); break; +- case 1: str(BReg(tmp_ur), ptr(x_tmp_vec[i])); break; +- default: assert(!"unreachable"); ++ for (int r = ur; r < ur + ur_step; ++r) { ++ if (zero_padding[r] == 0 || !tail_processing) { ++ mov(W_TMP_0, xmm_reorder_result_dq[r]); ++ const auto comp_addr = c_addr(c_off[ur]); ++ str(W_TMP_0, ptr(comp_addr)); ++ } ++ } ++ } ++ } else { ++ for (int ur = 0; ur < reg_unroll; ur += ur_step) { ++ if (zero_padding[ur] == 0 || !tail_processing) { ++ const auto xmm_reorder_result_dq = get_temp_xmm().s4; ++ const auto xmm_reorder_result = VReg4S(ur); ++ const auto comp_addr = c_addr(c_off[ur]); ++ frinti(xmm_reorder_result_dq, xmm_reorder_result); ++ const auto xmm_tmp_ = get_temp_xmm().s4; ++ ldr(SReg(xmm_tmp_.getIdx()), ptr(comp_addr)); ++ add(xmm_reorder_result_dq, xmm_reorder_result_dq, ++ xmm_tmp_); ++ str(SReg(xmm_tmp_.getIdx()), ptr(comp_addr)); ++ } + } +- tmp_ur += ur_step; + } + } ++ ++ for (int ur = 0; ur < reg_unroll; ur += ur_step) { ++ if (prb_.req_src_zp || prb_.req_dst_zp) { ++ const bool use_store_masks = !store_masks.empty(); ++ if (use_store_masks) { ++ const auto mask = (~store_masks[ur / ur_step]) & 0xF; ++ switch (mask) { ++ case 0x0: ++ /* Do nothing */ ++ break; ++ case 0x1: ins(VReg4S(ur)[0], xmm_zero_[0]); break; ++ case 0x2: ins(VReg4S(ur)[1], xmm_zero_[1]); break; ++ case 0x3: ++ ins(VReg2D(ur)[0], VReg2D(xmm_zero_.getIdx())[0]); ++ break; ++ case 0x4: ins(VReg4S(ur)[2], xmm_zero_[2]); break; ++ case 0x5: ++ ins(VReg4S(ur)[0], xmm_zero_[0]); ++ ins(VReg4S(ur)[2], xmm_zero_[2]); ++ break; ++ case 0x6: ++ ins(VReg4S(ur)[1], xmm_zero_[1]); ++ ins(VReg4S(ur)[2], xmm_zero_[2]); ++ break; ++ case 0x7: ++ ins(VReg2D(ur)[0], VReg2D(xmm_zero_.getIdx())[0]); ++ ins(VReg4S(ur)[2], xmm_zero_[2]); ++ break; ++ case 0x8: ins(VReg4S(ur)[3], xmm_zero_[3]); break; ++ case 0x9: ++ ins(VReg4S(ur)[0], xmm_zero_[0]); ++ ins(VReg4S(ur)[3], xmm_zero_[3]); ++ break; ++ case 0xa: ++ ins(VReg4S(ur)[1], xmm_zero_[1]); ++ ins(VReg4S(ur)[3], xmm_zero_[3]); ++ break; ++ case 0xb: ++ ins(VReg2D(ur)[0], VReg2D(xmm_zero_.getIdx())[0]); ++ ins(VReg4S(ur)[3], xmm_zero_[3]); ++ break; ++ case 0xc: ++ ins(VReg2D(ur)[1], VReg2D(xmm_zero_.getIdx())[1]); ++ break; ++ case 0xd: ++ ins(VReg4S(ur)[0], xmm_zero_[0]); ++ ins(VReg2D(ur)[1], VReg2D(xmm_zero_.getIdx())[1]); ++ break; ++ case 0xe: ++ ins(VReg4S(ur)[1], xmm_zero_[1]); ++ ins(VReg2D(ur)[1], VReg2D(xmm_zero_.getIdx())[1]); ++ break; ++ case 0xf: movi(VReg16B(ur), 0); break; ++ default: assert(!"unreachable"); ++ } ++ } ++ } ++ if (prb_.otype != f32) ++ cvt2odt(ur, 1, prb_.otype, interim_f32 ? f32 : prb_.itype); ++ ++ store(o_addr(o_off[ur]), VReg(ur), ur_step * otype_sz_); ++ } + } + +- void comp_padding_flag(int ndims, int off, int len, int &i_tail) { +- const int ip_without_padding +- = ndims == 0 ? len - ip_padding() : prb_.ip_tail; +- if ((ndims == 0 && off >= ip_without_padding) +- || (ndims > 0 && (off % prb_.oblock) >= ip_without_padding)) +- i_tail = 1; ++ bool interim_f32_needed() { ++ using namespace data_type; ++ ++ return utils::one_of(f32, prb_.itype, prb_.otype) ++ || prb_.scale_type != scale_type_t::NONE || prb_.beta != 0.f ++ || ((prb_.req_src_zp || prb_.req_dst_zp) ++ ? !(prb_.itype == s32 && prb_.otype == s32) ++ : false) ++ || (prb_.itype != f32 && compensation_needed_) ++ || prb_.scale_adjust != 1.f; + } + +- void process_unroll_generic(const int ndims, int len, const bool h_padded) { ++ void process_unroll_generic( ++ const int ndims, int len, const bool tail_processing) { ++ assert(IMPLICATION(prb_.nodes[0].tail_size > 0, ++ len == static_cast(prb_.nodes[0].n) ++ || len == static_cast(prb_.nodes[0].tail_size))); ++ + const int blk = 8; + + int i_off[2 * blk] = {0}; + int o_off[2 * blk] = {0}; + int s_off[2 * blk] = {0}; ++ int c_off[2 * blk] = {0}; + + int curr = 0; // will switch between 0 and 1 + ++ const bool interim_f32 = interim_f32_needed(); ++ ++ if (prb_.req_src_zp) { ++ add_imm(X_DEFAULT_ADDR, PARAM(src_zp), X_TMP_0); ++ ld1r(xmm_src_zp_, ptr(X_DEFAULT_ADDR)); ++ if (interim_f32) scvtf(xmm_src_zp_, xmm_src_zp_); ++ } ++ if (prb_.req_dst_zp) { ++ add_imm(X_DEFAULT_ADDR, PARAM(dst_zp), X_TMP_0); ++ ld1r(xmm_dst_zp_, ptr(X_DEFAULT_ADDR)); ++ if (interim_f32) scvtf(xmm_dst_zp_, xmm_dst_zp_); ++ } ++ + for (int off = 0; off < len; off += blk) { + const int reg_unroll = nstl::min(off + blk, len) - off; +- int ip_padding[blk] = {0}; ++ int zero_padding[blk] = {0}; ++ const auto curr_blk = curr * blk; + + /* compute offsets and tail*/ + for (int ur = off != 0 ? 0 : 1; ur < reg_unroll; ++ur) { +- const int ur_c = curr * blk + ur; ++ const int ur_c = curr_blk + ur; + const int ur_p = (ur_c - 1 + 2 * blk) % (2 * blk); // prev ur ++ const bool is_tail ++ = off + ur >= static_cast(prb_.nodes[0].tail_size); + step(off + ur, i_off[ur_p], o_off[ur_p], s_off[ur_p], +- i_off[ur_c], o_off[ur_c], s_off[ur_c]); +- if (h_padded) +- comp_padding_flag(ndims, off + ur, len, ip_padding[ur]); ++ c_off[ur_p], i_off[ur_c], o_off[ur_c], s_off[ur_c], ++ c_off[ur_c]); ++ if (tail_processing && is_tail) zero_padding[ur] = 1; + } +- process_unroll_generic_step(reg_unroll, i_off + curr * blk, +- o_off + curr * blk, s_off + curr * blk, ip_padding, +- h_padded); ++ ++ process_unroll_generic_step(reg_unroll, i_off + curr_blk, ++ o_off + curr_blk, s_off + curr_blk, c_off + curr_blk, ++ zero_padding, tail_processing); + + curr = 1 - curr; + } + } + + void compute_ker( +- const int ndims, const int len_unroll, const bool h_padded) { ++ const int ndims, const int len_unroll, const bool tail_processing) { + bool optimized = false; +- optimized = optimized +- || (process_direct_copy(len_unroll) && !h_padded); +- optimized = optimized +- || (process_direct_copy(len_unroll) && !h_padded); +- optimized +- = optimized || (process_unroll_tr8x8(len_unroll) && !h_padded); +- if (!optimized) process_unroll_generic(ndims, len_unroll, h_padded); ++ optimized = optimized || process_direct_copy(ndims, len_unroll) ++ || process_direct_copy(ndims, len_unroll) ++ || process_unroll_tr8x8(ndims, len_unroll); ++ if (!optimized) ++ process_unroll_generic(ndims, len_unroll, tail_processing); + } + + void loop_begin(Label &l, XReg reg_cnt, int len) { +@@ -1046,97 +1277,287 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { + L(l); + } + ++ void check_if_this_is_last_chunk(const XReg reg_curr_chunk, int node_id) { ++ // Chunks are backwards numered i.e: ++ // [0] -> [node_size] ++ // [1] -> [node_size - 1] ++ // ... ++ // [node_size - 1] -> [1] ++ ++ // It is done like this, because it is easier to decrement counter ++ // and check if it is equal to zero than increment and check ++ // if it is equal to node_size. ++ static constexpr int64_t last_chunk = 1; ++ cmp(reg_curr_chunk, last_chunk); ++ } ++ ++ void zero_dst_memory(const int bytes_to_zeroing) { ++ static constexpr int num_of_bytes_in_xmm = 128 / 8; ++ ++ const int xmms_to_zeroing ++ = std::div(bytes_to_zeroing, num_of_bytes_in_xmm).quot; ++ const int tail_to_zeroing ++ = std::div(bytes_to_zeroing, num_of_bytes_in_xmm).rem; ++ ++ movi(xmm_tmp_, 0); ++ ++ if (xmms_to_zeroing > 0) { ++ Label loop; ++ ++ mov(reg_tmp_, xmms_to_zeroing); ++ L(loop); ++ str(QReg(xmm_tmp_.getIdx()), ptr(o_addr(0))); ++ add_imm(reg_off_out_, reg_off_out_, num_of_bytes_in_xmm, X_TMP_0); ++ add_imm(x_ptr_out_off, x_ptr_out_off, num_of_bytes_in_xmm, X_TMP_0); ++ subs(reg_tmp_, reg_tmp_, 1); ++ mov(X_TMP_0, 32); ++ b(NE, loop); ++ } ++ ++ if (tail_to_zeroing) mov_imm(W_TMP_0, 0); ++ ++ for (int i = 0; i < tail_to_zeroing; i++) ++ strb(W_TMP_0, ptr(o_addr(i, false))); ++ ++ // Restore dst offset to initial value ++ if (xmms_to_zeroing > 0) { ++ sub_imm(reg_off_out_, reg_off_out_, ++ num_of_bytes_in_xmm * xmms_to_zeroing, X_TMP_0); ++ sub_imm(x_ptr_out_off, x_ptr_out_off, ++ num_of_bytes_in_xmm * xmms_to_zeroing, X_TMP_0); ++ } ++ } ++ ++ void finalize_tail_loop(int i_step, int o_step, int s_step, int c_step, ++ const int curr_node_id) { ++ static constexpr int empty_chunk_info = -1; ++ ++ mov(reg_tmp_, empty_chunk_info); ++ str(reg_tmp_, ptr(data_chunk_addr(curr_node_id))); ++ ++ const int padded_area = prb_.nodes[curr_node_id].n ++ - prb_.nodes[curr_node_id].tail_size; ++ ++ if (prb_.nodes[curr_node_id].is_zero_pad_needed) { ++ int num_of_zero_padded_values = padded_area; ++ for (int i = curr_node_id - 1; i >= 0; i--) { ++ num_of_zero_padded_values *= prb_.nodes[i].n; ++ } ++ ++ const int bytes_to_zeroing = num_of_zero_padded_values * otype_sz_; ++ zero_dst_memory(bytes_to_zeroing); ++ } ++ ++ // This function is called by loop_end. At the end ++ // of loop_end is section that is responsible for ++ // restoring offset values. Restoring is based on ++ // len value which is equal to prb.nodes[x].n. ++ // If fill_zero_padded_area is called then it means ++ // offsets were shifted prb.nodes[x].tail_size times. ++ // Therefore, this function has to shift offsets by ++ // zero pad area. ++ add_imm(reg_off_in_, reg_off_in_, padded_area * i_step * itype_sz_, ++ X_TMP_0); ++ add_imm(reg_off_out_, reg_off_out_, padded_area * o_step * otype_sz_, ++ X_TMP_0); ++ add_imm(x_ptr_in_off, x_ptr_in_off, padded_area * i_step * itype_sz_, ++ X_TMP_0); ++ add_imm(x_ptr_out_off, x_ptr_out_off, padded_area * o_step * otype_sz_, ++ X_TMP_0); ++ if (prb_.scale_type == scale_type_t::MANY) { ++ add_imm(reg_off_scale_, reg_off_scale_, ++ padded_area * s_step * stype_sz_, X_TMP_0); ++ add_imm(x_ptr_scale_off, x_ptr_scale_off, ++ padded_area * s_step * stype_sz_, X_TMP_0); ++ } ++ if (compensation_needed_) { ++ add_imm(reg_off_comp_, reg_off_comp_, ++ padded_area * c_step * sizeof(int32_t), X_TMP_0); ++ add_imm(x_ptr_comp_off, x_ptr_comp_off, ++ padded_area * c_step * sizeof(int32_t), X_TMP_0); ++ } ++ } ++ + void loop_end(Label &l, XReg reg_cnt, int len, int i_step, int o_step, +- int s_step) { +- add_imm(reg_off_in, reg_off_in, i_step * itype_sz, X_TMP_0); +- add_imm(reg_off_out, reg_off_out, o_step * otype_sz, X_TMP_0); +- add_imm(x_ptr_in_off, x_ptr_in_off, i_step * itype_sz, X_TMP_0); +- add_imm(x_ptr_out_off, x_ptr_out_off, o_step * otype_sz, X_TMP_0); ++ int s_step, int c_step, const int curr_node_id) { ++ add_imm(reg_off_in_, reg_off_in_, i_step * itype_sz_, X_TMP_0); ++ add_imm(reg_off_out_, reg_off_out_, o_step * otype_sz_, X_TMP_0); ++ add_imm(x_ptr_in_off, x_ptr_in_off, i_step * itype_sz_, X_TMP_0); ++ add_imm(x_ptr_out_off, x_ptr_out_off, o_step * otype_sz_, X_TMP_0); + + if (prb_.scale_type == scale_type_t::MANY) { +- add_imm(reg_off_scale, reg_off_scale, s_step * stype_sz, X_TMP_0); +- add_imm(x_ptr_scale_off, x_ptr_scale_off, s_step * stype_sz, ++ add_imm(reg_off_scale_, reg_off_scale_, s_step * stype_sz_, ++ X_TMP_0); ++ add_imm(x_ptr_scale_off, x_ptr_scale_off, s_step * stype_sz_, + X_TMP_0); + } ++ ++ if (compensation_needed_) { ++ add_imm(reg_off_comp_, reg_off_comp_, c_step * sizeof(int32_t), ++ X_TMP_0); ++ add_imm(x_ptr_comp_off, x_ptr_comp_off, c_step * sizeof(int32_t), ++ X_TMP_0); ++ } ++ + subs(reg_cnt, reg_cnt, 1); + b(NE, l); + +- sub_imm(reg_off_in, reg_off_in, len * i_step * itype_sz, X_TMP_0); +- sub_imm(reg_off_out, reg_off_out, len * o_step * otype_sz, X_TMP_0); +- sub_imm(x_ptr_in_off, x_ptr_in_off, len * i_step * itype_sz, X_TMP_0); +- sub_imm(x_ptr_out_off, x_ptr_out_off, len * o_step * otype_sz, X_TMP_0); ++ if (prb_.tail(curr_node_id) != 0) { ++ Label if_end; ++ ++ // On the stack should be an information if node ++ // was processed with tail or not. ++ ldr(reg_tmp_, post_ptr(X_SP, reg_tmp_.getBit() / 8)); ++ ++ cmp(reg_tmp_, with_tail_info_); ++ b(NE, if_end); ++ finalize_tail_loop(i_step, o_step, s_step, c_step, curr_node_id); ++ L(if_end); ++ } ++ ++ // Restore offset to initial values. It means before ++ // loop execution. ++ sub_imm(reg_off_in_, reg_off_in_, len * i_step * itype_sz_, X_TMP_0); ++ sub_imm(reg_off_out_, reg_off_out_, len * o_step * otype_sz_, X_TMP_0); ++ sub_imm(x_ptr_in_off, x_ptr_in_off, len * i_step * itype_sz_, X_TMP_0); ++ sub_imm(x_ptr_out_off, x_ptr_out_off, len * o_step * otype_sz_, ++ X_TMP_0); + + if (prb_.scale_type == scale_type_t::MANY) { +- sub_imm(reg_off_scale, reg_off_scale, len * s_step * stype_sz, ++ sub_imm(reg_off_scale_, reg_off_scale_, len * s_step * stype_sz_, + X_TMP_0); +- sub_imm(x_ptr_scale_off, x_ptr_scale_off, len * s_step * stype_sz, ++ sub_imm(x_ptr_scale_off, x_ptr_scale_off, len * s_step * stype_sz_, + X_TMP_0); + } ++ if (compensation_needed_) { ++ sub_imm(reg_off_comp_, reg_off_comp_, ++ len * c_step * sizeof(int32_t), X_TMP_0); ++ sub_imm(x_ptr_comp_off, x_ptr_comp_off, ++ len * c_step * sizeof(int32_t), X_TMP_0); ++ } + } + +- void compute_blk_ker(const int len_unroll) { ++ void compute_blk_ker(const simple_impl_desc_t &desc) { ++ static constexpr bool with_tail_processing = true; ++ Label no_last_chunk, end_label; + int omp_ndims = prb_.full_ndims - prb_.ndims; +- Label no_last_blk, end_label; + +- if (prb_.ip_tail > 0 && prb_.op_tail == 0) { +- if (omp_ndims == 0) { +- cmp(reg_last_loop_cnt, 1); +- bne(no_last_blk); +- compute_ker(omp_ndims, len_unroll, true); +- } else { +- cmp(reg_blk_chunks, blk_cnt()); +- bne(no_last_blk); +- compute_ker(omp_ndims, len_unroll, true); ++ if (prb_.nodes[0].tail_size > 0) { ++ if (!prb_.nodes[0].is_parent_empty()) { ++ const int parent_node_id = prb_.nodes[0].parent_node_id; ++ ldr(reg_tmp_, ptr(data_chunk_addr(parent_node_id))); ++ check_if_this_is_last_chunk(reg_tmp_, parent_node_id); ++ b(NE, no_last_chunk); + } ++ ++ const int len_unroll = desc.tail_len_unroll > 0 ++ ? desc.tail_len_unroll ++ : desc.len_unroll; ++ compute_ker(omp_ndims, len_unroll, with_tail_processing); + b(end_label); + } + +- L(no_last_blk); +- compute_ker(omp_ndims, len_unroll, false); ++ L(no_last_chunk); ++ compute_ker(omp_ndims, desc.len_unroll, !with_tail_processing); + L(end_label); + } + ++ void create_loops(const simple_impl_desc_t &desc, ++ const std::array ®_cnt, int jit_loop) { ++ assert(jit_loop <= ndims_jit_loop_max); ++ ++ if (jit_loop > 0) { ++ const int nfu = desc.ndims_full_unroll; ++ const int unroll_factor ++ = jit_loop == 1 ? desc.len_last_dim_unroll : 1; ++ const int curr_node_id = nfu + (jit_loop - 1); ++ const int parent_node_id = prb_.nodes[curr_node_id].parent_node_id; ++ const int tail_size = prb_.tail(curr_node_id) / unroll_factor; ++ const int node_size = prb_.n(curr_node_id) / unroll_factor; ++ const XReg reg_loop_cnt = reg_cnt[jit_loop - 1]; ++ const bool curr_node_has_tail = prb_.tail(curr_node_id) != 0; ++ Label loop, if_no_tail, if_end; ++ ++ if (curr_node_has_tail) { ++ const size_t reg_bytes = reg_tmp_.getBit() / 8; ++ if (prb_.nodes[curr_node_id].is_parent_empty()) { ++ mov(reg_loop_cnt, tail_size); ++ // Put info that node is being processed with tail. ++ mov(reg_tmp_, with_tail_info_); ++ str(reg_tmp_, pre_ptr(X_SP, -reg_bytes)); ++ } else { ++ ldr(reg_tmp_, ptr(data_chunk_addr(parent_node_id))); ++ check_if_this_is_last_chunk(reg_tmp_, parent_node_id); ++ b(NE, if_no_tail); ++ mov(reg_loop_cnt, tail_size); ++ // Put info that node is being processed with tail. ++ mov(reg_tmp_, with_tail_info_); ++ str(reg_tmp_, pre_ptr(X_SP, -reg_bytes)); ++ b(if_end); ++ ++ L(if_no_tail); ++ mov(reg_loop_cnt, node_size); ++ // Put info that node is being processed without tail. ++ mov(reg_tmp_, without_tail_info_); ++ str(reg_tmp_, pre_ptr(X_SP, -reg_bytes)); ++ L(if_end); ++ } ++ } ++ ++ if (prb_.is_tail_in_one_of_child_nodes(curr_node_id)) { ++ if (!curr_node_has_tail) { ++ mov(reg_loop_cnt, node_size); ++ str(reg_loop_cnt, ptr(data_chunk_addr(curr_node_id))); ++ } ++ L(loop); ++ if (!prb_.nodes[curr_node_id].is_parent_empty()) { ++ Label if_no_tail_in_child_node; ++ ldr(reg_tmp_, ptr(data_chunk_addr(parent_node_id))); ++ check_if_this_is_last_chunk(reg_tmp_, parent_node_id); ++ b(NE, if_no_tail_in_child_node); ++ str(reg_loop_cnt, ptr(data_chunk_addr(curr_node_id))); ++ L(if_no_tail_in_child_node); ++ } else { ++ str(reg_loop_cnt, ptr(data_chunk_addr(curr_node_id))); ++ } ++ } else if (curr_node_has_tail) { ++ L(loop); ++ } else { ++ loop_begin(loop, reg_loop_cnt, node_size); ++ } ++ create_loops(desc, reg_cnt, jit_loop - 1); ++ ++ loop_end(loop, reg_loop_cnt, node_size, ++ prb_.is(curr_node_id) * unroll_factor, ++ prb_.os(curr_node_id) * unroll_factor, ++ prb_.ss(curr_node_id) * unroll_factor, ++ prb_.cs(curr_node_id) * unroll_factor, curr_node_id); ++ } else { ++ compute_blk_ker(desc); ++ } ++ } ++ + bool simple_impl() { + simple_impl_desc_t d; + if (!simple_impl_desc_init(prb_, &d)) return false; + +- const int nfu = d.ndims_full_unroll; +- const int ldu = d.len_last_dim_unroll; +- const int n_jit_loops = prb_.ndims - d.ndims_full_unroll; +- assert(n_jit_loops <= ndims_jit_loop_max); +- +- eor(reg_off_in, reg_off_in, reg_off_in); +- eor(reg_off_out, reg_off_out, reg_off_out); +- mov(x_ptr_in_off, XReg(reg_ptr_in.getIdx())); +- mov(x_ptr_out_off, XReg(reg_ptr_out.getIdx())); ++ eor(reg_off_in_, reg_off_in_, reg_off_in_); ++ eor(reg_off_out_, reg_off_out_, reg_off_out_); ++ mov(x_ptr_in_off, reg_ptr_in_); ++ mov(x_ptr_out_off, reg_ptr_out_); + if (prb_.scale_type == scale_type_t::MANY) { +- eor(reg_off_scale, reg_off_scale, reg_off_scale); +- mov(x_ptr_scale_off, XReg(reg_ptr_scale.getIdx())); ++ mov(reg_off_scale_, 0); ++ mov(x_ptr_scale_off, reg_ptr_scale_); ++ } ++ if (compensation_needed_) { ++ eor(reg_off_comp_, reg_off_comp_, reg_off_comp_); ++ mov(x_ptr_comp_off, reg_off_comp_); + } + +- Label l_loop[3]; +- XReg reg_cnt[3] = {x15, x14, x13}; +- +- if (n_jit_loops > 2) loop_begin(l_loop[2], reg_cnt[2], n(nfu + 2)); +- +- if (n_jit_loops > 1) loop_begin(l_loop[1], reg_cnt[1], n(nfu + 1)); +- +- if (n_jit_loops > 0) +- loop_begin(l_loop[0], reg_cnt[0], n(nfu + 0) / ldu); +- +- compute_blk_ker(d.len_unroll); +- +- if (n_jit_loops > 0) +- loop_end(l_loop[0], reg_cnt[0], n(nfu + 0) / ldu, is(nfu + 0) * ldu, +- os(nfu + 0) * ldu, ss(nfu + 0) * ldu); +- +- if (n_jit_loops > 1) +- loop_end(l_loop[1], reg_cnt[1], n(nfu + 1), is(nfu + 1), +- os(nfu + 1), ss(nfu + 1)); ++ std::array reg_cnt({{x15, x14, x13}}); + +- if (n_jit_loops > 2) +- loop_end(l_loop[2], reg_cnt[2], n(nfu + 2), is(nfu + 2), +- os(nfu + 2), ss(nfu + 2)); ++ const int n_jit_loops = prb_.ndims - d.ndims_full_unroll; ++ create_loops(d, reg_cnt, n_jit_loops); + + return true; + } +@@ -1156,7 +1577,7 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { + inst(__VA_ARGS__); + + void cvt_z_s32_f32(const size_t startIdx, const size_t regNum) { +- UNROLL_INST(scvtf, ZRegS, tmp, p_all / T_m, tmp); ++ UNROLL_INST(scvtf, ZRegS, tmp, P_ALL_ONE / T_m, tmp); + } + + void cvt_v_s32_f32(const size_t startIdx, const size_t regNum) { +@@ -1164,8 +1585,8 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { + } + + void cvt_z_f32_s32(const size_t startIdx, const size_t regNum) { +- UNROLL_INST(frinti, ZRegS, tmp, p_all / T_m, tmp); +- UNROLL_INST(fcvtzs, ZRegS, tmp, p_all / T_m, tmp); ++ UNROLL_INST(frinti, ZRegS, tmp, P_ALL_ONE / T_m, tmp); ++ UNROLL_INST(fcvtzs, ZRegS, tmp, P_ALL_ONE / T_m, tmp); + } + + void cvt_v_f32_s32(const size_t startIdx, const size_t regNum) { +@@ -1175,7 +1596,7 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { + + void cvt_z_s8_s32(const size_t startIdx, const size_t regNum) { + cvt_z_b_s(startIdx, regNum); +- UNROLL_INST(sxtb, ZRegS, tmp, p_all / T_m, tmp); ++ UNROLL_INST(sxtb, ZRegS, tmp, P_ALL_ONE / T_m, tmp); + } + + void cvt_v_s8_s32(const size_t startIdx, const size_t regNum) { +@@ -1214,7 +1635,7 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { + + void cvt_z_u8_s32(const size_t startIdx, const size_t regNum) { + cvt_z_b_s(startIdx, regNum); +- UNROLL_INST(uxtb, ZRegS, tmp, p_all / T_m, tmp); ++ UNROLL_INST(uxtb, ZRegS, tmp, P_ALL_ONE / T_m, tmp); + } + + void cvt_v_u8_s32(const size_t startIdx, const size_t regNum) { +@@ -1285,7 +1706,7 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { + + dupm(z_tmp7.s, 255); + UNROLL_INST2(smax, ZRegS(i), 0); +- UNROLL_INST2(smin, ZRegS(i), p_all / T_m, z_tmp7.s); ++ UNROLL_INST2(smin, ZRegS(i), P_ALL_ONE / T_m, z_tmp7.s); + UNROLL_INST(uzp1, ZRegH, tmp, tmp, tmp); + UNROLL_INST(uzp1, ZRegB, tmp, tmp, tmp); + UNROLL_INST2(mov, ZRegB(i), P_NOT_128 / T_m, 0); +@@ -1320,107 +1741,514 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { + #undef UNROLL_INST + #undef UNROLL_INST + +- jit_uni_reorder_kernel_f32_t(const desc_t &desc) : kernel_t(desc) { +- itype_sz = data_type_size(prb_.itype); +- otype_sz = data_type_size(prb_.otype); +- stype_sz = sizeof(float); ++ jit_uni_reorder_kernel_f32_t(const desc_t &desc) ++ : kernel_t(desc), isa_(get_max_cpu_isa()) { ++ assert(!utils::one_of(isa_, isa_undef, isa_all)); ++ itype_sz_ = data_type_size(prb_.itype); ++ otype_sz_ = data_type_size(prb_.otype); ++ stype_sz_ = sizeof(float); + } + + void generate() override { + using namespace Xbyak_aarch64::util; + uint64_t sveLen = get_sve_length(); ++ Label end_of_kernel; + + preamble(); +-#define PARAM(x) offsetof(call_param_t, x) ++ + if (prb_.scale_type == scale_type_t::COMMON) { +- add_imm(X_DEFAULT_ADDR, abi_param1, PARAM(scale), X_TMP_1); ++ add_imm(X_DEFAULT_ADDR, PARAM(scale), X_TMP_1); + ldr(X_TMP_0, ptr(X_DEFAULT_ADDR)); +- ldr(W_TMP_1, ptr(X_TMP_0)); +- dup(xmm_scale, W_TMP_1); ++ ld1r(xmm_scale_, ptr(X_TMP_0)); + } else if (prb_.scale_type == scale_type_t::MANY) { +- add_imm(X_DEFAULT_ADDR, abi_param1, PARAM(scale), X_TMP_0); +- ldr(reg_ptr_scale, ptr(X_DEFAULT_ADDR)); ++ add_imm(X_DEFAULT_ADDR, PARAM(scale), X_TMP_0); ++ ldr(reg_ptr_scale_, ptr(X_DEFAULT_ADDR)); + } +- add_imm(X_TMP_0, abi_param1, PARAM(in), X_TMP_2); +- add_imm(X_TMP_1, abi_param1, PARAM(out), X_TMP_2); +- add_imm(reg_blk, abi_param1, PARAM(blk_chunks), reg_blk); +- ldr(reg_ptr_in, ptr(X_TMP_0)); +- ldr(reg_ptr_out, ptr(X_TMP_1)); +- ldr(reg_blk_chunks, ptr(reg_blk)); +- +-#undef PARAM +- mov_imm(reg_last_loop_cnt, 1); ++ if (compensation_needed_) { ++ add_imm(X_DEFAULT_ADDR, PARAM(compensation_scratch), X_TMP_0); ++ ldr(reg_ptr_comp_, ptr(X_DEFAULT_ADDR)); ++ } ++ if (prb_.scale_adjust == 0.5f) { mov(reg_scale_adjust_, 0x3f000000); } ++ add_imm(X_TMP_0, PARAM(in), X_TMP_2); ++ add_imm(X_TMP_1, PARAM(out), X_TMP_2); ++ ldr(reg_ptr_in_, ptr(X_TMP_0)); ++ ldr(reg_ptr_out_, ptr(X_TMP_1)); + +- mov(x_ptr_in_off, XReg(reg_ptr_in.getIdx())); +- mov(x_ptr_out_off, XReg(reg_ptr_out.getIdx())); +- mov(x_ptr_scale_off, XReg(reg_ptr_scale.getIdx())); ++ mov(x_ptr_in_off, reg_ptr_in_); ++ mov(x_ptr_out_off, reg_ptr_out_); ++ mov(x_ptr_scale_off, reg_ptr_scale_); ++ mov(x_ptr_comp_off, reg_ptr_comp_); + + if (sveLen) { /* SVE is available. */ + ptrue(p_lsb_256.b, VL32); +- ptrue(p_all.b); ++ ptrue(p_lsb_128.b, VL16); ++ ptrue(p_lsb_64.b, VL8); + } + +- if (can_do_tr8x8()) { +- dup(ymm_zero, 0); +- +- if (prb_.itype == data_type::u8 && prb_.otype == data_type::s8) { +- mov_imm(reg_tmp, 0x7f7f7f7f7f7f7f7f); +- mov(VReg4S(ymm_8x127b.getIdx())[0], WReg(reg_tmp.getIdx())); ++ bool is_tail_in_drv_dims = false; ++ for (int i = prb_.ndims; i < prb_.full_ndims; i++) ++ if (prb_.nodes[i].tail_size > 0) { ++ is_tail_in_drv_dims = true; ++ break; + } +- } else if (mayiuse(sve_512)) { +- movi(xmm_zero, 0); + +- if (prb_.itype == data_type::u8 && prb_.otype == data_type::s8) { +- mov(WReg(reg_tmp.getIdx()), 0x7f7f7f7f); +- mov(xmm_4x127b[0], WReg(reg_tmp.getIdx())); ++ if (is_tail_in_drv_dims) { ++ Label reorder_kernel; ++ add_imm(X_DEFAULT_ADDR, TAIL_PARAM(skip_kernel_execution), X_TMP_0); ++ ldr(reg_tmp_, ptr(X_DEFAULT_ADDR)); ++ cmp(reg_tmp_, static_cast(true)); ++ b(EQ, end_of_kernel); ++ ++ add_imm(X_DEFAULT_ADDR, TAIL_PARAM(zeroing_data), X_TMP_0); ++ ldr(reg_tmp_, ptr(X_DEFAULT_ADDR)); ++ cmp(reg_tmp_, static_cast(false)); ++ b(EQ, reorder_kernel); ++ // If zeroing data is set then all dst memory ++ // will be zeroed and nothing more will be done. ++ int bytes_to_zeroing = otype_sz_; ++ for (int i = 0; i < prb_.ndims; i++) { ++ bytes_to_zeroing *= prb_.nodes[i].n; + } ++ eor(reg_off_out_, reg_off_out_, reg_off_out_); ++ mov(x_ptr_out_off, reg_ptr_out_); ++ zero_dst_memory(bytes_to_zeroing); ++ b(end_of_kernel); ++ L(reorder_kernel); ++ } ++ ++ if (can_do_tr8x8()) { ++ dup(ymm_zero_, 0); ++ } else { ++ movi(xmm_zero_, 0); + } + + impl(); ++ ++ L(end_of_kernel); + postamble(); + } + ++ ~jit_uni_reorder_kernel_f32_t() override = default; ++ ++#undef TAIL_PARAM ++#undef PARAM ++ + private: +- int itype_sz; +- int otype_sz; +- int stype_sz; ++ static constexpr int64_t with_tail_info_ = static_cast(true); ++ static constexpr int64_t without_tail_info_ = static_cast(false); ++ ++ int itype_sz_; ++ int otype_sz_; ++ int stype_sz_; + +- XReg reg_ptr_in = x6; +- XReg reg_ptr_out = x2; +- XReg reg_ptr_scale = abi_not_param1; ++ const cpu_isa_t isa_; + +- XReg reg_off_in = x8; +- XReg reg_off_out = x9; +- XReg reg_off_scale = x10; ++ const XReg reg_ptr_in_ = x6; ++ const XReg reg_ptr_out_ = x2; ++ const XReg reg_ptr_scale_ = abi_not_param1; ++ const XReg reg_ptr_comp_ = x3; ++ const WReg ®_scale_adjust_ = w5; + +- XReg reg_blk = x11; +- XReg reg_blk_chunks = x12; +- XReg reg_last_loop_cnt = x11; ++ const XReg reg_off_in_ = x8; ++ const XReg reg_off_out_ = x9; ++ const XReg reg_off_scale_ = x10; ++ const XReg reg_off_comp_ = x11; + +- XReg reg_tmp = x0; ++ XReg reg_tmp_ = x12; + +- VReg4S xmm_scale = v15.s; +- VReg4S xmm_zero = v14.s; +- VReg4S xmm_4x127b = v13.s; // TODO: unite with ymm_zero +- ZRegS ymm_zero = z14.s; +- ZRegS ymm_8x127b = z13.s; +- VReg4S xmm_tmp = v12.s; +- VReg4S xmm_saturation_ubound = v12.s; +- ZRegS ymm_saturation_ubound = z12.s; ++ VReg4S xmm_scale_ = v15.s; ++ VReg4S xmm_zero_ = v14.s; ++ ZRegS ymm_zero_ = z14.s; ++ VReg4S xmm_tmp_ = v12.s; ++ const VReg4S xmm_src_zp_ = v9.s; ++ const VReg4S xmm_dst_zp_ = v11.s; ++ VReg4S xmm_saturation_ubound_ = v12.s; ++ ZRegS ymm_saturation_ubound_ = z12.s; + + /* Note: x22 - x28 are already used as temporal registgers + in jit_generator.hpp. +- x_ptr_(in|out|scale)_off keeps (base + offset) address. */ ++ x_ptr_(in|out|scale|comp)_off keeps (base + offset) address. */ + XReg x_ptr_in_off = x16; + XReg x_ptr_out_off = x18; + XReg x_ptr_scale_off = x20; ++ XReg x_ptr_comp_off = x17; + + /* Caution: Chose predicate registers not used by x64's implementation. */ + PReg p_lsb_256 = p7; +- PReg p_all = p6; ++ PReg p_lsb_128 = p6; ++ PReg p_lsb_64 = p4; + PReg p_tmp0 = p5; + + const std::vector tmp_vec_idx = {20, 21, 22, 23, 24, 25, 26, 27}; ++ VReg v_tmp0 = v20; ++ ZReg z_tmp0 = z20; ++ ZReg z_tmp1 = z21; ++ ZReg z_tmp2 = z22; ++ ZReg z_tmp3 = z23; ++ ZReg z_tmp4 = z24; ++ ZReg z_tmp5 = z25; ++ ZReg z_tmp6 = z26; ++ ZReg z_tmp7 = z27; ++ VReg v_tmp7 = v27; ++ ++ const std::vector z_tmp_vec ++ = {z_tmp0, z_tmp1, z_tmp2, z_tmp3, z_tmp4, z_tmp5, z_tmp6, z_tmp7}; ++ constexpr static int z_tmp_vec_size = 8; ++}; ++ ++// Seperate class for no unroll/threading burden ++struct jit_single_blk_kernel_t : public jit_generator { ++ DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_single_blk_kernel) ++ static bool applicable(const prb_t &p) { ++ using namespace data_type; ++ ++ bool ok = p.ndims >= 2 && mayiuse(sve_256) ++ && p.scale_type == scale_type_t::NONE ++ && utils::one_of(p.itype, f32) && utils::one_of(p.otype, f32) ++ && utils::everyone_is(0, p.ioff, p.ooff) && p.beta == 0.f ++ && prb_has_small_strides(p); ++ if (!ok) return false; ++ ++ int64_t n0 = p.nodes[0].n; ++ auto i0 = p.nodes[0].is; ++ auto o0 = p.nodes[0].os; ++ int64_t n1 = p.nodes[1].n; ++ auto i1 = p.nodes[1].is; ++ auto o1 = p.nodes[1].os; ++ ++ /* ++ * for a transpose of plain to 8c case, nodes would be like: ++ * n is os ++ * m 1 8 ++ * 8 m 1 ++ * or ++ * 8 m 1 ++ * m 1 8 ++ */ ++ ok = (utils::one_of(n0, 8, 16) || utils::one_of(n1, 8, 16)) ++ && ((i0 == 1 && o1 == 1 && n0 == i1 && o0 == n1) ++ || (o0 == 1 && i1 == 1 && n0 == o1 && i0 == n1)); ++ if (!ok) return false; ++ ++ // Do not handle transpose of dimensions other than last 2 ++ for (int i = 2; i < p.ndims; ++i) { ++ if (p.nodes[i].is != p.nodes[i].os) { ++ ok = false; ++ break; ++ } ++ } ++ ++ return ok; ++ } ++ ++ jit_single_blk_kernel_t(const tr::prb_t &prb) ++ : jit_generator() ++ , prb_(prb) ++ , itype_sz_(data_type_size(prb_.itype)) ++ , otype_sz_(data_type_size(prb_.otype)) ++ , block_sz(prb.nodes[0].n) {} ++ ++ void generate() override { ++ auto input_stride ++ = prb_.nodes[0].is != 1 ? prb_.nodes[0].is : prb_.nodes[1].is; ++ auto output_stride ++ = prb_.nodes[0].os != 1 ? prb_.nodes[0].os : prb_.nodes[1].os; ++ ++ Label tail_processing; ++ ++ const auto load_zp = [&](const ZRegS ymm_zp, const XReg reg_zp) { ++ dup(ymm_zp, WReg(reg_zp.getIdx())); ++ scvtf(ymm_zp, P_ALL_ONE / T_m, ymm_zp); ++ }; ++ ++ preamble(); ++ ++ if (prb_.req_src_zp) load_zp(ymm_src_zp, reg_src_zp); ++ ++ if (prb_.req_dst_zp) load_zp(ymm_dst_zp, reg_dst_zp); ++ ++ cmp(reg_ptr_tail, true); ++ b(EQ, tail_processing); ++ ++ if (block_sz == 8) { ++ gen_ker8x8(0, 0, input_stride, output_stride, 8, 8); ++ block_sz = 8; ++ } else if (block_sz == 16) { ++ gen_ker16x16_in_8x8(input_stride, output_stride); ++ block_sz = 16; ++ } else { ++ assert(!"unimplemented"); ++ } ++ ++ postamble(); ++ ++ L(tail_processing); ++ ++ if (block_sz == 8) { ++ auto i_tail = input_stride % 8 != 0 ? input_stride % 8 : 8; ++ auto o_tail = output_stride % 8 != 0 ? output_stride % 8 : 8; ++ if (i_tail != o_tail) { ++ auto t_mask = i_tail == 8 ? o_tail : i_tail; ++ gen_setmask(t_mask); ++ gen_ker8x8(0, 0, input_stride, output_stride, i_tail, o_tail); ++ } ++ } else if (block_sz == 16) { ++ auto i_tail = input_stride % 16 != 0 ? input_stride % 16 : 16; ++ auto o_tail = output_stride % 16 != 0 ? output_stride % 16 : 16; ++ if (i_tail != o_tail) { ++ auto t_mask = i_tail == 16 ? o_tail : i_tail; ++ t_mask %= 8; ++ if (t_mask != 0) gen_setmask(t_mask); ++ gen_ker16x16_in_8x8( ++ input_stride, output_stride, i_tail, o_tail); ++ } ++ } else { ++ assert(!"unimplemented"); ++ } ++ ++ postamble(); ++ } ++ ++ void gen_loadu(const ZRegS ymm, const XReg &addr, int size) { ++ QReg xmm(ymm.getIdx()); ++ switch (size) { ++ case 32: ld1w(ymm, p_lsb_256 / T_z, ptr(addr)); break; ++ case 16: ldr(xmm, ptr(addr)); break; ++ default: assert(!"unreachable"); ++ } ++ } ++ ++ void gen_storeu(const XReg &addr, const ZRegS ymm, int size) { ++ QReg xmm(ymm.getIdx()); ++ switch (size) { ++ case 32: st1w(ymm, p_lsb_256, ptr(addr)); break; ++ case 16: str(xmm, ptr(addr)); break; ++ default: assert(!"unreachable"); ++ } ++ } ++ ++ void gen_maskloadu( ++ const ZRegS ymm, const XReg &addr, const PReg mask, int size) { ++ switch (size) { ++ case 32: ++ case 16: ld1w(ymm, mask / T_z, ptr(addr)); break; ++ default: assert(!"unreachable"); ++ } ++ } ++ ++ void gen_maskstoreu( ++ const XReg &addr, const ZRegS ymm, const PReg mask, int size) { ++ switch (size) { ++ case 32: ++ case 16: st1w(ymm, mask, ptr(addr)); break; ++ default: assert(!"unreachable"); ++ } ++ } ++ ++ // Register allocation xmm0~11 ++ void gen_transpose_8x8() { ++ const uint64_t sveLen = get_sve_length(); ++ constexpr int lane = 8; ++ ++#if 0 ++ /* Debug code ++ z0: 7, 6, 5, 4, 3, 2, 1, 0 ++ z1: 15, 14, 13, 12, 11, 10, 9, 8 ++ ... ++ z17: 63, 62, 61, 60, 59, 58, 57, 56 ++ */ ++ ptrue(P_ALL_ONE.b); ++ ptrue(P_TMP.s, VL8); ++ not_(P_TMP.b, P_ALL_ONE/T_z, P_TMP.b); ++ index(z0.s, 0, 1); ++ mov(z0.s, P_TMP/T_m, 0); ++ mov(z_tmp_vec[0].s, 8); ++ mov(z_tmp_vec[0].s, P_TMP/T_m, 0); ++ for(uint32_t i=1; i nChw()C ++ // or nChw()C -> nchw ++ void gen_setmask(int mask) { ++ mov_imm(x_tmp_0, 0); ++ mov_imm(x_tmp_1, mask); ++ whilelt(p_mask.s, x_tmp_0, x_tmp_1); ++ } ++ ++ // TODO: Mark parameter with type information ++ // XXX: ! ++ // offset in byte offset ++ // stride in element number ++ // ++ // Gen specific 8x8 transform respect to certain tail condition ++ void gen_tr8x8(int i_off, int o_off, int input_stride, int output_stride, ++ int in_tail, int out_tail) { ++ constexpr int lane = 8; ++ ++ if (in_tail == 0 || out_tail == 0) return; ++ ++ for (int i = 0; i < out_tail; ++i) { ++ if (in_tail != lane) { ++ add_imm(x_addr, reg_ptr_in_, ++ i_off + i * input_stride * itype_sz_, x_tmp_0); ++ gen_maskloadu(ZRegS(i), x_addr, p_mask, lane * itype_sz_); ++ } else { ++ add_imm(x_addr, reg_ptr_in_, ++ i_off + i * input_stride * itype_sz_, x_tmp_0); ++ gen_loadu(ZRegS(i), x_addr, lane * itype_sz_); ++ } ++ if (prb_.req_src_zp) { fsub(ZRegS(i), ZRegS(i), ymm_src_zp); } ++ } ++ ++ gen_transpose_8x8(); ++ ++ for (int i = 0; i < in_tail; ++i) { ++ if (prb_.req_dst_zp) { fadd(ZRegS(i), ZRegS(i), ymm_dst_zp); } ++ if (out_tail == lane) { ++ add_imm(x_addr, reg_ptr_out_, ++ o_off + i * output_stride * otype_sz_, x_tmp_0); ++ gen_storeu(x_addr, ZRegS(i), lane * otype_sz_); ++ } else { ++ add_imm(x_addr, reg_ptr_out_, ++ o_off + i * output_stride * otype_sz_, x_tmp_0); ++ gen_maskstoreu(x_addr, ZRegS(i), p_mask, lane * otype_sz_); ++ } ++ } ++ } ++ ++ // tail: 0 ~ 8 ++ // support: either in_tail or out_tail is not 8, but not both ++ void gen_ker8x8(int i_off, int o_off, int input_stride, int output_stride, ++ int in_tail, int out_tail) { ++ gen_tr8x8(i_off, o_off, input_stride, output_stride, in_tail, out_tail); ++ } ++ ++ void gen_ker16x16_in_8x8(int input_stride, int output_stride) { ++ const auto lane = 16; ++ const auto sub_lane = lane / 2; ++ gen_tr8x8(0, 0, input_stride, output_stride, sub_lane, sub_lane); ++ gen_tr8x8(input_stride * sub_lane * itype_sz_, sub_lane * otype_sz_, ++ input_stride, output_stride, sub_lane, sub_lane); ++ gen_tr8x8(sub_lane * itype_sz_, output_stride * sub_lane * otype_sz_, ++ input_stride, output_stride, sub_lane, sub_lane); ++ gen_tr8x8((input_stride * sub_lane + sub_lane) * itype_sz_, ++ (output_stride * sub_lane + sub_lane) * otype_sz_, input_stride, ++ output_stride, sub_lane, sub_lane); ++ } ++ ++ // tail can be 1 ~ 16, using avx2 for now ++ void gen_ker16x16_in_8x8( ++ int input_stride, int output_stride, int in_tail, int out_tail) { ++ constexpr auto lane = 16; ++ constexpr auto sub_lane = lane / 2; ++ auto tail = in_tail != lane ? in_tail : out_tail; ++ ++ const auto l_tail = tail < sub_lane ? tail : sub_lane; ++ const auto u_tail = tail < sub_lane ? 0 : tail - sub_lane; ++ ++ if (tail == in_tail) { ++ gen_tr8x8(0, 0, input_stride, output_stride, l_tail, sub_lane); ++ gen_tr8x8(input_stride * sub_lane * itype_sz_, sub_lane * otype_sz_, ++ input_stride, output_stride, l_tail, sub_lane); ++ gen_tr8x8(sub_lane * itype_sz_, ++ output_stride * sub_lane * otype_sz_, input_stride, ++ output_stride, u_tail, sub_lane); ++ gen_tr8x8(itype_sz_ * (input_stride * sub_lane + sub_lane), ++ otype_sz_ * (output_stride * sub_lane + sub_lane), ++ input_stride, output_stride, u_tail, sub_lane); ++ } else { ++ gen_tr8x8(0, 0, input_stride, output_stride, sub_lane, l_tail); ++ gen_tr8x8(input_stride * sub_lane * itype_sz_, sub_lane * otype_sz_, ++ input_stride, output_stride, sub_lane, u_tail); ++ gen_tr8x8(sub_lane * itype_sz_, ++ output_stride * sub_lane * itype_sz_, input_stride, ++ output_stride, sub_lane, l_tail); ++ gen_tr8x8(itype_sz_ * (input_stride * sub_lane + sub_lane), ++ otype_sz_ * (output_stride * sub_lane + sub_lane), ++ input_stride, output_stride, sub_lane, u_tail); ++ } ++ } ++ ++private: ++ // 6 ~ 12 ++ constexpr static int xmm_save_for_windows = 0; ++ constexpr static int xmm_save_start_from = 6; ++ constexpr static int xmm_width = 16; ++ ++ void preamble() { ptrue(p_lsb_256.b, VL32); } ++ ++ void postamble() { ret(); } ++ ++ const prb_t &prb_; ++ ++ int itype_sz_; ++ int otype_sz_; ++ int block_sz; ++ ++ XReg reg_ptr_in_ = abi_param1; ++ XReg reg_ptr_out_ = abi_param2; ++ XReg reg_ptr_tail = abi_param3; ++ XReg reg_src_zp = abi_param4; ++ XReg reg_dst_zp = abi_param5; ++ ++ XReg x_addr = x10; ++ XReg x_tmp_0 = x11; ++ XReg x_tmp_1 = x12; ++ ++ /* Avoid P_TMP(p7) in jit_generator.hpp. */ ++ PReg p_lsb_256 = p6; ++ PReg p_mask = p5; ++ ++ ZRegS ymm_tmp = z0.s; ++ ZRegS ymm_src_zp = z14.s; ++ ZRegS ymm_dst_zp = z15.s; ++ ++ const std::vector tmp_vec_idx = {20, 21, 22, 23, 24, 25, 26, 27}; ++ VReg v_tmp0 = v20; + ZReg z_tmp0 = z20; + ZReg z_tmp1 = z21; + ZReg z_tmp2 = z22; +@@ -1472,15 +2300,31 @@ kernel_t *kernel_t::create(const kernel_t::desc_t &desc) { + + return nullptr; + } ++ + } // namespace tr + + static void prb_block_for_cache(tr::prb_t &prb) { + /* If strides for 0th and 1st nodes are cache friendly + * then one can altogether do away with blocking ! */ +- const bool cache_blocking_needed = false +- || (prb.nodes[0].is % 64 == 0 && prb.nodes[0].n > 16) +- || (prb.ndims > 1 && prb.nodes[1].is % 64 == 0 +- && prb.nodes[1].n > 16); ++ static constexpr int num_elems_thr = 16; ++ const bool stride_cache_friendly ++ = ((prb.nodes[0].is % 64 == 0 && prb.nodes[0].n > num_elems_thr) ++ || (prb.ndims > 1 && prb.nodes[1].is % num_elems_thr == 0 ++ && prb.nodes[1].n > num_elems_thr)) ++ && !prb.is_tail_present; ++ ++ // performance improvement for shapes with large inner-most dimension ++ const size_t L1_cache_sz ++ = size_t(3) * platform::get_per_core_cache_size(1) / 4; ++ const size_t itype_sz_ = data_type_size(prb.itype); ++ const size_t inner_block_sz = prb.nodes[0].n * itype_sz_; ++ const bool requires_inner_blocking = inner_block_sz > L1_cache_sz ++ // 'is_tail_present' is not supported for cache_blocking when ++ // asymmetric_comp is executed. ++ && IMPLICATION(prb.req_asymmetric_comp, !prb.is_tail_present); ++ ++ const bool cache_blocking_needed ++ = stride_cache_friendly || requires_inner_blocking; + if (!cache_blocking_needed) return; + + int unit_input_stride_idx = -1; +@@ -1496,28 +2340,58 @@ static void prb_block_for_cache(tr::prb_t &prb) { + const auto output_stride = prb.nodes[unit_input_stride_idx].os; + const auto num_elems = prb.nodes[unit_input_stride_idx].n; + +- const bool split_needed = (num_elems > 16) && (num_elems % 16 == 0); ++ const bool split_needed = (num_elems > num_elems_thr) ++ && (num_elems % num_elems_thr == 0); + const int move_location = (output_stride % 4 != 0) ? 0 : 1; +- if (split_needed) prb_node_split(prb, unit_input_stride_idx, 16); ++ if (split_needed) ++ prb_node_split(prb, unit_input_stride_idx, num_elems_thr); + + /* Because of cache-unfriendly nature of unit-output stride node, let + * us move unit-input stride node on or near front! */ +- prb_node_move(prb, unit_input_stride_idx, move_location); ++ if (unit_input_stride_idx != move_location) ++ prb_node_move(prb, unit_input_stride_idx, move_location); + } + + /* Potentially, split the node with os=1 in two and pull in the node with + * is=1 between them for better cache reuse: + * [n0:is0:1][n1:1:os1] --> [16n0:is0:1][n1:1:os1][n0/16:is0*16:16] */ + if (prb.ndims >= 2 && prb.nodes[0].os == 1 && prb.nodes[1].is == 1) { +- const auto input_stride = prb.nodes[0].is; + const auto num_elems = prb.nodes[0].n; + +- const bool split_needed = true && (num_elems > 16) +- && (num_elems % 16 == 0) && (input_stride >= 256) +- && (input_stride % 64 == 0); ++ const bool split_needed = (num_elems > num_elems_thr) ++ && (num_elems % num_elems_thr == 0); + if (split_needed) { +- prb_node_split(prb, 0, 16); ++ prb_node_split(prb, 0, num_elems_thr); + prb_node_move(prb, 1, 2); ++ ++ // Update node information ++ prb_node_dependency(prb); ++ ++ // heuristics - looping over the unrolled dims should maximize reuse ++ // of the already cached data; observation is choosing the smallest ++ // dim from the remaining (from 2 up to ndims) gives good results ++ constexpr int new_position = 2; ++ const auto dim_beg_it = std::begin(prb.nodes); ++ const auto dim_two_it = dim_beg_it + new_position; ++ const auto dim_last_it = dim_beg_it + prb.ndims; ++ const auto min_n_node_it = std::min_element(dim_two_it, dim_last_it, ++ [](const tr::node_t &lhs, const tr::node_t &rhs) { ++ return lhs.n < rhs.n; ++ }); ++ const auto min_idx = std::distance(dim_beg_it, min_n_node_it); ++ // check if min_idx node is parent of node with tail processing which ++ // is currently unsupported (i.e. tail processing can only be handled ++ // at the inner-most dimension) ++ bool inner_block_has_tail = false; ++ for (int idx = min_idx - 1; idx >= new_position; idx--) { ++ if (prb.nodes[idx].parent_node_id == min_idx) { ++ inner_block_has_tail = true; ++ break; ++ } ++ } ++ ++ if (min_idx > new_position && (!inner_block_has_tail)) ++ prb_node_move(prb, min_idx, new_position); + } + } + } +@@ -1527,73 +2401,76 @@ static void prb_block_for_cache(tr::prb_t &prb) { + * parallel driver and the kernel. */ + static void prb_thread_kernel_balance( + tr::prb_t &prb, int &ndims_ker_max, int nthr) { +- size_t sz_total = 1; ++ size_t size_total = 1; + for (int d = 0; d < prb.ndims; ++d) +- sz_total *= prb.nodes[d].n; ++ size_total *= prb.nodes[d].n; + +- /* The general expression for sz_drv_thr can be written as +- * sz_drv_min = C0 + FC * (nthr > 1 ? 1 : 0) + VC * (nthr - 1) ++ /* The general expression for size_drv_thr can be written as ++ * size_drv_min = C0 + FC * (nthr > 1 ? 1 : 0) + VC * (nthr - 1) + * where FC and VC are fixed and variable costs respectively. + * Though for now, the below heuristic seems to be good enough */ +- const size_t sz_drv_thr = (nthr > 1) ? 16 * nthr : 1; ++ const size_t size_drv_thr = (nthr > 1) ? 16 * nthr : 1; + +- /* sz_drv_min is the minimal size for the parallel ++ /* size_drv_min is the minimal size for the parallel + * driver required for good parallelization */ +- const size_t sz_drv_min +- = nstl::min(sz_drv_thr, utils::div_up(sz_total, 1024)); ++ const size_t size_drv_min ++ = nstl::min(size_drv_thr, utils::div_up(size_total, 1024)); + + /* kdims -- # of dimensions processed by a kernel +- * sz_ker_cur -- product of the dimension processed by a kernel +- * sz_drv_cur -- product of the dimension processed by a driver */ ++ * size_ker_cur -- product of the dimension processed by a kernel ++ * size_drv_cur -- product of the dimension processed by a driver */ + + int kdims = prb.ndims; +- size_t sz_drv_cur = 1; +- for (; kdims > 1 && sz_drv_cur < sz_drv_min; --kdims) +- sz_drv_cur *= prb.nodes[kdims - 1].n; ++ size_t size_drv_cur = 1; ++ for (; kdims > 1 && size_drv_cur < size_drv_min; --kdims) ++ size_drv_cur *= prb.nodes[kdims - 1].n; + +- size_t sz_ker_cur = 1; ++ size_t size_ker_cur = 1; + for (int d = 0; d < kdims; ++d) +- sz_ker_cur *= prb.nodes[d].n; ++ size_ker_cur *= prb.nodes[d].n; + +- /* Initially kdims is chosen so that sz_drv_cur >= sz_drv_min. ++ /* Initially kdims is chosen so that size_drv_cur >= size_drv_min. + * +- * It might happen that for chosen kdims the sz_ker_cur is too small ++ * It might happen that for chosen kdims the size_ker_cur is too small + * (less than tr::ker_prb_size_min). In that case try to split the +- * innermost driver dimension into two, to increase sz_ker_cur. */ +- bool want_borrow_ker_from_drv = true && kdims < prb.ndims +- && sz_ker_cur < tr::ker_prb_size_min && sz_drv_cur > sz_drv_min +- && kdims != prb.blk_chunk_idx; ++ * innermost driver dimension into two, to increase size_ker_cur. */ ++ const bool want_borrow_ker_from_drv = kdims < prb.ndims ++ && size_ker_cur < tr::ker_prb_size_min ++ && size_drv_cur > size_drv_min; + if (want_borrow_ker_from_drv) { +- /* sz_want_borrow is the minimal sz, so that: +- * o) sz_ker_cur * sz_want_borrow >= tr::ker_prb_size_min ++ /* size_want_borrow is the minimal size, so that: ++ * o) size_ker_cur * size_want_borrow >= tr::ker_prb_size_min + * o) current innermost driver dimension is divisible by +- * sz_want_borrow (so that we can evenly split that ++ * size_want_borrow (so that we can evenly split that + * dimension into two) + * +- * In the worst case the minimal sz_want_borrow is equal ++ * In the worst case the minimal size_want_borrow is equal + * to the innermost driver dimension itself. In that case + * we will sacrifice it in favor of kernel (is it fine?). */ +- size_t sz_want_borrow = utils::div_up(tr::ker_prb_size_min, sz_ker_cur); +- for (; prb.nodes[kdims].n % sz_want_borrow; ++sz_want_borrow) ++ size_t size_want_borrow ++ = utils::div_up(tr::ker_prb_size_min, size_ker_cur); ++ for (; prb.nodes[kdims].n % size_want_borrow; ++size_want_borrow) + ; +- if (sz_want_borrow != prb.nodes[kdims].n) +- prb_node_split(prb, kdims, sz_want_borrow); ++ ++ if (size_want_borrow != prb.nodes[kdims].n) ++ prb_node_split(prb, kdims, size_want_borrow); + kdims += 1; + } + + /* On the other hand it might happen that for chosen kdims +- * the sz_drv_cur is too small (less than sz_drv_min). In that case ++ * the size_drv_cur is too small (less than size_drv_min). In that case + * try to split the outermost kernel dimension into two, to increase +- * sz_drv_cur. */ +- bool want_borrow_drv_from_ker = true && sz_ker_cur > tr::ker_prb_size_min +- && sz_drv_cur < sz_drv_min && kdims != prb.blk_chunk_idx; ++ * size_drv_cur. */ ++ const bool want_borrow_drv_from_ker = size_ker_cur > tr::ker_prb_size_min ++ && size_drv_cur < size_drv_min; + if (want_borrow_drv_from_ker) { +- size_t sz_want_borrow = utils::div_up(sz_drv_min, sz_drv_cur); +- for (; prb.nodes[kdims - 1].n % sz_want_borrow; ++sz_want_borrow) ++ size_t size_want_borrow = utils::div_up(size_drv_min, size_drv_cur); ++ for (; prb.nodes[kdims - 1].n % size_want_borrow; ++size_want_borrow) + ; +- if (sz_want_borrow != prb.nodes[kdims - 1].n) ++ ++ if (size_want_borrow != prb.nodes[kdims - 1].n) + prb_node_split( +- prb, kdims - 1, prb.nodes[kdims - 1].n / sz_want_borrow); ++ prb, kdims - 1, prb.nodes[kdims - 1].n / size_want_borrow); + } + + ndims_ker_max = kdims; +@@ -1607,6 +2484,33 @@ static void prb_thread_kernel_balance( + } + } + ++status_t jit_uni_reorder_t::pd_t::init( ++ engine_t *engine, engine_t *src_engine, engine_t *dst_engine) { ++ CHECK(cpu_reorder_pd_t::init(engine, src_engine, dst_engine)); ++ ++ const bool compensation_needed ++ = prb_.req_s8s8_comp || prb_.req_asymmetric_comp; ++ if (compensation_needed) init_scratchpad(); ++ ++ return status::success; ++} ++ ++void jit_uni_reorder_t::pd_t::init_scratchpad() { ++ const memory_desc_wrapper od(dst_md()); ++ const auto G = with_groups_ ? od.padded_dims()[0] : 1; ++ const auto N = od.padded_dims()[with_groups_ ? 1 : 0]; ++ static constexpr int cache_line_size = 16; ++ const auto wspace_per_thr_size ++ = utils::rnd_up(G * N, cache_line_size) * sizeof(int32_t); ++ ++ auto scratchpad = scratchpad_registry().registrar(); ++ const auto compensation_reduce_size = wspace_per_thr_size * nthr_; ++ ++ //every thread gets its own scratchpad space for each N ++ scratchpad.template book(memory_tracking::names::key_reorder_space, ++ compensation_reduce_size); ++} ++ + status_t jit_uni_reorder_t::pd_t::create(reorder_pd_t **reorder_pd, + engine_t *engine, const primitive_attr_t *attr, engine_t *src_engine, + const memory_desc_t *src_md, engine_t *dst_engine, +@@ -1616,36 +2520,18 @@ status_t jit_uni_reorder_t::pd_t::create(reorder_pd_t **reorder_pd, + status_t prb_init_status = prb_init(prb, *src_md, *dst_md, attr); + if (prb_init_status != status::success) return prb_init_status; + +- DEBUG({ +- printf("init : "); +- prb_dump(prb); +- }); +- // Sort the prb array in increasing sizes of the output stride +- prb_normalize(prb); +- DEBUG({ +- printf("norm : "); +- prb_dump(prb); +- }); +- /* Combine the variables, which appear together on both +- * sides of the reorder */ +- prb_simplify(prb); +- DEBUG({ +- printf("smpl : "); +- prb_dump(prb); +- }); +- + prb_block_for_cache(prb); + DEBUG({ + printf("cache: "); + prb_dump(prb); + }); + +- CHECK(prb_check_blk(prb, *dst_md)); +- +- int ndims_ker_max; ++ int ndims_ker_max {}; + int nthr = dnnl_get_max_threads(); + prb_thread_kernel_balance(prb, ndims_ker_max, nthr); + ++ if (prb.is_tail_present) prb_node_dependency(prb); ++ + tr::kernel_t::desc_t ker_desc; + status_t ker_init_status + = tr::kernel_t::desc_init(ker_desc, prb, ndims_ker_max); +@@ -1663,99 +2549,191 @@ status_t jit_uni_reorder_t::pd_t::create(reorder_pd_t **reorder_pd, + auto _pd = new pd_t( + attr, src_engine->kind(), src_md, dst_engine->kind(), dst_md); + if (_pd == nullptr) return status::out_of_memory; ++ ++ _pd->nthr_ = nthr; ++ _pd->prb_ = prb; ++ _pd->with_groups_ ++ = prb.compensation_mask == tr::prb_t::comp_mask_with_groups; + if (_pd->init(engine, src_engine, dst_engine) != status::success) { + delete _pd; + return status::unimplemented; + } +- _pd->prb_ = prb; + _pd->ker_desc_ = ker_desc; + _pd->init_scratchpad_md(); +- _pd->nthr_ = nthr; ++ + return safe_ptr_assign(*reorder_pd, _pd); + } + +-void jit_uni_reorder_t::omp_driver_0d( +- int off, const char *in, char *out, const float *scale) const { +- tr::call_param_t c {in, out, scale, 0}; +- (*kernel_)(&c); ++void jit_uni_reorder_t::omp_driver_0d(int off, const char *in, char *out, ++ const float *scale, int src_zp, int dst_zp, ++ int32_t *compensation_scratch) const { ++ const tr::prb_t &prb = pd()->prb_; ++ ++ tr::call_param_t base_params; ++ base_params.in = in; ++ base_params.out = out; ++ base_params.scale = scale; ++ base_params.src_zp = src_zp; ++ base_params.dst_zp = dst_zp; ++ base_params.compensation_scratch = compensation_scratch; ++ ++ if (prb.is_tail_present) { ++ tr::tail_call_param_t tail_params; ++ tail_params.base_params = base_params; ++ ++ static constexpr int omp_ndims = 0; ++ fill_curr_data_chunks(prb, off, nullptr, omp_ndims, tail_params); ++ (*kernel_)(&tail_params); ++ } else { ++ (*kernel_)(&base_params); ++ } + } + + void jit_uni_reorder_t::omp_driver_1d(int ithr, int nthr, int off, +- const char *in, char *out, const float *scale) const { +- const tr::node_t *ns = pd()->prb_.nodes + off; ++ const char *in, char *out, const float *scale, int src_zp, int dst_zp, ++ int32_t *compensation_scratch) const { ++ const tr::prb_t &prb = pd()->prb_; ++ const tr::node_t *ns = prb.nodes + off; + for_nd(ithr, nthr, (ptrdiff_t)ns[0].n, [&](ptrdiff_t d0) { +- auto c = tr::call_param_t(); +- c.in = in + d0 * ns[0].is * data_type_size(pd()->prb_.itype); +- c.out = out + d0 * ns[0].os * data_type_size(pd()->prb_.otype); +- c.scale = scale + d0 * ns[0].ss; +- c.blk_chunks = d0; +- (*kernel_)(&c); ++ tr::call_param_t base_params; ++ base_params.in = in + d0 * ns[0].is * data_type_size(prb.itype); ++ base_params.out = out + d0 * ns[0].os * data_type_size(prb.otype); ++ base_params.scale = scale + d0 * ns[0].ss; ++ base_params.src_zp = src_zp; ++ base_params.dst_zp = dst_zp; ++ base_params.compensation_scratch = compensation_scratch + d0 * ns[0].cs; ++ ++ if (prb.is_tail_present) { ++ tr::tail_call_param_t tail_params; ++ tail_params.base_params = base_params; ++ ++ static constexpr int omp_ndims = 1; ++ const ptrdiff_t omp_data_chunks[omp_ndims] = {d0}; ++ fill_curr_data_chunks( ++ prb, off, omp_data_chunks, omp_ndims, tail_params); ++ (*kernel_)(&tail_params); ++ } else { ++ (*kernel_)(&base_params); ++ } + }); + } + + void jit_uni_reorder_t::omp_driver_2d(int ithr, int nthr, int off, +- const char *in, char *out, const float *scale) const { +- const tr::node_t *ns = pd()->prb_.nodes + off; +- const int blk_idx_off = pd()->prb_.blk_chunk_idx - off; ++ const char *in, char *out, const float *scale, int src_zp, int dst_zp, ++ int32_t *compensation_scratch) const { ++ const tr::prb_t &prb = pd()->prb_; ++ const tr::node_t *ns = prb.nodes + off; + for_nd(ithr, nthr, (ptrdiff_t)ns[1].n, (ptrdiff_t)ns[0].n, + [&](ptrdiff_t d1, ptrdiff_t d0) { +- auto c = tr::call_param_t(); +- c.in = in ++ tr::call_param_t base_params; ++ base_params.in = in + + (d0 * ns[0].is + d1 * ns[1].is) +- * data_type_size(pd()->prb_.itype); +- c.out = out ++ * data_type_size(prb.itype); ++ base_params.out = out + + (d0 * ns[0].os + d1 * ns[1].os) +- * data_type_size(pd()->prb_.otype); +- c.scale = scale + d0 * ns[0].ss + d1 * ns[1].ss; +- c.blk_chunks = utils::pick(blk_idx_off, d0, d1); +- (*kernel_)(&c); ++ * data_type_size(prb.otype); ++ base_params.scale = scale + d0 * ns[0].ss + d1 * ns[1].ss; ++ base_params.src_zp = src_zp; ++ base_params.dst_zp = dst_zp; ++ base_params.compensation_scratch ++ = compensation_scratch + d0 * ns[0].cs + d1 * ns[1].cs; ++ ++ if (prb.is_tail_present) { ++ tr::tail_call_param_t tail_params; ++ tail_params.base_params = base_params; ++ ++ static constexpr int omp_ndims = 2; ++ const ptrdiff_t omp_data_chunks[omp_ndims] = {d0, d1}; ++ fill_curr_data_chunks( ++ prb, off, omp_data_chunks, omp_ndims, tail_params); ++ ++ (*kernel_)(&tail_params); ++ } else { ++ (*kernel_)(&base_params); ++ } + }); + } + + void jit_uni_reorder_t::omp_driver_3d(int ithr, int nthr, int off, +- const char *in, char *out, const float *scale) const { +- const tr::node_t *ns = pd()->prb_.nodes + off; +- const int blk_idx_off = pd()->prb_.blk_chunk_idx - off; ++ const char *in, char *out, const float *scale, int src_zp, int dst_zp, ++ int32_t *compensation_scratch) const { ++ const tr::prb_t &prb = pd()->prb_; ++ const tr::node_t *ns = prb.nodes + off; + for_nd(ithr, nthr, (ptrdiff_t)ns[2].n, (ptrdiff_t)ns[1].n, + (ptrdiff_t)ns[0].n, [&](ptrdiff_t d2, ptrdiff_t d1, ptrdiff_t d0) { +- auto c = tr::call_param_t(); +- c.in = in ++ tr::call_param_t base_params; ++ base_params.in = in + + (d0 * ns[0].is + d1 * ns[1].is + d2 * ns[2].is) +- * data_type_size(pd()->prb_.itype); +- c.out = out ++ * data_type_size(prb.itype); ++ base_params.out = out + + (d0 * ns[0].os + d1 * ns[1].os + d2 * ns[2].os) +- * data_type_size(pd()->prb_.otype); +- c.scale = scale + d0 * ns[0].ss + d1 * ns[1].ss + d2 * ns[2].ss; +- c.blk_chunks = utils::pick(blk_idx_off, d0, d1, d2); +- (*kernel_)(&c); ++ * data_type_size(prb.otype); ++ base_params.scale ++ = scale + d0 * ns[0].ss + d1 * ns[1].ss + d2 * ns[2].ss; ++ base_params.src_zp = src_zp; ++ base_params.dst_zp = dst_zp; ++ base_params.compensation_scratch = compensation_scratch ++ + d0 * ns[0].cs + d1 * ns[1].cs + d2 * ns[2].cs; ++ ++ if (prb.is_tail_present) { ++ tr::tail_call_param_t tail_params; ++ tail_params.base_params = base_params; ++ ++ static constexpr int omp_ndims = 3; ++ const ptrdiff_t omp_data_chunks[omp_ndims] = {d0, d1, d2}; ++ fill_curr_data_chunks( ++ prb, off, omp_data_chunks, omp_ndims, tail_params); ++ (*kernel_)(&tail_params); ++ } else { ++ (*kernel_)(&base_params); ++ } + }); + } + + void jit_uni_reorder_t::omp_driver_4d(int ithr, int nthr, int off, +- const char *in, char *out, const float *scale) const { +- const tr::node_t *ns = pd()->prb_.nodes + off; +- const int blk_idx_off = pd()->prb_.blk_chunk_idx - off; ++ const char *in, char *out, const float *scale, int src_zp, int dst_zp, ++ int32_t *compensation_scratch) const { ++ const tr::prb_t &prb = pd()->prb_; ++ const tr::node_t *ns = prb.nodes + off; + for_nd(ithr, nthr, (ptrdiff_t)ns[3].n, (ptrdiff_t)ns[2].n, + (ptrdiff_t)ns[1].n, (ptrdiff_t)ns[0].n, + [&](ptrdiff_t d3, ptrdiff_t d2, ptrdiff_t d1, ptrdiff_t d0) { +- auto c = tr::call_param_t(); +- c.in = in ++ tr::call_param_t base_params; ++ base_params.in = in + + (d0 * ns[0].is + d1 * ns[1].is + d2 * ns[2].is + + d3 * ns[3].is) +- * data_type_size(pd()->prb_.itype); +- c.out = out ++ * data_type_size(prb.itype); ++ base_params.out = out + + (d0 * ns[0].os + d1 * ns[1].os + d2 * ns[2].os + + d3 * ns[3].os) +- * data_type_size(pd()->prb_.otype); +- c.scale = scale + d0 * ns[0].ss + d1 * ns[1].ss + d2 * ns[2].ss +- + d3 * ns[3].ss; +- c.blk_chunks = utils::pick(blk_idx_off, d0, d1, d2, d3); +- (*kernel_)(&c); ++ * data_type_size(prb.otype); ++ base_params.scale = scale + d0 * ns[0].ss + d1 * ns[1].ss ++ + d2 * ns[2].ss + d3 * ns[3].ss; ++ base_params.src_zp = src_zp; ++ base_params.dst_zp = dst_zp; ++ base_params.compensation_scratch = compensation_scratch ++ + d0 * ns[0].cs + d1 * ns[1].cs + d2 * ns[2].cs ++ + d3 * ns[3].cs; ++ ++ if (prb.is_tail_present) { ++ tr::tail_call_param_t tail_params; ++ tail_params.base_params = base_params; ++ ++ static constexpr int omp_ndims = 4; ++ const ptrdiff_t omp_data_chunks[omp_ndims] ++ = {d0, d1, d2, d3}; ++ fill_curr_data_chunks( ++ prb, off, omp_data_chunks, omp_ndims, tail_params); ++ (*kernel_)(&tail_params); ++ } else { ++ (*kernel_)(&base_params); ++ } + }); + } + +-void jit_uni_reorder_t::omp_driver( +- const char *in, char *out, const float *scale) const { ++void jit_uni_reorder_t::omp_driver(const char *in, char *out, ++ const float *scale, int src_zp, int dst_zp, ++ const memory_tracking::grantor_t &scratchpad) const { + in += pd()->prb_.ioff * data_type_size(pd()->prb_.itype); + out += pd()->prb_.ooff * data_type_size(pd()->prb_.otype); + +@@ -1770,29 +2748,153 @@ void jit_uni_reorder_t::omp_driver( + + int ndims = pd()->prb_.ndims; + int ndims_ker = pd()->ker_desc_.prb.ndims; ++ const bool req_s8s8_comp = pd()->prb_.req_s8s8_comp; ++ const bool req_asymmetric_comp = pd()->prb_.req_asymmetric_comp; ++ const bool req_compensation = req_s8s8_comp || req_asymmetric_comp; + assert(ndims - ndims_ker <= ndims_driver_max); + ++ int32_t *compensation_reduce_scratch = scratchpad.template get( ++ memory_tracking::names::key_reorder_space); ++ ++ const memory_desc_wrapper od(pd()->dst_md()); ++ const auto G = pd()->with_groups_ ? od.padded_dims()[0] : 1; ++ const auto N = od.padded_dims()[pd()->with_groups_ ? 1 : 0]; ++ static constexpr int cache_line_size = 16; ++ const auto wspace_per_thr_size = utils::rnd_up(G * N, cache_line_size); ++ const auto wspace_per_thr_bytes = wspace_per_thr_size * sizeof(int32_t); ++ + if (ndims - ndims_ker == 0) { +- omp_driver_0d(ndims_ker, in, out, scale); ++ if (req_compensation) ++ std::memset(compensation_reduce_scratch, 0, wspace_per_thr_bytes); ++ ++ omp_driver_0d(ndims_ker, in, out, scale, src_zp, dst_zp, ++ compensation_reduce_scratch); + } else { + parallel(pd()->nthr_, [&](const int ithr, const int nthr) { ++ int32_t *compensation_scratch = nullptr; ++ if (req_compensation) { ++ compensation_scratch = &compensation_reduce_scratch[ithr ++ * wspace_per_thr_size]; ++ std::memset(compensation_scratch, 0, wspace_per_thr_bytes); ++ } ++ + switch (ndims - ndims_ker) { + case 1: +- omp_driver_1d(ithr, nthr, ndims_ker, in, out, scale); ++ omp_driver_1d(ithr, nthr, ndims_ker, in, out, scale, src_zp, ++ dst_zp, compensation_scratch); + break; + case 2: +- omp_driver_2d(ithr, nthr, ndims_ker, in, out, scale); ++ omp_driver_2d(ithr, nthr, ndims_ker, in, out, scale, src_zp, ++ dst_zp, compensation_scratch); + break; + case 3: +- omp_driver_3d(ithr, nthr, ndims_ker, in, out, scale); ++ omp_driver_3d(ithr, nthr, ndims_ker, in, out, scale, src_zp, ++ dst_zp, compensation_scratch); + break; + case 4: +- omp_driver_4d(ithr, nthr, ndims_ker, in, out, scale); ++ omp_driver_4d(ithr, nthr, ndims_ker, in, out, scale, src_zp, ++ dst_zp, compensation_scratch); + break; + default: assert(!"unimplemented"); + } + }); + } ++ ++ //reduction of intermediate compensation results to the final output ++ if (req_compensation) { ++ const int nthr = ndims - ndims_ker == 0 ? 1 : pd()->nthr_; ++ reduce_compensation( ++ out, compensation_reduce_scratch, nthr, wspace_per_thr_size); ++ } ++} ++ ++void jit_uni_reorder_t::reduce_compensation(char *out, ++ const int32_t *compensation_reduce_scratch, const int nthr, ++ const dim_t wspace_per_thr_size) const { ++ ++ const memory_desc_wrapper od(pd()->dst_md()); ++ const size_t offset = od.size() - od.additional_buffer_size(); ++ ++ static constexpr auto comp_dt_size = sizeof(int32_t); ++ static constexpr int32_t comp_s8s8_shift = 128; ++ ++ // Note: We do not need to explicitly zero-out compensation buffer, as the ++ // per_thread buffers are already zeroed out in the padded area. ++ const auto G = pd()->with_groups_ ? od.padded_dims()[0] : 1; ++ const auto N = od.padded_dims()[pd()->with_groups_ ? 1 : 0]; ++ const auto GN = G * N; ++ const bool req_s8s8_comp = pd()->prb_.req_s8s8_comp; ++ const bool req_asymmetric_comp = pd()->prb_.req_asymmetric_comp; ++ const size_t zp_offset ++ = offset + (pd()->prb_.req_s8s8_comp ? GN * comp_dt_size : 0); ++ ++ parallel_nd(GN, [&](int idx) { ++ int32_t acc = 0; ++ for (int ithr = 0; ithr < nthr; ithr++) { ++ acc -= compensation_reduce_scratch[ithr * wspace_per_thr_size ++ + idx]; ++ } ++ if (req_s8s8_comp) { ++ int32_t *out_comp = reinterpret_cast(&out[offset]); ++ out_comp[idx] = comp_s8s8_shift * acc; ++ } ++ if (req_asymmetric_comp) { ++ int32_t *out_asym_comp ++ = reinterpret_cast(&out[zp_offset]); ++ out_asym_comp[idx] = acc; ++ } ++ }); ++} ++ ++void jit_uni_reorder_t::fill_curr_data_chunks(const tr::prb_t &prb, ++ const int off, const ptrdiff_t *omp_data_chunks, const int omp_ndims, ++ tr::tail_call_param_t &c) const { ++ // Chunks are backwards numered i.e: ++ // [0] -> [node_size] ++ // [1] -> [node_size - 1] ++ // ... ++ // [node_size - 1] -> [1] ++ ++ // It is done like this, because it is easier to decrement counter ++ // and check if it is equal to zero than increment and check ++ // if it is equal to node_size in jit kernel. ++ ++ static constexpr int64_t empty_chunk_info = -1; ++ static constexpr int64_t last_chunk = 1; ++ ++ for (int curr_node_id = prb.ndims - 1; curr_node_id >= 0; curr_node_id--) { ++ const int parent_node_id = prb.nodes[curr_node_id].parent_node_id; ++ const bool is_drv_processing_this_node ++ = curr_node_id >= off && curr_node_id <= off + omp_ndims - 1; ++ const bool is_tail_processing ++ = prb.is_tail_in_one_of_child_nodes(curr_node_id) ++ || prb.nodes[curr_node_id].tail_size > 0; ++ ++ if (is_drv_processing_this_node && is_tail_processing) { ++ const int inner_idx = curr_node_id - off; ++ assert(inner_idx < omp_ndims); ++ const int64_t node_size = prb.nodes[curr_node_id].tail_size > 0 ++ ? prb.nodes[curr_node_id].tail_size ++ : prb.nodes[curr_node_id].n; ++ const int64_t data_chunk = node_size - omp_data_chunks[inner_idx]; ++ ++ if (!prb.nodes[curr_node_id].is_parent_empty()) { ++ const bool is_parent_chunk_last ++ = c.curr_data_chunks[parent_node_id] == last_chunk; ++ c.curr_data_chunks[curr_node_id] ++ = is_parent_chunk_last ? data_chunk : empty_chunk_info; ++ c.zeroing_data = static_cast( ++ is_parent_chunk_last && data_chunk <= 0); ++ } else { ++ c.curr_data_chunks[curr_node_id] = data_chunk; ++ c.zeroing_data = static_cast(data_chunk <= 0); ++ } ++ c.skip_kernel_execution = static_cast(c.zeroing_data ++ && !prb.nodes[curr_node_id].is_zero_pad_needed); ++ if (c.zeroing_data || c.skip_kernel_execution) break; ++ } else ++ c.curr_data_chunks[curr_node_id] = empty_chunk_info; ++ } + } + + status_t jit_uni_reorder_t::init(engine_t *engine) { +@@ -1801,13 +2903,98 @@ status_t jit_uni_reorder_t::init(engine_t *engine) { + } + + status_t jit_uni_reorder_t::execute(const exec_ctx_t &ctx) const { +- status_t status = status::success; + auto in = CTX_IN_MEM(const char *, DNNL_ARG_FROM); +- auto out = CTX_OUT_CLEAN_MEM(char *, DNNL_ARG_TO, status); +- CHECK(status); ++ auto out = CTX_OUT_MEM(char *, DNNL_ARG_TO); + DEFINE_SCALES_BUFFER(scales); ++ DEFINE_ZERO_POINT_VALUE(src_zp, DNNL_ARG_FROM); ++ DEFINE_ZERO_POINT_VALUE(dst_zp, DNNL_ARG_TO); ++ const auto &scratchpad = ctx.get_scratchpad_grantor(); ++ ++ omp_driver(in, out, scales, src_zp, dst_zp, scratchpad); ++ ++ return status::success; ++} ++ ++status_t jit_blk_reorder_t::pd_t::create(reorder_pd_t **reorder_pd, ++ engine_t *engine, const primitive_attr_t *attr, engine_t *src_engine, ++ const memory_desc_t *src_md, engine_t *dst_engine, ++ const memory_desc_t *dst_md) { ++ auto prb = tr::prb_t(); ++ ++ status_t prb_init_status = prb_init(prb, *src_md, *dst_md, attr); ++ if (prb_init_status != status::success) return prb_init_status; ++ // only uni_reorder supports tail processing now ++ // TODO: Add tail processing support in blk_reorder ++ if (prb.is_tail_present) return status::unimplemented; ++ ++ prb_tile_normalize(prb); ++ DEBUG({ ++ printf("tile : "); ++ prb_dump(prb); ++ }); ++ ++ if (!tr::jit_single_blk_kernel_t::applicable(prb)) { ++ return status::unimplemented; ++ } + +- omp_driver(in, out, scales); ++ auto _pd = new pd_t( ++ attr, src_engine->kind(), src_md, dst_engine->kind(), dst_md); ++ if (_pd == nullptr) return status::out_of_memory; ++ _pd->prb_ = prb; ++ if (_pd->init(engine, src_engine, dst_engine) != status::success) { ++ delete _pd; ++ return status::unimplemented; ++ } ++ _pd->init_scratchpad_md(); ++ ++ return safe_ptr_assign(*reorder_pd, _pd); ++} ++ ++void jit_blk_reorder_t::pd_t::prb_tile_normalize(tr::prb_t &p) { ++ if (!utils::one_of(p.nodes[0].n, 8ul, 16ul) ++ && utils::one_of(p.nodes[1].n, 8ul, 16ul)) { ++ nstl::swap(p.nodes[0], p.nodes[1]); ++ } ++} ++ ++jit_blk_reorder_t::jit_blk_reorder_t(const pd_t *apd) : primitive_t(apd) {} ++jit_blk_reorder_t::~jit_blk_reorder_t() = default; ++ ++status_t jit_blk_reorder_t::init(engine_t *engine) { ++ kernel_ = utils::make_unique(pd()->prb_); ++ return kernel_->create_kernel(); ++} ++ ++status_t jit_blk_reorder_t::execute(const exec_ctx_t &ctx) const { ++ const auto in = CTX_IN_MEM(const char *, DNNL_ARG_FROM); ++ auto out = CTX_OUT_MEM(char *, DNNL_ARG_TO); ++ DEFINE_ZERO_POINT_VALUE(src_zp, DNNL_ARG_FROM); ++ DEFINE_ZERO_POINT_VALUE(dst_zp, DNNL_ARG_TO); ++ ++ // kernel handle 2-dimension tiles, a tail is possible ++ auto &prb = this->pd()->prb_; ++ ptrdiff_t BH = 1; ++ for (int i = 2; i < prb.ndims; ++i) { ++ BH *= prb.nodes[i].n; ++ } ++ ++ auto block_sz = prb.n(0); ++ auto n1 = prb.n(1); ++ auto i1 = prb.is(1); ++ auto o1 = prb.os(1); ++ auto FL = (n1 + block_sz - 1) / block_sz; ++ auto bh_stride = BH == 1 ? 0 : prb.is(2); ++ ++ auto itype_sz_ = data_type_size(pd()->prb_.itype); ++ auto otype_sz_ = data_type_size(pd()->prb_.otype); ++ ++ parallel_nd(BH, FL, [&](dim_t bh, dim_t fl) { ++ auto fl_b = fl * block_sz; ++ auto bh_b = bh_stride * bh; ++ auto *i = in + (bh_b + fl_b * i1) * itype_sz_; ++ auto *o = out + (bh_b + fl_b * o1) * otype_sz_; ++ (*kernel_)(i, o, n1 - fl_b < block_sz, src_zp, dst_zp); ++ }); + + return status::success; + } +diff --git a/src/cpu/aarch64/jit_uni_reorder.hpp b/src/cpu/aarch64/jit_uni_reorder.hpp +index 2fb6f0f89f3..bf400430ba5 100644 +--- a/src/cpu/aarch64/jit_uni_reorder.hpp ++++ b/src/cpu/aarch64/jit_uni_reorder.hpp +@@ -1,6 +1,6 @@ + /******************************************************************************* +-* Copyright 2018-2020 Intel Corporation +-* Copyright 2020 FUJITSU LIMITED ++* Copyright 2018-2022 Intel Corporation ++* Copyright 2020-2022 FUJITSU LIMITED + * Copyright 2022 Arm Ltd. and affiliates + * + * Licensed under the Apache License, Version 2.0 (the "License"); +@@ -36,15 +36,76 @@ namespace tr { + constexpr int max_ndims = DNNL_MAX_NDIMS; + + struct node_t { +- size_t n; +- ptrdiff_t is; // input stride +- ptrdiff_t os; // output stride +- ptrdiff_t ss; // scale stride ++ static constexpr int64_t empty_field = -1; ++ ++ size_t n = 0; ++ size_t tail_size = 0; ++ int dim_id = empty_field; ++ int parent_node_id = empty_field; ++ bool is_zero_pad_needed = false; ++ ptrdiff_t is = 0; // input stride ++ ptrdiff_t os = 0; // output stride ++ ptrdiff_t ss = 0; // scale stride ++ ptrdiff_t cs = 0; // compensation stride ++ ++ bool is_dim_id_empty() const { return dim_id == empty_field; } ++ bool is_parent_empty() const { return parent_node_id == empty_field; } + }; + + enum class scale_type_t { NONE, COMMON, MANY }; + + struct prb_t { ++ /* The compensation mask value indicates how big an additional buffer should be. ++ * Possible values for reorder: ++ * 1) standard compensation = 1 = 0b01 ++ * 2) asymmetric compensation = 2 = 0b10 ++ * 3) compensation if tensor contains group = 3 = 0b11 */ ++ static constexpr int invalid_comp_mask = 0; ++ static constexpr int standard_comp_mask = 0b1; ++ static constexpr int asymmetric_comp_mask = 0b10; ++ static constexpr int comp_mask_with_groups ++ = standard_comp_mask + asymmetric_comp_mask; ++ ++ bool is_tail_in_one_of_child_nodes(int parent_node_id) const { ++ for (int i = parent_node_id; i >= 0; i--) { ++ if (nodes[i].parent_node_id == parent_node_id) { ++ if (nodes[i].tail_size != 0) ++ return true; ++ else ++ parent_node_id = i; ++ } ++ } ++ ++ return false; ++ } ++ ++ int tail(int d) const { ++ assert(d < ndims); ++ return static_cast(nodes[d].tail_size); ++ } ++ ++ int n(int d) const { ++ assert(d < ndims); ++ return static_cast(nodes[d].n); ++ } ++ int is(int d) const { ++ assert(d < ndims); ++ return static_cast(nodes[d].is); ++ } ++ int os(int d) const { ++ assert(d < ndims); ++ return static_cast(nodes[d].os); ++ } ++ int ss(int d) const { ++ assert(d < ndims); ++ return static_cast(nodes[d].ss); ++ } ++ ++ int cs(int d) const { ++ assert(d < ndims); ++ return static_cast(nodes[d].cs); ++ } ++ + data_type_t itype; + data_type_t otype; + int ndims; +@@ -54,21 +115,24 @@ struct prb_t { + scale_type_t scale_type; + float beta; + int full_ndims; +- int ip_tail; +- int op_tail; +- int iblock; +- int oblock; +- int blk_chunk_idx; ++ bool is_tail_present = false; ++ float scale_adjust = 1.f; ++ int compensation_mask = invalid_comp_mask; ++ bool req_s8s8_comp = false; ++ bool req_asymmetric_comp = false; ++ bool req_src_zp = false; ++ bool req_dst_zp = false; + }; + + status_t prb_init(prb_t &prb, const memory_desc_t &imd, + const memory_desc_t &omd, const primitive_attr_t *attr); + +-status_t prb_check_blk(prb_t &prb, const memory_desc_t &imd); +- + /** sorts the problem nodes so that output strides come in ascending order */ + void prb_normalize(prb_t &p); + ++/** fill parent node info for blocked nodes */ ++void prb_node_dependency(prb_t &p); ++ + /** folds nodes together if possible */ + void prb_simplify(prb_t &p); + +@@ -88,10 +152,24 @@ void prb_node_move(prb_t &p, int d0, int d1); + void prb_dump(const prb_t &p); + + struct call_param_t { +- const void *in; +- void *out; +- const float *scale; +- size_t blk_chunks; ++ const void *in = nullptr; ++ void *out = nullptr; ++ const float *scale = nullptr; ++ int32_t src_zp = 0; ++ int32_t dst_zp = 0; ++ int32_t *compensation_scratch = nullptr; ++}; ++ ++// The additional structure is needed because ++// using a data structure with tail processing ++// data for non-tail cases reduces kernel ++// performance. This is because there is too ++// much data that has to be transferred to the kernel. ++struct tail_call_param_t { ++ call_param_t base_params; ++ int64_t curr_data_chunks[DNNL_MAX_NDIMS] = {-1}; ++ int64_t zeroing_data = static_cast(false); ++ int64_t skip_kernel_execution = static_cast(false); + }; + + struct kernel_t { +@@ -100,8 +178,12 @@ struct kernel_t { + prb_t prb; + }; + +- kernel_t(const desc_t &desc) : desc_(desc) {} ++ kernel_t(const desc_t &desc) ++ : desc_(desc) ++ , compensation_needed_( ++ desc.prb.req_s8s8_comp || desc.prb.req_asymmetric_comp) {} + virtual void operator()(const call_param_t *c) const = 0; ++ virtual void operator()(const tail_call_param_t *c) const = 0; + virtual status_t create_kernel() = 0; + virtual ~kernel_t() {} + +@@ -119,10 +201,13 @@ struct kernel_t { + protected: + const desc_t desc_; + const prb_t &prb_ = desc_.prb; ++ bool compensation_needed_ = false; + }; + + /* TODO: add trans_t class */ + ++struct jit_single_blk_kernel_t; ++ + } // namespace tr + + struct jit_uni_reorder_t : public primitive_t { +@@ -135,8 +220,13 @@ struct jit_uni_reorder_t : public primitive_t { + tr::prb_t prb_; + tr::kernel_t::desc_t ker_desc_; + int nthr_; ++ bool with_groups_ = false; ++ ++ status_t init( ++ engine_t *engine, engine_t *src_engine, engine_t *dst_engine); + + private: ++ void init_scratchpad(); + static status_t create(reorder_pd_t **reorder_pd, engine_t *engine, + const primitive_attr_t *attr, engine_t *src_engine, + const memory_desc_t *src_md, engine_t *dst_engine, +@@ -151,23 +241,66 @@ struct jit_uni_reorder_t : public primitive_t { + enum { ndims_driver_max = 4 }; + + private: +- void omp_driver_0d( +- int off, const char *in, char *out, const float *scale) const; ++ void omp_driver_0d(int off, const char *in, char *out, const float *scale, ++ int src_zp, int dst_zp, int32_t *compensation_scratch) const; + void omp_driver_1d(int ithr, int nthr, int off, const char *in, char *out, +- const float *scale) const; ++ const float *scale, int src_zp, int dst_zp, ++ int32_t *compensation_scratch) const; + void omp_driver_2d(int ithr, int nthr, int off, const char *in, char *out, +- const float *scale) const; ++ const float *scale, int src_zp, int dst_zp, ++ int32_t *compensation_scratch) const; + void omp_driver_3d(int ithr, int nthr, int off, const char *in, char *out, +- const float *scale) const; ++ const float *scale, int src_zp, int dst_zp, ++ int32_t *compensation_scratch) const; + void omp_driver_4d(int ithr, int nthr, int off, const char *in, char *out, +- const float *scale) const; ++ const float *scale, int src_zp, int dst_zp, ++ int32_t *compensation_scratch) const; ++ ++ void omp_driver(const char *in, char *out, const float *scale, int src_zp, ++ int dst_zp, const memory_tracking::grantor_t &scratchpad) const; + +- void omp_driver(const char *in, char *out, const float *scale) const; ++ void fill_curr_data_chunks(const tr::prb_t &prb, const int off, ++ const ptrdiff_t *omp_data_chunks, const int omp_ndims, ++ tr::tail_call_param_t &c) const; ++ ++ void reduce_compensation(char *out, ++ const int32_t *compensation_reduce_scratch, const int nthr, ++ const dim_t wspace_per_thr_size) const; + + const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } + std::unique_ptr kernel_; + }; + ++struct jit_blk_reorder_t : public primitive_t { ++ using primitive_t::primitive_t; ++ struct pd_t : public cpu_reorder_pd_t { ++ using cpu_reorder_pd_t::cpu_reorder_pd_t; ++ DECLARE_COMMON_PD_T("jit:blk", jit_blk_reorder_t); ++ ++ tr::prb_t prb_; ++ ++ private: ++ static status_t create(reorder_pd_t **reorder_pd, engine_t *engine, ++ const primitive_attr_t *attr, engine_t *src_engine, ++ const memory_desc_t *src_md, engine_t *dst_engine, ++ const memory_desc_t *dst_md); ++ ++ // Swap last two nodes, put block 4, 8, 16 nodes to first ++ static void prb_tile_normalize(tr::prb_t &p); ++ friend dnnl::impl::impl_list_item_t; ++ }; ++ ++ status_t init(engine_t *engine) override; ++ status_t execute(const exec_ctx_t &ctx) const override; ++ ++ jit_blk_reorder_t(const pd_t *apd); ++ ~jit_blk_reorder_t(); ++ ++private: ++ const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } ++ std::unique_ptr kernel_; ++}; ++ + } // namespace aarch64 + } // namespace cpu + } // namespace impl +diff --git a/src/cpu/aarch64/jit_uni_reorder_utils.cpp b/src/cpu/aarch64/jit_uni_reorder_utils.cpp +index 7123811f827..28f36a7e2e7 100644 +--- a/src/cpu/aarch64/jit_uni_reorder_utils.cpp ++++ b/src/cpu/aarch64/jit_uni_reorder_utils.cpp +@@ -1,6 +1,6 @@ + /******************************************************************************* +-* Copyright 2018-2021 Intel Corporation +-* Copyright 2020 FUJITSU LIMITED ++* Copyright 2018-2022 Intel Corporation ++* Copyright 2020-2022 FUJITSU LIMITED + * Copyright 2022 Arm Ltd. and affiliates + * + * Licensed under the Apache License, Version 2.0 (the "License"); +@@ -25,10 +25,21 @@ + #include "common/nstl.hpp" + #include "common/type_helpers.hpp" + #include "common/utils.hpp" +-#include "dnnl_debug.h" ++#include "oneapi/dnnl/dnnl_debug.h" + + #include "cpu/aarch64/jit_uni_reorder.hpp" + ++// #define TR_DEBUG ++#if defined(TR_DEBUG) ++#define DEBUg(...) \ ++ do { \ ++ __VA_ARGS__ \ ++ } while (0) ++#else ++#define DEBUg(...) ++#endif ++#define DEBUG(...) DEBUg(__VA_ARGS__) ++ + using namespace dnnl::impl::types; + using namespace dnnl::impl::status; + +@@ -41,87 +52,45 @@ namespace tr { + + /** ad-hoc structure to describe blocked memory layout */ + struct layout_desc_t { ++ layout_desc_t() ++ : dt(dnnl_data_type_undef) ++ , ndims(0) ++ , id {-1} ++ , dims {0} ++ , tails {0} ++ , is_blk {false} ++ , strides {0} {} + data_type_t dt; + int ndims; + dims_t id; + dims_t dims; ++ dims_t tails; ++ bool is_blk[DNNL_MAX_NDIMS]; + strides_t strides; + }; + +-static status_t compute_blk_and_tail( +- const memory_desc_t &md_, const int idx, int &blk, int &tail) { +- const auto md = memory_desc_wrapper(md_); +- const auto &bd = md.blocking_desc(); +- if (tail == 0) return status::success; +- +- const std::set unique_inner_idxs( +- bd.inner_idxs, bd.inner_idxs + bd.inner_nblks); +- std::set dims_with_multiple_blks; +- for (dim_t dim : unique_inner_idxs) { +- if (std::count(bd.inner_idxs, bd.inner_idxs + bd.inner_nblks, dim) > 1) +- dims_with_multiple_blks.insert(dim); +- } +- +- // Dims that have a tail and have multiple blocks are not supported by the jit kernel yet. +- // For example: +- // src_tag = abcd +- // dst_tag = ABcd16b16a4b +- // 16x15x3x3 +- // In this case, 'b' dim has two blocks and has a tail. It is not a supported case. +- if (dims_with_multiple_blks.find(idx) != dims_with_multiple_blks.end()) +- return status::unimplemented; +- +- // Only supports inconsistent padding in single and double blocks +- // and the total block size <= 256 +- for (int iblk = bd.inner_nblks - 1; iblk > 0; --iblk) { +- if (bd.inner_idxs[iblk] == idx) break; +- blk *= bd.inner_blks[iblk]; +- tail *= bd.inner_blks[iblk]; +- } +- if (unique_inner_idxs.size() > 2 || blk > 256) return status::unimplemented; +- +- return status::success; +-} +- +-static status_t compute_chunk_idx(const prb_t &p, const memory_desc_t &imd_, +- const memory_desc_t &omd_, const int blk_idx, int &chunk_idx) { +- const auto imd = memory_desc_wrapper(imd_); +- const auto omd = memory_desc_wrapper(omd_); +- const auto &ibd = imd.blocking_desc(); +- const auto &obd = omd.blocking_desc(); +- if (p.ip_tail == 0 && p.op_tail == 0) return status::success; +- +- const ptrdiff_t is +- = ibd.strides[blk_idx] * obd.inner_blks[obd.inner_idxs[blk_idx]]; +- const ptrdiff_t os = obd.strides[blk_idx]; +- +- for (int i = blk_idx; i < omd.ndims(); ++i) { +- if (p.nodes[i].os == os && p.nodes[i].is == is) { +- chunk_idx = i; +- return status::success; +- } +- } +- +- return status::invalid_arguments; +-} +- + status_t cvt_mem_desc_to_layout_desc(const memory_desc_t &md_, +- layout_desc_t &ld, const dims_t &blocks, const dims_t &ext_padding) { ++ layout_desc_t &ld, const dims_t &blocks, const dims_t &external_padding, ++ const dims_t &tails) { ++ static constexpr bool it_is_blk = true; ++ + const auto md = memory_desc_wrapper(md_); + +- bool ok = true && md.is_blocking_desc() && md.extra().flags == 0; +- if (!ok) return invalid_arguments; ++ if (!md.is_blocking_desc()) return invalid_arguments; + + const auto &bd = md.blocking_desc(); + + ld.ndims = 0; + ld.dt = md.data_type(); + +- auto P = [&ld](int id, int dim, ptrdiff_t stride) { ++ auto add_dim = [&ld](int id, dim_t dim, dim_t tail, bool is_blk, ++ ptrdiff_t stride) { + assert((size_t)ld.ndims < sizeof(ld.dims) / sizeof(ld.dims[0])); + ld.id[ld.ndims] = id; + ld.dims[ld.ndims] = dim; + ld.strides[ld.ndims] = stride; ++ ld.tails[ld.ndims] = tail; ++ ld.is_blk[ld.ndims] = is_blk; + ++ld.ndims; + }; + +@@ -129,12 +98,27 @@ status_t cvt_mem_desc_to_layout_desc(const memory_desc_t &md_, + const int ld_ndims_start = ld.ndims; + if (blocks[d] != 1) { + stride_t stride = 1; ++ int tail = tails[d]; + for (int iblk = bd.inner_nblks - 1; iblk >= 0; --iblk) { +- if (bd.inner_idxs[iblk] == d) P(d, bd.inner_blks[iblk], stride); ++ if (bd.inner_idxs[iblk] == d) { ++ const dim_t inner_tail = tail % bd.inner_blks[iblk]; ++ add_dim(d, bd.inner_blks[iblk], inner_tail, it_is_blk, ++ stride); ++ tail = utils::div_up(tail, bd.inner_blks[iblk]); ++ } + stride *= bd.inner_blks[iblk]; + } + } +- P(d, (md.padded_dims()[d] + ext_padding[d]) / blocks[d], bd.strides[d]); ++ ++ const dim_t dim_with_external_padding ++ = (md.padded_dims()[d] + external_padding[d]) / blocks[d]; ++ const dim_t padded_dim = md.padded_dims()[d] / blocks[d]; ++ const dim_t tail = dim_with_external_padding != padded_dim ++ ? dim_with_external_padding ++ - (dim_with_external_padding - padded_dim) ++ : 0; ++ ++ add_dim(d, dim_with_external_padding, tail, !it_is_blk, bd.strides[d]); + + // TODO: NOW: revisit, do we need a reverse? + // TODO: NOW: consider using strides instead of block sizes in md +@@ -144,12 +128,70 @@ status_t cvt_mem_desc_to_layout_desc(const memory_desc_t &md_, + const int idx1 = ld.ndims - 1 - ld_d; + nstl::swap(ld.dims[idx0], ld.dims[idx1]); + nstl::swap(ld.strides[idx0], ld.strides[idx1]); ++ nstl::swap(ld.tails[idx0], ld.tails[idx1]); ++ nstl::swap(ld.is_blk[idx0], ld.is_blk[idx1]); + } + } + + return success; + } + ++static bool is_with_groups(const memory_desc_t &dst_md) { ++ using namespace memory_extra_flags; ++ auto dst_d = memory_desc_wrapper(dst_md); ++ const int grp_bit = 1 << 1; ++ auto check_flag_and_mask = [&](int flag, int mask) { ++ return (dst_d.extra().flags & flag) && (mask & grp_bit); ++ }; ++ ++ return check_flag_and_mask( ++ compensation_conv_s8s8, dst_d.extra().compensation_mask) ++ || check_flag_and_mask(compensation_conv_asymmetric_src, ++ dst_d.extra().asymm_compensation_mask); ++} ++ ++static inline int get_next_parent_node(node_t *nodes, int ndims, int cur_node) { ++ const int cur_id = nodes[cur_node].dim_id; ++ for (int d = cur_node + 1; d < ndims; ++d) { ++ if (nodes[d].dim_id == cur_id) return d; ++ } ++ return -1; ++} ++ ++static void prb_set_compensation_strides(prb_t &p) { ++ ++ auto require_n_stride = [&](int cur_node) -> bool { ++ const int parent = get_next_parent_node(p.nodes, p.ndims, cur_node); ++ if (parent < 0) return false; ++ ++ const size_t p_n = p.nodes[parent].n; ++ ++ // if 'parent_node.n' is larger than 1, then cur_node stride ++ // is 'cur_node.n' ++ return p_n > size_t(1); ++ }; ++ ++ const auto compensation_needed = p.req_s8s8_comp || p.req_asymmetric_comp; ++ if (!compensation_needed) return; ++ int mask = p.compensation_mask; ++ ptrdiff_t cs = 1; ++ for (int d = 0; d < p.ndims; ++d) { ++ if (mask & (1 << p.nodes[d].dim_id)) { ++ ++ // correct cases when 'cs' exceeds output stride ++ if (cs > p.nodes[d].os) cs = p.nodes[d].os; ++ ++ p.nodes[d].cs = cs; ++ const bool n_stride = require_n_stride(d); ++ if (p.nodes[d].tail_size > 0 && (!p.nodes[d].is_zero_pad_needed) ++ && (!n_stride)) ++ cs *= p.nodes[d].tail_size; ++ else ++ cs *= p.nodes[d].n; ++ } ++ } ++} ++ + status_t prb_init(prb_t &p, const memory_desc_t &imd, const memory_desc_t &omd, + const primitive_attr_t *attr) { + auto im_d = memory_desc_wrapper(imd); +@@ -157,8 +199,7 @@ status_t prb_init(prb_t &p, const memory_desc_t &imd, const memory_desc_t &omd, + + auto check_post_ops = [](const primitive_attr_t *attr) { + const auto &po = attr->post_ops_; +- return po.len() == 0 +- || (po.len() == 1 && po.contain(primitive_kind::sum, 0)); ++ return po.len() == 0 || (po.len() == 1 && po.entry_[0].is_sum(false)); + }; + + bool ok = im_d.is_blocking_desc() && om_d.is_blocking_desc() +@@ -166,81 +207,129 @@ status_t prb_init(prb_t &p, const memory_desc_t &imd, const memory_desc_t &omd, + && !om_d.has_runtime_dims_or_strides() && !om_d.has_zero_dim() + && attr->has_default_values( + primitive_attr_t::skip_mask_t::oscale_runtime ++ | primitive_attr_t::skip_mask_t::zero_points_runtime + | primitive_attr_t::skip_mask_t::post_ops) + && check_post_ops(attr); + if (!ok) return unimplemented; + +- dims_t iblocks, oblocks, ip_padding, op_padding; ++ bool is_tail_present = false; ++ dims_t iblocks, oblocks, i_tails, o_tails, i_paddings, o_paddings; + im_d.compute_blocks(iblocks); + om_d.compute_blocks(oblocks); +- utils::array_set(ip_padding, 0, im_d.ndims()); +- utils::array_set(op_padding, 0, om_d.ndims()); +- +- /* padding_dim consistency check +- * only supports inconsitent padding for src +- * TODO: Add inconsistent padding support for dst */ +- int ip_tail = 0; +- int op_tail = 0; +- int iblk_w_tail = 1; +- int oblk_w_tail = 1; +- int blk_idx = 0; ++ ++ for (int d = 0; d < om_d.ndims(); ++d) { ++ const auto dim = om_d.dims()[d]; ++ const auto pdim = om_d.padded_dims()[d]; ++ const auto cblock = oblocks[d]; ++ // do not allow excess pdim other than required for rounding-up of dim. ++ if (utils::rnd_up(dim, cblock) != pdim) return unimplemented; ++ } ++ ++ utils::array_set(i_tails, 0, im_d.ndims()); ++ utils::array_set(o_tails, 0, om_d.ndims()); ++ utils::array_set(i_paddings, 0, im_d.ndims()); ++ utils::array_set(o_paddings, 0, om_d.ndims()); + + for (int d = 0; d < im_d.ndims(); ++d) { +- const int ip_tmp_dim = im_d.padded_dims()[d]; +- const int op_tmp_dim = om_d.padded_dims()[d]; +- const int ip_tmp_tail = ip_tmp_dim % oblocks[d]; +- const int op_tmp_tail = op_tmp_dim % iblocks[d]; +- +- const bool pdim_consistent = ip_tmp_dim == op_tmp_dim +- && ip_tmp_tail == 0 && op_tmp_tail == 0; +- const bool pdim_tail = ip_tmp_tail > 0 +- && (ip_tmp_dim + oblocks[d] - ip_tmp_tail) == op_tmp_dim +- && op_tmp_tail == 0 && ip_tail == 0; +- if (!pdim_consistent && !pdim_tail) return status::unimplemented; +- if (pdim_tail) { +- blk_idx = d; +- ip_tail = ip_tmp_tail; +- op_tail = op_tmp_tail; +- iblk_w_tail = iblocks[d]; +- oblk_w_tail = oblocks[d]; +- ip_padding[d] = oblocks[d] - ip_tmp_tail; +- op_padding[d] = iblocks[d] - op_tmp_tail; ++ const dim_t i_dim = im_d.dims()[d]; ++ const dim_t o_dim = om_d.dims()[d]; ++ const dim_t i_tail = i_dim % iblocks[d]; ++ const dim_t o_tail = o_dim % oblocks[d]; ++ ++ if (o_tail > 0) { ++ is_tail_present = true; ++ o_tails[d] = o_tail; ++ o_paddings[d] = oblocks[d] - o_tail; ++ } ++ ++ if (i_tail > 0) { ++ is_tail_present = true; ++ i_tails[d] = i_tail; ++ i_paddings[d] = iblocks[d] - i_tail; + } + } +- CHECK(compute_blk_and_tail(omd, blk_idx, oblk_w_tail, ip_tail)); + ++ // To compute input layout description we need to pass output paddings ++ // which will be used to compute input dims rounded up to multiple of ++ // output dims. Analogous applies to output layout description. ++ // This is demanded by the algorithm of nodes creation. ++ // Example: ++ // input: ++ // format: abc ++ // size: 77, 15, 3 ++ // o_padding: 3, 17, 0 ++ // returns ild: 80, 32, 3 ++ // output: ++ // format: ABc16b16a2b ++ // size: 77, 15, 3 ++ // i_padding: 0, 0, 0 ++ // returns old: 5, 16, 1, 16, 2, 3 + layout_desc_t ild, old; +- status_t status +- = cvt_mem_desc_to_layout_desc(imd, ild, iblocks, ip_padding); +- if (status != success) return status; +- status = cvt_mem_desc_to_layout_desc(omd, old, oblocks, op_padding); +- if (status != success) return status; ++ CHECK(cvt_mem_desc_to_layout_desc(imd, ild, iblocks, o_paddings, i_tails)); ++ CHECK(cvt_mem_desc_to_layout_desc(omd, old, oblocks, i_paddings, o_tails)); + + p.itype = ild.dt; + p.otype = old.dt; +- p.ip_tail = ip_tail; +- p.op_tail = op_tail; +- p.iblock = iblk_w_tail; +- p.oblock = oblk_w_tail; +- ++ p.is_tail_present = is_tail_present; ++ p.req_src_zp = !attr->zero_points_.has_default_values(DNNL_ARG_SRC); ++ p.req_dst_zp = !attr->zero_points_.has_default_values(DNNL_ARG_DST); + p.scale_type = attr->output_scales_.has_default_values() + ? scale_type_t::NONE + : (attr->output_scales_.mask_ == 0 ? scale_type_t::COMMON + : scale_type_t::MANY); ++ p.scale_adjust = (om_d.extra().flags & memory_extra_flags::scale_adjust) ++ ? om_d.extra().scale_adjust ++ : 1.f; ++ p.req_s8s8_comp ++ = om_d.extra().flags & memory_extra_flags::compensation_conv_s8s8; ++ p.req_asymmetric_comp = om_d.extra().flags ++ & memory_extra_flags::compensation_conv_asymmetric_src; ++ ++ const bool with_groups = is_with_groups(omd); ++ ++ auto mask_ok = [&](bool check, int mask) { ++ return IMPLICATION(check, mask == (with_groups ? 0x3 : 0x1)); ++ }; ++ ++ if (!mask_ok(p.req_s8s8_comp, om_d.extra().compensation_mask) ++ || !mask_ok(p.req_asymmetric_comp, ++ om_d.extra().asymm_compensation_mask)) ++ return status::unimplemented; + +- ptrdiff_t ss[max_ndims] = {0}; ++ ptrdiff_t ss[max_ndims] = {0}; // scales strides + if (p.scale_type == scale_type_t::MANY) { +- ptrdiff_t last_ss = 1; ++ const int mask = attr->output_scales_.mask_; ++ ptrdiff_t dense_stride = 1; ++ ptrdiff_t last_stride = 1; + for (int d = old.ndims - 1; d >= 0; --d) { + assert((d == 0 || old.id[d - 1] <= old.id[d]) + && "logical dimensions should be in ascending order"); +- if (attr->output_scales_.mask_ & (1 << old.id[d])) { +- ss[d] = last_ss; +- last_ss *= old.dims[d]; ++ if (mask & (1 << old.id[d])) { ++ if ((d + 1) < old.ndims && old.id[d + 1] != old.id[d] ++ && (mask & (1 << old.id[d + 1]))) { ++ dense_stride = dense_stride * imd.dims[old.id[d + 1]]; ++ last_stride = dense_stride; ++ } ++ ss[d] = last_stride; ++ last_stride *= old.dims[d]; + } + } + } + ++ const auto compensation_needed = p.req_s8s8_comp || p.req_asymmetric_comp; ++ if (compensation_needed) { ++ p.compensation_mask = p.req_s8s8_comp ++ ? om_d.extra().compensation_mask ++ : (p.req_asymmetric_comp ? om_d.extra().asymm_compensation_mask ++ : tr::prb_t::invalid_comp_mask); ++ ++ if (p.compensation_mask == tr::prb_t::asymmetric_comp_mask) ++ return unimplemented; ++ ++ assert(p.compensation_mask == tr::prb_t::standard_comp_mask ++ || p.compensation_mask == tr::prb_t::comp_mask_with_groups); ++ } ++ + int ndims = 0; + + int i_pos = 0; /* state for input -- current dimension */ +@@ -254,6 +343,10 @@ status_t prb_init(prb_t &p, const memory_desc_t &imd, const memory_desc_t &omd, + + if (ild.dims[i_pos] == old.dims[o_pos]) { + p.nodes[ndims].n = ild.dims[i_pos]; ++ p.nodes[ndims].dim_id = old.id[o_pos]; ++ p.nodes[ndims].tail_size = old.tails[o_pos]; ++ p.nodes[ndims].is_zero_pad_needed ++ = old.is_blk[o_pos] && old.tails[o_pos] > 0; + p.nodes[ndims].is = ild.strides[i_pos]; + p.nodes[ndims].os = old.strides[o_pos]; + p.nodes[ndims].ss = ss[o_pos]; +@@ -261,19 +354,45 @@ status_t prb_init(prb_t &p, const memory_desc_t &imd, const memory_desc_t &omd, + ++i_pos; + ++o_pos; + } else if (ild.dims[i_pos] < old.dims[o_pos]) { +- assert(old.dims[o_pos] % ild.dims[i_pos] == 0); +- int factor = old.dims[o_pos] / ild.dims[i_pos]; ++ // old must be divisible by ild or we will not be ++ // able to create valid nodes. The problem appears ++ // when stag=Acdb48a and dtag=Acdb32a for example. ++ if (ild.dims[i_pos] == 0 || old.dims[o_pos] % ild.dims[i_pos] != 0) ++ return status::unimplemented; ++ ++ dim_t factor = old.dims[o_pos] / ild.dims[i_pos]; ++ ++ const size_t tail_of_upper_dim ++ = utils::div_up(old.tails[o_pos], factor) == ild.dims[i_pos] ++ ? 0 ++ : utils::div_up(old.tails[o_pos], factor); ++ const size_t tail_of_lower_dim = old.tails[o_pos] % factor; ++ + p.nodes[ndims].n = ild.dims[i_pos]; ++ p.nodes[ndims].dim_id = old.id[o_pos]; ++ p.nodes[ndims].tail_size = tail_of_upper_dim; ++ p.nodes[ndims].is_zero_pad_needed ++ = old.is_blk[o_pos] && tail_of_upper_dim > 0; + p.nodes[ndims].is = ild.strides[i_pos]; + p.nodes[ndims].os = old.strides[o_pos] * factor; + p.nodes[ndims].ss = ss[o_pos] * factor; + ++ndims; + ++i_pos; + old.dims[o_pos] = factor; ++ old.tails[o_pos] = tail_of_lower_dim; + } else if (ild.dims[i_pos] > old.dims[o_pos]) { +- assert(ild.dims[i_pos] % old.dims[o_pos] == 0); +- int factor = ild.dims[i_pos] / old.dims[o_pos]; ++ // ild must be divisible by old or we will not be ++ // able to create valid nodes. The problem appears ++ // when stag=Acdb32a and dtag=Acdb48a for example. ++ if (old.dims[o_pos] == 0 || ild.dims[i_pos] % old.dims[o_pos] != 0) ++ return status::unimplemented; ++ ++ dim_t factor = ild.dims[i_pos] / old.dims[o_pos]; + p.nodes[ndims].n = old.dims[o_pos]; ++ p.nodes[ndims].dim_id = old.id[o_pos]; ++ p.nodes[ndims].tail_size = old.tails[o_pos]; ++ p.nodes[ndims].is_zero_pad_needed ++ = old.is_blk[o_pos] && old.tails[o_pos] > 0; + p.nodes[ndims].is = ild.strides[i_pos] * factor; + p.nodes[ndims].os = old.strides[o_pos]; + p.nodes[ndims].ss = ss[o_pos]; +@@ -282,12 +401,9 @@ status_t prb_init(prb_t &p, const memory_desc_t &imd, const memory_desc_t &omd, + ild.dims[i_pos] = factor; + } + } +- int blk_chunk_idx = ndims; +- CHECK(compute_chunk_idx(p, imd, omd, blk_idx, blk_chunk_idx)); + + p.ndims = ndims; + p.full_ndims = ndims; +- p.blk_chunk_idx = blk_chunk_idx; + + p.ioff = memory_desc_wrapper(imd).offset0(); + p.ooff = memory_desc_wrapper(omd).offset0(); +@@ -295,6 +411,28 @@ status_t prb_init(prb_t &p, const memory_desc_t &imd, const memory_desc_t &omd, + const int sum_idx = attr->post_ops_.find(primitive_kind::sum); + p.beta = sum_idx == -1 ? 0.f : attr->post_ops_.entry_[sum_idx].sum.scale; + ++ DEBUG({ ++ printf("init : "); ++ prb_dump(prb); ++ }); ++ // Sort the prb array in increasing sizes of the output stride ++ prb_normalize(p); ++ DEBUG({ ++ printf("norm : "); ++ prb_dump(prb); ++ }); ++ ++ // compensation strides require prb_normalized ++ prb_set_compensation_strides(p); ++ ++ /* Combine the variables, which appear together on both ++ * sides of the reorder */ ++ prb_simplify(p); ++ DEBUG({ ++ printf("smpl : "); ++ prb_dump(prb); ++ }); ++ + return success; + } + +@@ -307,28 +445,23 @@ void prb_normalize(prb_t &p) { + && p.nodes[j].n < p.nodes[min_pos].n); + if (new_min) min_pos = j; + } +- if (min_pos != d) { +- nstl::swap(p.nodes[d], p.nodes[min_pos]); +- if (p.blk_chunk_idx == min_pos || p.blk_chunk_idx == d) +- p.blk_chunk_idx = p.blk_chunk_idx == min_pos ? d : min_pos; +- } ++ if (min_pos != d) { nstl::swap(p.nodes[d], p.nodes[min_pos]); } + } + } + +-status_t prb_check_blk(prb_t &p, const memory_desc_t &md_) { +- const auto md = memory_desc_wrapper(md_); +- const auto &bd = md.blocking_desc(); +- if (p.ip_tail == 0) return status::success; +- +- // Check if the inner blocks and p.nodes[blk].n in the firsti nblks +- // is equivalent in reverse order when has tail in block layout. +- const int nblk = bd.inner_nblks; +- for (int iblk = 0; iblk < nblk; ++iblk) { +- if (bd.inner_blks[nblk - iblk - 1] +- != static_cast(p.nodes[iblk].n)) +- return status::unimplemented; ++void prb_node_dependency(prb_t &prb) { ++ for (int i = 0; i < prb.ndims; i++) { ++ tr::node_t &node = prb.nodes[i]; ++ node.parent_node_id = node_t::empty_field; ++ for (int j = i + 1; j < prb.ndims; j++) { ++ const tr::node_t &potential_parent_node = prb.nodes[j]; ++ if (!potential_parent_node.is_dim_id_empty() ++ && potential_parent_node.dim_id == node.dim_id) { ++ node.parent_node_id = j; ++ break; ++ } ++ } + } +- return status::success; + } + + void prb_simplify(prb_t &p) { +@@ -338,16 +471,25 @@ void prb_simplify(prb_t &p) { + #pragma GCC diagnostic push + #pragma GCC diagnostic ignored "-Warray-bounds" + #endif ++ ++ const auto skip_dim_combining = [&p](const int node_id) -> bool { ++ return (p.is_tail_in_one_of_child_nodes(node_id) ++ && p.nodes[node_id].n > 1) ++ || p.nodes[node_id].tail_size > 0; ++ }; ++ ++ if (p.is_tail_present) prb_node_dependency(p); ++ + for (int d = 0; d < p.ndims - 1; ++d) { + auto &this_node = p.nodes[d + 0]; + auto &next_node = p.nodes[d + 1]; +- const bool skip_blk_idx = (p.ip_tail > 0 || p.op_tail > 0) +- && (p.blk_chunk_idx == d || p.blk_chunk_idx == d + 1); ++ const bool skip_dims_combining ++ = skip_dim_combining(d) || skip_dim_combining(d + 1); + const bool fold = false + || (next_node.n == static_cast(1) +- && !skip_blk_idx) // trivial case, just drop next node ++ && !skip_dims_combining) // trivial case, just drop next node + || (true // or real folding if possible +- && !skip_blk_idx ++ && !skip_dims_combining + && next_node.is + == static_cast( + this_node.n * this_node.is) +@@ -356,15 +498,20 @@ void prb_simplify(prb_t &p) { + this_node.n * this_node.os) + && next_node.ss + == static_cast( +- this_node.n * this_node.ss)); ++ this_node.n * this_node.ss) ++ && next_node.cs ++ == static_cast( ++ this_node.n * this_node.cs)); + if (fold) { + this_node.n *= next_node.n; ++ this_node.dim_id = node_t::empty_field; ++ this_node.is_zero_pad_needed = false; + for (int j = d + 2; j < p.ndims; ++j) + p.nodes[j - 1] = p.nodes[j]; +- if (d < p.blk_chunk_idx) --p.blk_chunk_idx; + --p.ndims; + --p.full_ndims; + --d; // make another try ++ if (p.is_tail_present) prb_node_dependency(p); + } + } + #if defined(__GNUC__) && __GNUC__ >= 4 +@@ -372,24 +519,42 @@ void prb_simplify(prb_t &p) { + #endif + } + +-void prb_node_split(prb_t &p, int dim, size_t n1) { ++void prb_node_split(prb_t &p, int dim, size_t new_node_size) { + assert(dim < p.ndims); + assert(p.ndims < max_ndims); +- assert(p.nodes[dim].n % n1 == 0); ++ assert(p.nodes[dim].n % new_node_size == 0); + + p.ndims += 1; + p.full_ndims += 1; +- if (dim < p.blk_chunk_idx) p.blk_chunk_idx += 1; + + for (int d = p.ndims; d > dim + 1; --d) + p.nodes[d] = p.nodes[d - 1]; + +- p.nodes[dim + 1].n = p.nodes[dim].n / n1; +- p.nodes[dim + 1].is = p.nodes[dim].is * n1; +- p.nodes[dim + 1].os = p.nodes[dim].os * n1; +- p.nodes[dim + 1].ss = p.nodes[dim].ss * n1; +- +- p.nodes[dim].n = n1; ++ const size_t upper_node_size = p.nodes[dim].n / new_node_size; ++ const size_t lower_node_size = new_node_size; ++ p.nodes[dim + 1].n = upper_node_size; ++ p.nodes[dim].n = lower_node_size; ++ ++ const bool is_tail = p.nodes[dim].tail_size > 0; ++ const size_t upper_node_tail ++ = utils::div_up(p.nodes[dim].tail_size, lower_node_size) ++ == upper_node_size ++ ? 0 ++ : utils::div_up(p.nodes[dim].tail_size, lower_node_size); ++ const size_t lower_node_tail = p.nodes[dim].tail_size % lower_node_size; ++ p.nodes[dim].tail_size = is_tail ? lower_node_tail : 0; ++ p.nodes[dim + 1].tail_size = is_tail ? upper_node_tail : 0; ++ ++ p.nodes[dim + 1].is_zero_pad_needed ++ = p.nodes[dim].is_zero_pad_needed && p.nodes[dim + 1].tail_size > 0; ++ p.nodes[dim].is_zero_pad_needed ++ = p.nodes[dim].is_zero_pad_needed && p.nodes[dim].tail_size > 0; ++ ++ p.nodes[dim + 1].dim_id = p.nodes[dim].dim_id; ++ p.nodes[dim + 1].is = p.nodes[dim].is * lower_node_size; ++ p.nodes[dim + 1].os = p.nodes[dim].os * lower_node_size; ++ p.nodes[dim + 1].ss = p.nodes[dim].ss * lower_node_size; ++ p.nodes[dim + 1].cs = p.nodes[dim].cs * lower_node_size; + } + + void prb_node_swap(prb_t &p, int d0, int d1) { +@@ -425,8 +590,11 @@ void prb_dump(const prb_t &p) { + printf("@@@ type:%s:%s ndims:%d ", dnnl_dt2str(p.itype), + dnnl_dt2str(p.otype), p.ndims); + for (int d = 0; d < p.ndims; ++d) +- printf("[%zu:%td:%td:%td]", p.nodes[d].n, p.nodes[d].is, p.nodes[d].os, +- p.nodes[d].ss); ++ printf("[%zu:%zu:%d:%d:%s:%td:%td:%td:%td]", p.nodes[d].n, ++ p.nodes[d].tail_size, p.nodes[d].dim_id, ++ p.nodes[d].parent_node_id, ++ p.nodes[d].is_zero_pad_needed ? "true" : "false", p.nodes[d].is, ++ p.nodes[d].os, p.nodes[d].ss, p.nodes[d].cs); + printf(" off:%zu:%zu\n", p.ioff, p.ooff); + } + +diff --git a/src/cpu/reorder/cpu_reorder_regular_f32_f32.cpp b/src/cpu/reorder/cpu_reorder_regular_f32_f32.cpp +index f51e3c22414..fdefec8a049 100644 +--- a/src/cpu/reorder/cpu_reorder_regular_f32_f32.cpp ++++ b/src/cpu/reorder/cpu_reorder_regular_f32_f32.cpp +@@ -1,5 +1,6 @@ + /******************************************************************************* + * Copyright 2020-2022 Intel Corporation ++* Copyright 2022 FUJITSU LIMITED + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. +@@ -32,6 +33,7 @@ const impl_list_map_t ®ular_f32_f32_impl_list_map() { + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_blk_reorder_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_t)) + ++ DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_blk_reorder_t)) + DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_uni_reorder_t)) + REG_SR(f32, any, f32, any, fmt_order::any, spec::reference) + +@@ -44,6 +46,7 @@ const impl_list_map_t ®ular_f32_f32_impl_list_map() { + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_blk_reorder_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_t)) + ++ DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_blk_reorder_t)) + DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_uni_reorder_t)) + DNNL_NON_X64_ONLY(REG_SR_BIDIR(f32, any, f32, nCw16c)) + DNNL_NON_X64_ONLY(REG_SR_BIDIR(f32, any, f32, nCw8c)) +@@ -75,6 +78,7 @@ const impl_list_map_t ®ular_f32_f32_impl_list_map() { + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_blk_reorder_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_t)) + ++ DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_blk_reorder_t)) + DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_uni_reorder_t)) + + DNNL_NON_X64_ONLY(REG_SR_BIDIR(f32, any, f32, nChw16c)) +@@ -123,6 +127,7 @@ const impl_list_map_t ®ular_f32_f32_impl_list_map() { + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_blk_reorder_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_t)) + ++ DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_blk_reorder_t)) + DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_uni_reorder_t)) + + DNNL_NON_X64_ONLY(REG_SR_BIDIR(f32, any, f32, nCdhw16c)) +@@ -171,6 +176,7 @@ const impl_list_map_t ®ular_f32_f32_impl_list_map() { + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_blk_reorder_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_t)) + ++ DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_blk_reorder_t)) + DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_uni_reorder_t)) + + +diff --git a/src/cpu/reorder/cpu_reorder_regular_f32_s32.cpp b/src/cpu/reorder/cpu_reorder_regular_f32_s32.cpp +index fadbee0ecf8..b1881df80e0 100644 +--- a/src/cpu/reorder/cpu_reorder_regular_f32_s32.cpp ++++ b/src/cpu/reorder/cpu_reorder_regular_f32_s32.cpp +@@ -1,5 +1,6 @@ + /******************************************************************************* + * Copyright 2020-2022 Intel Corporation ++* Copyright 2022 FUJITSU LIMITED + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. +@@ -31,6 +32,7 @@ const impl_list_map_t ®ular_f32_s32_impl_list_map() { + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_blk_reorder_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_t)) + ++ DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_blk_reorder_t)) + DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_uni_reorder_t)) + DNNL_NON_X64_ONLY(REG_SR_BIDIR(f32, any, s32, nChw16c)) + REG_SR(f32, any, s32, any, fmt_order::any, spec::reference) +diff --git a/src/cpu/reorder/cpu_reorder_regular_f32_s8.cpp b/src/cpu/reorder/cpu_reorder_regular_f32_s8.cpp +index b83d47b2d6f..6bd305c7b41 100644 +--- a/src/cpu/reorder/cpu_reorder_regular_f32_s8.cpp ++++ b/src/cpu/reorder/cpu_reorder_regular_f32_s8.cpp +@@ -1,5 +1,6 @@ + /******************************************************************************* + * Copyright 2020-2022 Intel Corporation ++* Copyright 2022 FUJITSU LIMITED + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. +@@ -35,6 +36,7 @@ const impl_list_map_t ®ular_f32_s8_impl_list_map() { + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_blk_reorder_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_t)) + ++ DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_blk_reorder_t)) + DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_uni_reorder_t)) + + DNNL_NON_X64_ONLY(REG_SR_BIDIR(f32, any, s8, nChw16c)) +diff --git a/src/cpu/reorder/cpu_reorder_regular_f32_u8.cpp b/src/cpu/reorder/cpu_reorder_regular_f32_u8.cpp +index 4bae84307e6..d306c3abeb8 100644 +--- a/src/cpu/reorder/cpu_reorder_regular_f32_u8.cpp ++++ b/src/cpu/reorder/cpu_reorder_regular_f32_u8.cpp +@@ -1,5 +1,6 @@ + /******************************************************************************* + * Copyright 2020-2022 Intel Corporation ++* Copyright 2022 FUJITSU LIMITED + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. +@@ -33,6 +34,7 @@ const impl_list_map_t ®ular_f32_u8_impl_list_map() { + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_blk_reorder_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_t)) + ++ DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_blk_reorder_t)) + DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_uni_reorder_t)) + DNNL_NON_X64_ONLY(REG_SR_BIDIR(f32, any, u8, nChw16c)) + REG_SR(f32, any, u8, any, fmt_order::any, spec::reference) +diff --git a/src/cpu/reorder/cpu_reorder_regular_s32.cpp b/src/cpu/reorder/cpu_reorder_regular_s32.cpp +index 54d65661791..a8197402b0a 100644 +--- a/src/cpu/reorder/cpu_reorder_regular_s32.cpp ++++ b/src/cpu/reorder/cpu_reorder_regular_s32.cpp +@@ -1,5 +1,6 @@ + /******************************************************************************* + * Copyright 2020-2022 Intel Corporation ++* Copyright 2022 FUJITSU LIMITED + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. +@@ -34,6 +35,7 @@ const impl_list_map_t ®ular_s32_impl_list_map() { + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_blk_reorder_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_t)) + ++ DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_blk_reorder_t)) + DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_uni_reorder_t)) + + DNNL_NON_X64_ONLY(REG_SR_BIDIR(s32, any, f32, nChw16c)) +diff --git a/src/cpu/reorder/cpu_reorder_regular_s8.cpp b/src/cpu/reorder/cpu_reorder_regular_s8.cpp +index f57d01e2009..ce18dc5caf1 100644 +--- a/src/cpu/reorder/cpu_reorder_regular_s8.cpp ++++ b/src/cpu/reorder/cpu_reorder_regular_s8.cpp +@@ -1,5 +1,6 @@ + /******************************************************************************* + * Copyright 2020-2022 Intel Corporation ++* Copyright 2022 FUJITSU LIMITED + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. +@@ -41,6 +42,7 @@ const impl_list_map_t ®ular_s8_impl_list_map() { + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_blk_reorder_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_t)) + ++ DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_blk_reorder_t)) + DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_uni_reorder_t)) + + DNNL_NON_X64_ONLY(REG_SR_BIDIR(s8, any, f32, nChw16c)) +diff --git a/src/cpu/reorder/cpu_reorder_regular_u8.cpp b/src/cpu/reorder/cpu_reorder_regular_u8.cpp +index 73d731c3b15..87a58872262 100644 +--- a/src/cpu/reorder/cpu_reorder_regular_u8.cpp ++++ b/src/cpu/reorder/cpu_reorder_regular_u8.cpp +@@ -1,5 +1,6 @@ + /******************************************************************************* + * Copyright 2020-2022 Intel Corporation ++* Copyright 2022 FUJITSU LIMITED + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. +@@ -35,6 +36,7 @@ const impl_list_map_t ®ular_u8_impl_list_map() { + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_blk_reorder_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_t)) + ++ DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_blk_reorder_t)) + DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_uni_reorder_t)) + + DNNL_NON_X64_ONLY(REG_SR_BIDIR(u8, any, f32, nChw16c)) From 3e900dd214a14edcaef444513eb4e8fe997477ee Mon Sep 17 00:00:00 2001 From: Tom Allsop Date: Wed, 28 Jun 2023 13:40:23 +0100 Subject: [PATCH 027/376] Legalization for quantized int8 tfl.squared_difference operator * Added legalization for int8 tfl.squared_difference --- .../mlir/tosa/tests/tfl-to-tosa-pipeline.mlir | 23 ++++++ .../mlir/tosa/transforms/legalize_common.cc | 72 +++++++++++++++++++ 2 files changed, 95 insertions(+) diff --git a/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir b/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir index add1d1bc541259..2801bdf50b4269 100644 --- a/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir +++ b/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir @@ -2779,3 +2779,26 @@ func.func @test_imag_non_complex(%arg0: tensor<1x8x9xf32>) -> (tensor<1x8x9xf32> %0 = "tfl.imag"(%arg0) {} : (tensor<1x8x9xf32>) -> tensor<1x8x9xf32> return %0 : tensor<1x8x9xf32> } + +// ----- + +// CHECK-LABEL: test_squared_difference_qi8 +// CHECK-DAG: %[[VAR0:.*]] = "tosa.rescale"(%arg0) +// CHECK-DAG: %[[VAR1:.*]] = "tosa.rescale"(%arg1) +// CHECK-DAG: %[[VAR2:.*]] = "tosa.sub"(%[[VAR0]], %[[VAR1]]) +// CHECK-DAG: %[[VAR3:.*]] = "tosa.mul"(%[[VAR2]], %[[VAR2]]) {shift = 0 : i32} +// CHECK: %[[VAR4:.*]] = "tosa.rescale"(%[[VAR3]]) +func.func @test_squared_difference_qi8(%arg0: tensor<1x197x768x!quant.uniform>, %arg1: tensor<1x197x1x!quant.uniform>) -> tensor<1x197x768x!quant.uniform> { + %0 = "tfl.squared_difference"(%arg0, %arg1) : (tensor<1x197x768x!quant.uniform>, tensor<1x197x1x!quant.uniform>) -> tensor<1x197x768x!quant.uniform> + func.return %0 : tensor<1x197x768x!quant.uniform> +} + +// ----- + +// CHECK-LABEL: test_squared_difference_f32 +// CHECK-DAG: %[[VAR0:.*]] = "tosa.sub"(%arg0, %arg1) +// CHECK-DAG: %[[VAR1:.*]] = "tosa.mul"(%[[VAR0]], %[[VAR0]]) +func.func @test_squared_difference_f32(%arg0: tensor<1x197x768xf32>, %arg1: tensor<1x197x1xf32>) -> tensor<1x197x768xf32> { + %0 = "tfl.squared_difference"(%arg0, %arg1) : (tensor<1x197x768xf32>, tensor<1x197x1xf32>) -> tensor<1x197x768xf32> + func.return %0 : tensor<1x197x768xf32> +} diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc b/tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc index b72a54a64b4058..5ff27ad8750e89 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc +++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc @@ -583,6 +583,78 @@ std::optional convertSquaredDifferenceOp(PatternRewriter& rewriter, return std::nullopt; } + bool x_is_qtype = + x_type.getElementType().isa(); + bool y_is_qtype = + y_type.getElementType().isa(); + bool result_is_qtype = + result_type.getElementType().isa(); + + if (x_is_qtype != result_is_qtype || y_is_qtype != result_is_qtype) { + (void)rewriter.notifyMatchFailure( + op, + "input/output tensor should all be in FP32, INT32 or quantized INT8"); + return std::nullopt; + } + + // If the output is I8 then we need to rescale to I32 + // Then scale back to I8 + if (result_is_qtype) { + auto x_qtype = + x_type.getElementType().cast(); + auto y_qtype = + y_type.getElementType().cast(); + auto result_qtype = + result_type.getElementType().cast(); + + uint32_t result_bits = result_qtype.getStorageTypeIntegralWidth(); + + if (result_bits == 8) { + ShapedType rescale_type = result_type.clone(rewriter.getI32Type()); + + // We need to make sure the inputs are rescaled correctly + // Following the behaviour defined here lite/kernels/squared_difference.cc + double in_x_scale = x_qtype.getScale(); + double in_y_scale = y_qtype.getScale(); + double result_scale = result_qtype.getScale(); + + double twice_max_input_scale = 2.0 * std::max(in_x_scale, in_y_scale); + + const int32_t LEFT_SHIFT = 7; + + double x_rescale_scale = in_x_scale / twice_max_input_scale; + double y_rescale_scale = in_y_scale / twice_max_input_scale; + double output_rescale_scale = + (twice_max_input_scale * twice_max_input_scale) / + ((static_cast(1 << LEFT_SHIFT * 2)) * result_scale); + + Value x_scaled = buildRescaleToInt32( + rewriter, op, x, + x_rescale_scale * static_cast(1 << LEFT_SHIFT), + x_qtype.getZeroPoint()); + Value y_scaled = buildRescaleToInt32( + rewriter, op, y, + y_rescale_scale * static_cast(1 << LEFT_SHIFT), + y_qtype.getZeroPoint()); + + auto sub_op = CreateOpAndInfer( + rewriter, op->getLoc(), rescale_type, x_scaled, y_scaled); + auto mul_op = CreateOpAndInfer( + rewriter, op->getLoc(), rescale_type, sub_op.getResult(), + sub_op.getResult(), 0); + + // Convert the operator back to the original type + return buildRescaleFromInt32(rewriter, op, result_type, mul_op, + output_rescale_scale, + result_qtype.getZeroPoint()); + } else { + (void)rewriter.notifyMatchFailure( + op, "Only FP32, INT32 or quantized INT8 is supported"); + return std::nullopt; + } + } + + // This will cover FP32/FP16/INT32 legalization auto sub_op = CreateOpAndInfer(rewriter, op->getLoc(), result_type, x, y); return CreateOpAndInfer(rewriter, op->getLoc(), result_type, From 213c1439b9476ced4ab9c02e112ee215a36b46f1 Mon Sep 17 00:00:00 2001 From: Austin Anderson Date: Wed, 28 Jun 2023 17:18:30 -0700 Subject: [PATCH 028/376] Before trying big redo --- ci/official/any.sh | 45 +++ ci/official/bazelrcs/cpu.bazelrc | 96 ++++++ ci/official/bazelrcs/cpu_gcc.bazelrc | 85 +++++ ci/official/bazelrcs/gpu.bazelrc | 127 ++++++++ ci/official/bazelrcs/nvidia.bazelrc | 127 ++++++++ ci/official/code_check_changed_files.sh | 15 + ci/official/code_check_full.sh | 15 + ci/official/envs/env.local_cpu | 26 ++ ci/official/envs/env.nightly_cpu | 33 ++ ci/official/libtensorflow.sh | 34 ++ ci/official/pycpp.sh | 22 ++ .../utilities/code_check_changed_files.bats | 76 +++++ ci/official/utilities/code_check_full.bats | 307 ++++++++++++++++++ ci/official/utilities/copybara.sh | 30 ++ ci/official/utilities/docker.sh | 20 ++ ci/official/utilities/generate_index_html.sh | 40 +++ .../utilities/rename_and_verify_wheels.sh | 35 ++ ci/official/utilities/repack_libtensorflow.sh | 71 ++++ ci/official/utilities/wheel_verification.bats | 78 +++++ ci/official/wheel.sh | 33 ++ 20 files changed, 1315 insertions(+) create mode 100755 ci/official/any.sh create mode 100644 ci/official/bazelrcs/cpu.bazelrc create mode 100644 ci/official/bazelrcs/cpu_gcc.bazelrc create mode 100644 ci/official/bazelrcs/gpu.bazelrc create mode 100644 ci/official/bazelrcs/nvidia.bazelrc create mode 100755 ci/official/code_check_changed_files.sh create mode 100755 ci/official/code_check_full.sh create mode 100644 ci/official/envs/env.local_cpu create mode 100644 ci/official/envs/env.nightly_cpu create mode 100755 ci/official/libtensorflow.sh create mode 100755 ci/official/pycpp.sh create mode 100644 ci/official/utilities/code_check_changed_files.bats create mode 100644 ci/official/utilities/code_check_full.bats create mode 100755 ci/official/utilities/copybara.sh create mode 100755 ci/official/utilities/docker.sh create mode 100755 ci/official/utilities/generate_index_html.sh create mode 100755 ci/official/utilities/rename_and_verify_wheels.sh create mode 100755 ci/official/utilities/repack_libtensorflow.sh create mode 100644 ci/official/utilities/wheel_verification.bats create mode 100755 ci/official/wheel.sh diff --git a/ci/official/any.sh b/ci/official/any.sh new file mode 100755 index 00000000000000..43bba1a2bec474 --- /dev/null +++ b/ci/official/any.sh @@ -0,0 +1,45 @@ +#!/bin/bash +# -e: abort script if one command fails +# -u: error if undefined variable used +# -o pipefail: entire command fails if pipe fails. watch out for yes | ... +# -o history: record shell history +set -euxo pipefail -o history +set -o allexport && source "$TFCI" && set +o allexport + +_TFCI_HOST_ARTIFACTS_DIR="$TFCI_RUNTIME_ARTIFACTS_DIR" +tfrun() { "$@"; } +[[ "$TFCI_COPYBARA_ENABLE" = 1 ]] && source $TFCI_RUNTIME_USERTOOLS_DIR/copybara.sh +[[ "$TFCI_DOCKER_ENABLE" = 1 ]] && source $TFCI_RUNTIME_USERTOOLS_DIR/docker.sh +"$TFCI_RUNTIME_USERTOOLS_DIR/generate_index_html.sh" "$TFCI_RUNTIME_ARTIFACTS_DIR/index.html" + +# Parse options and build targets into arrays, so that shelllint doesn't yell +# about readability. We can't pipe into 'read -ra' to create an array because +# piped commands run in subshells, which can't store variables outside of the +# subshell environment. +# See https://g3doc.corp.google.com/devtools/staticanalysis/pipeline/analyzers/shell/lint/g3doc/findings/SC2086.md?cl=head +# Ignore grep failures since we're using it for basic filtering +set +e +filtered_build_targets=( $(echo "$BUILD_TARGETS" | tr ' ' '\n' | grep . | tee build_targets.txt) ) +nonpip_targets=( $(echo "$TEST_TARGETS" | tr ' ' '\n' | grep -E "^//tensorflow/" | tee nonpip_targets.txt) ) +config=( $(echo "$CONFIG_OPTIONS" ) ) +test_flags=( $(echo "$TEST_FLAGS" ) ) +set -e + +[[ "$TFCI_NVIDIA_SMI_ENABLE" = 1 ]] && tfrun nvidia-smi + +if [[ -s build_targets.txt ]]; then + tfrun bazel "${TFCI_BAZEL_BAZELRC_ARGS[@]}" "${config[@]}" "${filtered_build_targets[@]}" +fi + +if [[ "${PIP_WHEEL}" -eq "1" ]]; then + # Update the version numbers to build a "nightly" package + [[ "$TFCI_NIGHTLY_UPDATE_VERSION_ENABLE" = 1 ]] && tfrun python3 tensorflow/tools/ci_build/update_version.py --nightly + + tfrun bazel "${TFCI_BAZEL_BAZELRC_ARGS[@]}" build "${TFCI_BAZEL_CACHE_ARGS[@]}" tensorflow/tools/pip_package:build_pip_package + tfrun ./bazel-bin/tensorflow/tools/pip_package/build_pip_package "$TFCI_RUNTIME_ARTIFACTS_DIR" "${TFCI_BUILD_PIP_PACKAGE_ARGS[@]}" + tfrun "$TFCI_RUNTIME_USERTOOLS_DIR/rename_and_verify_wheels.sh" +fi + +if [[ -s nonpip_targets.txt ]]; then + tfrun bazel "${TFCI_BAZEL_BAZELRC_ARGS[@]}" test "${config[@]}" "${test_flags[@]}" "${nonpip_targets[@]}" +fi diff --git a/ci/official/bazelrcs/cpu.bazelrc b/ci/official/bazelrcs/cpu.bazelrc new file mode 100644 index 00000000000000..3a324603bdf0ce --- /dev/null +++ b/ci/official/bazelrcs/cpu.bazelrc @@ -0,0 +1,96 @@ +# This bazelrc can build a CPU-supporting TF package. + +# Convenient cache configurations +# Use a cache directory mounted to /tf/cache. Very useful! +build:sigbuild_local_cache --disk_cache=/tf/cache +# Use the public-access TF DevInfra cache (read only) +build:sigbuild_remote_cache --remote_cache="https://storage.googleapis.com/tensorflow-devinfra-bazel-cache/september2022" --remote_upload_local_results=false +# Write to the TF DevInfra cache (only works for internal TF CI) +build:sigbuild_remote_cache_push --remote_cache="https://storage.googleapis.com/tensorflow-devinfra-bazel-cache/september2022" --google_default_credentials +# Change the value of CACHEBUSTER when upgrading the toolchain, or when testing +# different compilation methods. E.g. for a PR to test a new CUDA version, set +# the CACHEBUSTER to the PR number. +build --action_env=CACHEBUSTER=501872366 + +# Use Python 3.X as installed in container image +build --action_env PYTHON_BIN_PATH="/usr/bin/python3" +build --action_env PYTHON_LIB_PATH="/usr/lib/tf_python" +build --python_path="/usr/bin/python3" + +# Build TensorFlow v2 +build --define=tf_api_version=2 --action_env=TF2_BEHAVIOR=1 + +# Target the AVX instruction set +build --copt=-mavx --host_copt=-mavx + +# Use lld as the linker +build --linkopt="-fuse-ld=lld" +build --linkopt="-lm" + +# Disable clang extention that rejects type definitions within offsetof. +# This was added in clang-16 by https://reviews.llvm.org/D133574. +# Can be removed once upb is updated, since a type definition is used within +# offset of in the current version of ubp. +# See https://github.com/protocolbuffers/upb/blob/9effcbcb27f0a665f9f345030188c0b291e32482/upb/upb.c#L183. +build --copt=-Wno-gnu-offsetof-extensions + +# Store performance profiling log in the mounted artifact directory. +# The profile can be viewed by visiting chrome://tracing in a Chrome browser. +# See https://docs.bazel.build/versions/main/skylark/performance.html#performance-profiling +build --profile=/tf/pkg/profile.json.gz + +# Use the NVCC toolchain to compile for manylinux2014 +build --crosstool_top="@sigbuild-r2.14-clang_config_cuda//crosstool:toolchain" + +# Test-related settings below this point. +test --build_tests_only --keep_going --test_output=errors --verbose_failures=true +test --local_test_jobs=HOST_CPUS +test --test_env=LD_LIBRARY_PATH +# Give only the list of failed tests at the end of the log +test --test_summary=short + +# "nonpip" tests are regular py_test tests. +# Pass --config=nonpip to run the same suite of tests. If you want to run just +# one test for investigation, you don't need --config=nonpip; just run the +# bazel test invocation as normal. +test:nonpip_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 +test:nonpip_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 +test:nonpip_filters --test_lang_filters=py --test_size_filters=small,medium +test:nonpip --config=nonpip_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/compiler/xrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... + +# For building libtensorflow archives +test:libtensorflow_test -- //tensorflow/tools/lib_package:libtensorflow_test //tensorflow/tools/lib_package:libtensorflow_java_test +build:libtensorflow_build -- //tensorflow/tools/lib_package:libtensorflow.tar.gz //tensorflow/tools/lib_package:libtensorflow_jni.tar.gz //tensorflow/java:libtensorflow.jar //tensorflow/java:libtensorflow-src.jar //tensorflow/tools/lib_package:libtensorflow_proto.zip + +# For outputting Build Event Protocol files +build:build_event_export --build_event_json_file=/tf/pkg/bep.json + +# For Remote Build Execution. +build:rbe --google_default_credentials +build:rbe --bes_backend=buildeventservice.googleapis.com +build:rbe --bes_results_url="https://source.cloud.google.com/results/invocations" +build:rbe --bes_timeout=600s +build:rbe --define=EXECUTOR=remote +build:rbe --jobs=800 +build:rbe --remote_executor=grpcs://remotebuildexecution.googleapis.com +build:rbe --remote_timeout=3600 +build:rbe --spawn_strategy=remote,worker,standalone,local +build:rbe --remote_download_toplevel +build:rbe --action_env=PATH="/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/local/go/bin" +build:rbe --linkopt=-lrt --host_linkopt=-lrt --linkopt=-lm --host_linkopt=-lm # Unclear why this is here +build:rbe --host_crosstool_top="@sigbuild-r2.14-clang_config_cuda//crosstool:toolchain" +build:rbe --crosstool_top="@sigbuild-r2.14-clang_config_cuda//crosstool:toolchain" +build:rbe --extra_toolchains="@sigbuild-r2.14-clang_config_cuda//crosstool:toolchain-linux-x86_64" +build:rbe --extra_execution_platforms="@sigbuild-r2.14-clang_config_platform//:platform" +build:rbe --host_platform="@sigbuild-r2.14-clang_config_platform//:platform" +build:rbe --platforms="@sigbuild-r2.14-clang_config_platform//:platform" +# Python config is the same across all containers because the binary is the same +build:rbe --repo_env=TF_PYTHON_CONFIG_REPO="@sigbuild-r2.14-clang_config_python" +build:rbe --remote_instance_name=projects/tensorflow-testing/instances/default_instance +build:rbe --project_id="tensorflow-testing" + +# For continuous builds +test:pycpp_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only +test:pycpp_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only +test:pycpp_filters --test_lang_filters=cc,py --test_size_filters=small,medium +test:pycpp --config=pycpp_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/compiler/xrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... diff --git a/ci/official/bazelrcs/cpu_gcc.bazelrc b/ci/official/bazelrcs/cpu_gcc.bazelrc new file mode 100644 index 00000000000000..cc74fd978cfade --- /dev/null +++ b/ci/official/bazelrcs/cpu_gcc.bazelrc @@ -0,0 +1,85 @@ +# This bazelrc can build a CPU-supporting TF package. + +# Convenient cache configurations +# Use a cache directory mounted to /tf/cache. Very useful! +build:sigbuild_local_cache --disk_cache=/tf/cache +# Use the public-access TF DevInfra cache (read only) +build:sigbuild_remote_cache --remote_cache="https://storage.googleapis.com/tensorflow-devinfra-bazel-cache/september2022" --remote_upload_local_results=false +# Write to the TF DevInfra cache (only works for internal TF CI) +build:sigbuild_remote_cache_push --remote_cache="https://storage.googleapis.com/tensorflow-devinfra-bazel-cache/september2022" --google_default_credentials +# Change the value of CACHEBUSTER when upgrading the toolchain, or when testing +# different compilation methods. E.g. for a PR to test a new CUDA version, set +# the CACHEBUSTER to the PR number. +build --action_env=CACHEBUSTER=501872366 + +# Use Python 3.X as installed in container image +build --action_env PYTHON_BIN_PATH="/usr/bin/python3" +build --action_env PYTHON_LIB_PATH="/usr/lib/tf_python" +build --python_path="/usr/bin/python3" + +# Build TensorFlow v2 +build --define=tf_api_version=2 --action_env=TF2_BEHAVIOR=1 + +# Target the AVX instruction set +build --copt=-mavx --host_copt=-mavx + +# Store performance profiling log in the mounted artifact directory. +# The profile can be viewed by visiting chrome://tracing in a Chrome browser. +# See https://docs.bazel.build/versions/main/skylark/performance.html#performance-profiling +build --profile=/tf/pkg/profile.json.gz + +# Use the NVCC toolchain to compile for manylinux2014 +build --crosstool_top="@sigbuild-r2.14_config_cuda//crosstool:toolchain" + +# Test-related settings below this point. +test --build_tests_only --keep_going --test_output=errors --verbose_failures=true +test --local_test_jobs=HOST_CPUS +test --test_env=LD_LIBRARY_PATH +# Give only the list of failed tests at the end of the log +test --test_summary=short + +# "nonpip" tests are regular py_test tests. +# Pass --config=nonpip to run the same suite of tests. If you want to run just +# one test for investigation, you don't need --config=nonpip; just run the +# bazel test invocation as normal. +test:nonpip_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 +test:nonpip_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 +test:nonpip_filters --test_lang_filters=py --test_size_filters=small,medium +test:nonpip --config=nonpip_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/compiler/xrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... + +# For building libtensorflow archives +test:libtensorflow_test -- //tensorflow/tools/lib_package:libtensorflow_test //tensorflow/tools/lib_package:libtensorflow_java_test +build:libtensorflow_build -- //tensorflow/tools/lib_package:libtensorflow.tar.gz //tensorflow/tools/lib_package:libtensorflow_jni.tar.gz //tensorflow/java:libtensorflow.jar //tensorflow/java:libtensorflow-src.jar //tensorflow/tools/lib_package:libtensorflow_proto.zip + +# For outputting Build Event Protocol files +build:build_event_export --build_event_json_file=/tf/pkg/bep.json + +# For Remote Build Execution. +build:rbe --google_default_credentials +build:rbe --bes_backend=buildeventservice.googleapis.com +build:rbe --bes_results_url="https://source.cloud.google.com/results/invocations" +build:rbe --bes_timeout=600s +build:rbe --define=EXECUTOR=remote +build:rbe --jobs=800 +build:rbe --remote_executor=grpcs://remotebuildexecution.googleapis.com +build:rbe --remote_timeout=3600 +build:rbe --spawn_strategy=remote,worker,standalone,local +build:rbe --remote_download_toplevel +build:rbe --action_env=PATH="/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/local/go/bin" +build:rbe --linkopt=-lrt --host_linkopt=-lrt --linkopt=-lm --host_linkopt=-lm # Unclear why this is here +build:rbe --host_crosstool_top="@sigbuild-r2.14_config_cuda//crosstool:toolchain" +build:rbe --crosstool_top="@sigbuild-r2.14_config_cuda//crosstool:toolchain" +build:rbe --extra_toolchains="@sigbuild-r2.14_config_cuda//crosstool:toolchain-linux-x86_64" +build:rbe --extra_execution_platforms="@sigbuild-r2.14_config_platform//:platform" +build:rbe --host_platform="@sigbuild-r2.14_config_platform//:platform" +build:rbe --platforms="@sigbuild-r2.14_config_platform//:platform" +# Python config is the same across all containers because the binary is the same +build:rbe --repo_env=TF_PYTHON_CONFIG_REPO="@sigbuild-r2.14_config_python" +build:rbe --remote_instance_name=projects/tensorflow-testing/instances/default_instance +build:rbe --project_id="tensorflow-testing" + +# For continuous builds +test:pycpp_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only +test:pycpp_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only +test:pycpp_filters --test_lang_filters=cc,py --test_size_filters=small,medium +test:pycpp --config=pycpp_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/compiler/xrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... diff --git a/ci/official/bazelrcs/gpu.bazelrc b/ci/official/bazelrcs/gpu.bazelrc new file mode 100644 index 00000000000000..50ea575205967c --- /dev/null +++ b/ci/official/bazelrcs/gpu.bazelrc @@ -0,0 +1,127 @@ +# This bazelrc can build a GPU-supporting TF package. + +# Convenient cache configurations +# Use a cache directory mounted to /tf/cache. Very useful! +build:sigbuild_local_cache --disk_cache=/tf/cache +# Use the public-access TF DevInfra cache (read only) +build:sigbuild_remote_cache --remote_cache="https://storage.googleapis.com/tensorflow-devinfra-bazel-cache/september2022" --remote_upload_local_results=false +# Write to the TF DevInfra cache (only works for internal TF CI) +build:sigbuild_remote_cache_push --remote_cache="https://storage.googleapis.com/tensorflow-devinfra-bazel-cache/september2022" --google_default_credentials +# Change the value of CACHEBUSTER when upgrading the toolchain, or when testing +# different compilation methods. E.g. for a PR to test a new CUDA version, set +# the CACHEBUSTER to the PR number. +build --action_env=CACHEBUSTER=501872366 + +# Use Python 3.X as installed in container image +build --action_env PYTHON_BIN_PATH="/usr/bin/python3" +build --action_env PYTHON_LIB_PATH="/usr/lib/tf_python" +build --python_path="/usr/bin/python3" + +# Build TensorFlow v2 +build --define=tf_api_version=2 --action_env=TF2_BEHAVIOR=1 + +# Target the AVX instruction set +build --copt=-mavx --host_copt=-mavx + +# Disable clang extention that rejects type definitions within offsetof. +# This was added in clang-16 by https://reviews.llvm.org/D133574. +# Can be removed once upb is updated, since a type definition is used within +# offset of in the current version of ubp. +# See https://github.com/protocolbuffers/upb/blob/9effcbcb27f0a665f9f345030188c0b291e32482/upb/upb.c#L183. +build --copt=-Wno-gnu-offsetof-extensions + +# Use lld as the linker +build --linkopt="-fuse-ld=lld" +build --linkopt="-lm" + +# Store performance profiling log in the mounted artifact directory. +# The profile can be viewed by visiting chrome://tracing in a Chrome browser. +# See https://docs.bazel.build/versions/main/skylark/performance.html#performance-profiling +build --profile=/tf/pkg/profile.json.gz + +# CUDA: Set up compilation CUDA version and paths +build --@local_config_cuda//:enable_cuda +build --@local_config_cuda//:cuda_compiler=clang +build --repo_env TF_NEED_CUDA=1 +build --config cuda_clang +build --action_env=TF_CUDA_VERSION="11" +build --action_env=TF_CUDNN_VERSION="8" +build --action_env=CUDA_TOOLKIT_PATH="/usr/local/cuda-11.8" +build --action_env=GCC_HOST_COMPILER_PATH="/dt9/usr/bin/gcc" +build --action_env=CLANG_CUDA_COMPILER_PATH="/usr/lib/llvm-16/bin/clang" +build --action_env=TF_CUDA_CLANG="1" +build --action_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:/usr/local/tensorrt/lib" +build --crosstool_top="@sigbuild-r2.14-clang_config_cuda//crosstool:toolchain" + +# CUDA: Enable TensorRT optimizations +# https://developer.nvidia.com/tensorrt +build --repo_env TF_NEED_TENSORRT=1 + +# CUDA: Select supported compute capabilities (supported graphics cards). +# This is the same as the official TensorFlow builds. +# See https://developer.nvidia.com/cuda-gpus#compute +# TODO(angerson, perfinion): What does sm_ vs compute_ mean? +# TODO(angerson, perfinion): How can users select a good value for this? +build --repo_env=TF_CUDA_COMPUTE_CAPABILITIES="sm_35,sm_50,sm_60,sm_70,sm_75,compute_80" + +# Test-related settings below this point. +test --build_tests_only --keep_going --test_output=errors --verbose_failures=true +test --test_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64" +# Local test jobs has to be 4 because parallel_gpu_execute is fragile, I think +test --test_timeout=300,450,1200,3600 --local_test_jobs=4 --run_under=//tensorflow/tools/ci_build/gpu_build:parallel_gpu_execute +# Give only the list of failed tests at the end of the log +test --test_summary=short + +# "nonpip" tests are regular py_test tests. +# Pass --config=nonpip to run the same suite of tests. If you want to run just +# one test for investigation, you don't need --config=nonpip; just run the +# bazel test invocation as normal. +test:nonpip_filters --test_tag_filters=gpu,requires-gpu,-no_gpu,-no_oss,-oss_excluded,-oss_serial,-no_cuda11,-no_oss_py38,-no_oss_py39,-no_oss_py310 +test:nonpip_filters --build_tag_filters=gpu,requires-gpu,-no_gpu,-no_oss,-oss_excluded,-oss_serial,-no_cuda11,-no_oss_py38,-no_oss_py39,-no_oss_py310 +test:nonpip_filters --test_lang_filters=py --test_size_filters=small,medium +test:nonpip --config=nonpip_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/compiler/xrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... + +# For building libtensorflow archives +test:libtensorflow_test -- //tensorflow/tools/lib_package:libtensorflow_test //tensorflow/tools/lib_package:libtensorflow_java_test +build:libtensorflow_build -- //tensorflow/tools/lib_package:libtensorflow.tar.gz //tensorflow/tools/lib_package:libtensorflow_jni.tar.gz //tensorflow/java:libtensorflow.jar //tensorflow/java:libtensorflow-src.jar //tensorflow/tools/lib_package:libtensorflow_proto.zip + +# For outputting Build Event Protocol files +build:build_event_export --build_event_json_file=/tf/pkg/bep.json + +# For Remote Build Execution. +build:rbe --google_default_credentials +build:rbe --bes_backend=buildeventservice.googleapis.com +build:rbe --bes_results_url="https://source.cloud.google.com/results/invocations" +build:rbe --bes_timeout=600s +build:rbe --define=EXECUTOR=remote +build:rbe --jobs=800 +build:rbe --remote_executor=grpcs://remotebuildexecution.googleapis.com +build:rbe --remote_timeout=3600 +build:rbe --spawn_strategy=remote,worker,standalone,local +build:rbe --remote_download_toplevel +build:rbe --action_env=PATH="/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/local/go/bin" +build:rbe --linkopt=-lrt --host_linkopt=-lrt --linkopt=-lm --host_linkopt=-lm # Unclear why this is here +build:rbe --host_crosstool_top="@sigbuild-r2.14-clang_config_cuda//crosstool:toolchain" +build:rbe --crosstool_top="@sigbuild-r2.14-clang_config_cuda//crosstool:toolchain" +build:rbe --extra_toolchains="@sigbuild-r2.14-clang_config_cuda//crosstool:toolchain-linux-x86_64" +build:rbe --extra_execution_platforms="@sigbuild-r2.14-clang_config_platform//:platform" +build:rbe --host_platform="@sigbuild-r2.14-clang_config_platform//:platform" +build:rbe --platforms="@sigbuild-r2.14-clang_config_platform//:platform" +# Python config is the same across all containers because the binary is the same +build:rbe --repo_env=TF_PYTHON_CONFIG_REPO="@sigbuild-r2.14-clang_config_python" +build:rbe --remote_instance_name=projects/tensorflow-testing/instances/default_instance +build:rbe --project_id="tensorflow-testing" + +# For Remote build execution -- GPU configuration +build:rbe --repo_env=REMOTE_GPU_TESTING=1 +test:rbe --test_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64" +build:rbe --repo_env=TF_CUDA_CONFIG_REPO="@sigbuild-r2.14-clang_config_cuda" +build:rbe --repo_env=TF_TENSORRT_CONFIG_REPO="@sigbuild-r2.14-clang_config_tensorrt" +build:rbe --repo_env=TF_NCCL_CONFIG_REPO="@sigbuild-r2.14-clang_config_nccl" +build:rbe --repo_env=TF_PYTHON_CONFIG_REPO="@sigbuild-r2.14-clang_config_python" + +# For continuous builds +test:pycpp_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-benchmark-test,-v1only,gpu,-no_gpu,-no_gpu_presubmit,-no_cuda11 +test:pycpp_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-benchmark-test,-v1only,gpu,-no_gpu,-no_gpu_presubmit,-no_cuda11 +test:pycpp_filters --test_lang_filters=cc,py --test_size_filters=small,medium +test:pycpp --config=pycpp_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/compiler/xrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... diff --git a/ci/official/bazelrcs/nvidia.bazelrc b/ci/official/bazelrcs/nvidia.bazelrc new file mode 100644 index 00000000000000..50ea575205967c --- /dev/null +++ b/ci/official/bazelrcs/nvidia.bazelrc @@ -0,0 +1,127 @@ +# This bazelrc can build a GPU-supporting TF package. + +# Convenient cache configurations +# Use a cache directory mounted to /tf/cache. Very useful! +build:sigbuild_local_cache --disk_cache=/tf/cache +# Use the public-access TF DevInfra cache (read only) +build:sigbuild_remote_cache --remote_cache="https://storage.googleapis.com/tensorflow-devinfra-bazel-cache/september2022" --remote_upload_local_results=false +# Write to the TF DevInfra cache (only works for internal TF CI) +build:sigbuild_remote_cache_push --remote_cache="https://storage.googleapis.com/tensorflow-devinfra-bazel-cache/september2022" --google_default_credentials +# Change the value of CACHEBUSTER when upgrading the toolchain, or when testing +# different compilation methods. E.g. for a PR to test a new CUDA version, set +# the CACHEBUSTER to the PR number. +build --action_env=CACHEBUSTER=501872366 + +# Use Python 3.X as installed in container image +build --action_env PYTHON_BIN_PATH="/usr/bin/python3" +build --action_env PYTHON_LIB_PATH="/usr/lib/tf_python" +build --python_path="/usr/bin/python3" + +# Build TensorFlow v2 +build --define=tf_api_version=2 --action_env=TF2_BEHAVIOR=1 + +# Target the AVX instruction set +build --copt=-mavx --host_copt=-mavx + +# Disable clang extention that rejects type definitions within offsetof. +# This was added in clang-16 by https://reviews.llvm.org/D133574. +# Can be removed once upb is updated, since a type definition is used within +# offset of in the current version of ubp. +# See https://github.com/protocolbuffers/upb/blob/9effcbcb27f0a665f9f345030188c0b291e32482/upb/upb.c#L183. +build --copt=-Wno-gnu-offsetof-extensions + +# Use lld as the linker +build --linkopt="-fuse-ld=lld" +build --linkopt="-lm" + +# Store performance profiling log in the mounted artifact directory. +# The profile can be viewed by visiting chrome://tracing in a Chrome browser. +# See https://docs.bazel.build/versions/main/skylark/performance.html#performance-profiling +build --profile=/tf/pkg/profile.json.gz + +# CUDA: Set up compilation CUDA version and paths +build --@local_config_cuda//:enable_cuda +build --@local_config_cuda//:cuda_compiler=clang +build --repo_env TF_NEED_CUDA=1 +build --config cuda_clang +build --action_env=TF_CUDA_VERSION="11" +build --action_env=TF_CUDNN_VERSION="8" +build --action_env=CUDA_TOOLKIT_PATH="/usr/local/cuda-11.8" +build --action_env=GCC_HOST_COMPILER_PATH="/dt9/usr/bin/gcc" +build --action_env=CLANG_CUDA_COMPILER_PATH="/usr/lib/llvm-16/bin/clang" +build --action_env=TF_CUDA_CLANG="1" +build --action_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:/usr/local/tensorrt/lib" +build --crosstool_top="@sigbuild-r2.14-clang_config_cuda//crosstool:toolchain" + +# CUDA: Enable TensorRT optimizations +# https://developer.nvidia.com/tensorrt +build --repo_env TF_NEED_TENSORRT=1 + +# CUDA: Select supported compute capabilities (supported graphics cards). +# This is the same as the official TensorFlow builds. +# See https://developer.nvidia.com/cuda-gpus#compute +# TODO(angerson, perfinion): What does sm_ vs compute_ mean? +# TODO(angerson, perfinion): How can users select a good value for this? +build --repo_env=TF_CUDA_COMPUTE_CAPABILITIES="sm_35,sm_50,sm_60,sm_70,sm_75,compute_80" + +# Test-related settings below this point. +test --build_tests_only --keep_going --test_output=errors --verbose_failures=true +test --test_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64" +# Local test jobs has to be 4 because parallel_gpu_execute is fragile, I think +test --test_timeout=300,450,1200,3600 --local_test_jobs=4 --run_under=//tensorflow/tools/ci_build/gpu_build:parallel_gpu_execute +# Give only the list of failed tests at the end of the log +test --test_summary=short + +# "nonpip" tests are regular py_test tests. +# Pass --config=nonpip to run the same suite of tests. If you want to run just +# one test for investigation, you don't need --config=nonpip; just run the +# bazel test invocation as normal. +test:nonpip_filters --test_tag_filters=gpu,requires-gpu,-no_gpu,-no_oss,-oss_excluded,-oss_serial,-no_cuda11,-no_oss_py38,-no_oss_py39,-no_oss_py310 +test:nonpip_filters --build_tag_filters=gpu,requires-gpu,-no_gpu,-no_oss,-oss_excluded,-oss_serial,-no_cuda11,-no_oss_py38,-no_oss_py39,-no_oss_py310 +test:nonpip_filters --test_lang_filters=py --test_size_filters=small,medium +test:nonpip --config=nonpip_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/compiler/xrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... + +# For building libtensorflow archives +test:libtensorflow_test -- //tensorflow/tools/lib_package:libtensorflow_test //tensorflow/tools/lib_package:libtensorflow_java_test +build:libtensorflow_build -- //tensorflow/tools/lib_package:libtensorflow.tar.gz //tensorflow/tools/lib_package:libtensorflow_jni.tar.gz //tensorflow/java:libtensorflow.jar //tensorflow/java:libtensorflow-src.jar //tensorflow/tools/lib_package:libtensorflow_proto.zip + +# For outputting Build Event Protocol files +build:build_event_export --build_event_json_file=/tf/pkg/bep.json + +# For Remote Build Execution. +build:rbe --google_default_credentials +build:rbe --bes_backend=buildeventservice.googleapis.com +build:rbe --bes_results_url="https://source.cloud.google.com/results/invocations" +build:rbe --bes_timeout=600s +build:rbe --define=EXECUTOR=remote +build:rbe --jobs=800 +build:rbe --remote_executor=grpcs://remotebuildexecution.googleapis.com +build:rbe --remote_timeout=3600 +build:rbe --spawn_strategy=remote,worker,standalone,local +build:rbe --remote_download_toplevel +build:rbe --action_env=PATH="/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/local/go/bin" +build:rbe --linkopt=-lrt --host_linkopt=-lrt --linkopt=-lm --host_linkopt=-lm # Unclear why this is here +build:rbe --host_crosstool_top="@sigbuild-r2.14-clang_config_cuda//crosstool:toolchain" +build:rbe --crosstool_top="@sigbuild-r2.14-clang_config_cuda//crosstool:toolchain" +build:rbe --extra_toolchains="@sigbuild-r2.14-clang_config_cuda//crosstool:toolchain-linux-x86_64" +build:rbe --extra_execution_platforms="@sigbuild-r2.14-clang_config_platform//:platform" +build:rbe --host_platform="@sigbuild-r2.14-clang_config_platform//:platform" +build:rbe --platforms="@sigbuild-r2.14-clang_config_platform//:platform" +# Python config is the same across all containers because the binary is the same +build:rbe --repo_env=TF_PYTHON_CONFIG_REPO="@sigbuild-r2.14-clang_config_python" +build:rbe --remote_instance_name=projects/tensorflow-testing/instances/default_instance +build:rbe --project_id="tensorflow-testing" + +# For Remote build execution -- GPU configuration +build:rbe --repo_env=REMOTE_GPU_TESTING=1 +test:rbe --test_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64" +build:rbe --repo_env=TF_CUDA_CONFIG_REPO="@sigbuild-r2.14-clang_config_cuda" +build:rbe --repo_env=TF_TENSORRT_CONFIG_REPO="@sigbuild-r2.14-clang_config_tensorrt" +build:rbe --repo_env=TF_NCCL_CONFIG_REPO="@sigbuild-r2.14-clang_config_nccl" +build:rbe --repo_env=TF_PYTHON_CONFIG_REPO="@sigbuild-r2.14-clang_config_python" + +# For continuous builds +test:pycpp_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-benchmark-test,-v1only,gpu,-no_gpu,-no_gpu_presubmit,-no_cuda11 +test:pycpp_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-benchmark-test,-v1only,gpu,-no_gpu,-no_gpu_presubmit,-no_cuda11 +test:pycpp_filters --test_lang_filters=cc,py --test_size_filters=small,medium +test:pycpp --config=pycpp_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/compiler/xrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... diff --git a/ci/official/code_check_changed_files.sh b/ci/official/code_check_changed_files.sh new file mode 100755 index 00000000000000..49924bc2344eb5 --- /dev/null +++ b/ci/official/code_check_changed_files.sh @@ -0,0 +1,15 @@ +#!/bin/bash +# -e: abort script if one command fails +# -u: error if undefined variable used +# -o pipefail: entire command fails if pipe fails. watch out for yes | ... +# -o history: record shell history +set -euxo pipefail -o history +set -o allexport && source "$TFCI" && set +o allexport + +_TFCI_HOST_ARTIFACTS_DIR="$TFCI_RUNTIME_ARTIFACTS_DIR" +tfrun() { "$@"; } +[[ "$TFCI_COPYBARA_ENABLE" = 1 ]] && source $TFCI_RUNTIME_USERTOOLS_DIR/copybara.sh +[[ "$TFCI_DOCKER_ENABLE" = 1 ]] && source $TFCI_RUNTIME_USERTOOLS_DIR/docker.sh +"$TFCI_RUNTIME_USERTOOLS_DIR/generate_index_html.sh" "$TFCI_RUNTIME_ARTIFACTS_DIR/index.html" + +tfrun bats "$TFCI_RUNTIME_USERTOOLS_DIR"/code_check_changed_files.bats --timing --output "$TFCI_RUNTIME_ARTIFACTS_DIR" diff --git a/ci/official/code_check_full.sh b/ci/official/code_check_full.sh new file mode 100755 index 00000000000000..224b73375285d8 --- /dev/null +++ b/ci/official/code_check_full.sh @@ -0,0 +1,15 @@ +#!/bin/bash +# -e: abort script if one command fails +# -u: error if undefined variable used +# -o pipefail: entire command fails if pipe fails. watch out for yes | ... +# -o history: record shell history +set -euxo pipefail -o history +set -o allexport && source "$TFCI" && set +o allexport + +_TFCI_HOST_ARTIFACTS_DIR="$TFCI_RUNTIME_ARTIFACTS_DIR" +tfrun() { "$@"; } +[[ "$TFCI_COPYBARA_ENABLE" = 1 ]] && source $TFCI_RUNTIME_USERTOOLS_DIR/copybara.sh +[[ "$TFCI_DOCKER_ENABLE" = 1 ]] && source $TFCI_RUNTIME_USERTOOLS_DIR/docker.sh +"$TFCI_RUNTIME_USERTOOLS_DIR/generate_index_html.sh" "$TFCI_RUNTIME_ARTIFACTS_DIR/index.html" + +tfrun bats "$TFCI_RUNTIME_USERTOOLS_DIR"/code_check_full.bats --timing --output "$TFCI_RUNTIME_ARTIFACTS_DIR" diff --git a/ci/official/envs/env.local_cpu b/ci/official/envs/env.local_cpu new file mode 100644 index 00000000000000..0686157de77cee --- /dev/null +++ b/ci/official/envs/env.local_cpu @@ -0,0 +1,26 @@ +TFCI_BAZEL_BAZELRC_ARGS=(--bazelrc /usertools/cpu.bazelrc) +TFCI_BAZEL_CACHE_ARGS=(--config sigbuild_remote_cache) +TFCI_BUILD_PIP_PACKAGE_ARGS=("--cpu") +TFCI_COPYBARA_ENABLE=0 +TFCI_COPYBARA_SCRIPT_PATH= +TFCI_COPYBARA_GIT_DIR= +TFCI_DOCKER_ARTIFACTS_HOST_DIR=/tmp/tf +TFCI_DOCKER_GIT_HOST_DIR=/usr/local/google/home/angerson/repos/tensorflow +TFCI_DOCKER_ENABLE=1 +TFCI_DOCKER_GPU_ARGS=() +TFCI_DOCKER_IMAGE=tensorflow/build:latest-python3.9 +TFCI_DOCKER_PULL_ENABLE= +TFCI_LIB_SUFFIX="-cpu-linux-x86_64" +TFCI_NIGHTLY_UPDATE_VERSION_ENABLE= +TFCI_NVIDIA_SMI_ENABLE= +TFCI_RUNTIME_ARTIFACTS_DIR=/tf/artifacts +TFCI_RUNTIME_USERTOOLS_DIR=/tf/tensorflow/ci/official/utilities +TFCI_RUNTIME_GIT_DIR=/tf/tensorflow +TFCI_UPLOAD_LIB_ENABLE= +TFCI_UPLOAD_LIB_URI= +TFCI_UPLOAD_LIB_LATEST_ENABLE= +TFCI_UPLOAD_LIB_LATEST_URI= +TFCI_UPLOAD_WHL_GCS_ENABLE= +TFCI_UPLOAD_WHL_GCS_URI= +TFCI_UPLOAD_WHL_PYPI_ARGS= +TFCI_UPLOAD_WHL_PYPI_ENABLE= diff --git a/ci/official/envs/env.nightly_cpu b/ci/official/envs/env.nightly_cpu new file mode 100644 index 00000000000000..f0515d2ceff3c5 --- /dev/null +++ b/ci/official/envs/env.nightly_cpu @@ -0,0 +1,33 @@ +TFCI_BAZEL_BAZELRC_ARGS=(--bazelrc /usertools/cpu.bazelrc) +TFCI_BAZEL_CACHE_ARGS=(--config sigbuild_remote_cache_push) +TFCI_BUILD_PIP_PACKAGE_ARGS=("--cpu" "--nightly_flag") +TFCI_COPYBARA_ENABLE=1 +TFCI_COPYBARA_SCRIPT_PATH="$KOKORO_ARTIFACTS_DIR/google3/learning/brain/testing/kokoro/rel/docker/run_copybara_for_presubmit.sh" +TFCI_COPYBARA_GIT_DIR="$KOKORO_ARTIFACTS_DIR/google3/learning/brain/testing/kokoro/rel/docker/run_copybara_for_presubmit.sh" +TFCI_DOCKER_ARTIFACTS_HOST_DIR="$KOKORO_ARTIFACTS_DIR" +TFCI_DOCKER_GIT_HOST_DIR="$KOKORO_ARTIFACTS_DIR/github/tensorflow" +TFCI_DOCKER_ENABLE=1 +TFCI_DOCKER_GPU_ARGS=() +TFCI_DOCKER_IMAGE= +TFCI_DOCKER_PULL_ENABLE=1 +TFCI_LIB_SUFFIX="-cpu-linux-x86_64" +TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1 +TFCI_NVIDIA_SMI_ENABLE=1 +TFCI_RUNTIME_ARTIFACTS_DIR=/tf/artifacts +TFCI_RUNTIME_USERTOOLS_DIR=/usertools +TFCI_RUNTIME_GIT_DIR=/tf/tensorflow +TFCI_UPLOAD_LIB_ENABLE=1 +TFCI_UPLOAD_LIB_URI="gs://libtensorflow-nightly/$(date -I)" +#TFCI_UPLOAD_LIB_URI="gs://tensorflow-release-packages/$RELEASE_VERSION/$KOKORO_GIT_COMMIT_tensorflow" +TFCI_UPLOAD_LIB_LATEST_ENABLE=1 +TFCI_UPLOAD_LIB_LATEST_URI="gs://libtensorflow-nightly/latest" +TFCI_UPLOAD_WHL_GCS_ENABLE= +TFCI_UPLOAD_WHL_GCS_URI= +#TFCI_UPLOAD_WHL_GCS_URI="gs://tensorflow-release-packages/$RELEASE_VERSION/$KOKORO_GIT_COMMIT_tensorflow" +TFCI_UPLOAD_WHL_PYPI_ARGS=( --config-file "$KOKORO_KEYSTORE_DIR/73361_tensorflow_pypirc_using_global_api_token" --repository testpypi ) +TFCI_UPLOAD_WHL_PYPI_ENABLE=1 + + +TFCI_GIT_DIR=tensorflow + +tensorflow/ARTIFACTS diff --git a/ci/official/libtensorflow.sh b/ci/official/libtensorflow.sh new file mode 100755 index 00000000000000..6b539bf10c64d9 --- /dev/null +++ b/ci/official/libtensorflow.sh @@ -0,0 +1,34 @@ +#!/bin/bash +# -e: abort script if one command fails +# -u: error if undefined variable used +# -o pipefail: entire command fails if pipe fails. watch out for yes | ... +# -o history: record shell history +set -euxo pipefail -o history +set -o allexport && source "$TFCI" && set +o allexport + +# If this is a CL presubmit, then run Copybara on the Piper code and place it +# in the same directory as the GitHub source code would normally be. This lets +# the rest of the script proceed as normal. +_TFCI_HOST_ARTIFACTS_DIR="$TFCI_RUNTIME_ARTIFACTS_DIR" +tfrun() { "$@"; } +[[ "$TFCI_COPYBARA_ENABLE" = 1 ]] && source $TFCI_RUNTIME_USERTOOLS_DIR/copybara.sh +[[ "$TFCI_DOCKER_ENABLE" = 1 ]] && source $TFCI_RUNTIME_USERTOOLS_DIR/docker.sh +"$TFCI_RUNTIME_USERTOOLS_DIR/generate_index_html.sh" "$TFCI_RUNTIME_ARTIFACTS_DIR/index.html" + +# Record GPU count and CUDA version status +[[ "$TFCI_NVIDIA_SMI_ENABLE" = 1 ]] && tfrun nvidia-smi + +# Update the version numbers for Nightly only +[[ "$TFCI_NIGHTLY_UPDATE_VERSION_ENABLE" = 1 ]] && tfrun python3 tensorflow/tools/ci_build/update_version.py --nightly + +tfrun bazel "${TFCI_BAZEL_BAZELRC_ARGS[@]}" test "${TFCI_BAZEL_CACHE_ARGS[@]}" --config=libtensorflow_test +tfrun bazel "${TFCI_BAZEL_BAZELRC_ARGS[@]}" build "${TFCI_BAZEL_CACHE_ARGS[@]}" --config=libtensorflow_build + +tfrun "$TFCI_RUNTIME_USERTOOLS_DIR"/repack_libtensorflow.sh "$TFCI_RUNTIME_ARTIFACTS_DIR" "$TFCI_LIB_SUFFIX" + +if [[ "$TFCI_UPLOAD_LIB_ENABLE" = 1 ]]; then + gsutil cp "$_TFCI_HOST_ARTIFACTS_DIR"/*.tar.gz "$TFCI_UPLOAD_LIB_GCS_URI" + if [[ "$TFCI_UPLOAD_LIB_LATEST_ENABLE" = 1 ]]; then + gsutil cp "$_TFCI_HOST_ARTIFACTS_DIR"/*.tar.gz "$TFCI_UPLOAD_LIB_LATEST_GCS_URI" + fi +fi diff --git a/ci/official/pycpp.sh b/ci/official/pycpp.sh new file mode 100755 index 00000000000000..e9f61a886cae1c --- /dev/null +++ b/ci/official/pycpp.sh @@ -0,0 +1,22 @@ +#!/bin/bash +# -e: abort script if one command fails +# -u: error if undefined variable used +# -o pipefail: entire command fails if pipe fails. watch out for yes | ... +# -o history: record shell history +set -euxo pipefail -o history +set -o allexport && source "$TFCI" && set +o allexport + +# If this is a CL presubmit, then run Copybara on the Piper code and place it +# in the same directory as the GitHub source code would normally be. This lets +# the rest of the script proceed as normal. +_TFCI_HOST_ARTIFACTS_DIR="$TFCI_RUNTIME_ARTIFACTS_DIR" +tfrun() { "$@"; } +[[ "$TFCI_COPYBARA_ENABLE" = 1 ]] && source $TFCI_RUNTIME_USERTOOLS_DIR/copybara.sh +[[ "$TFCI_DOCKER_ENABLE" = 1 ]] && source $TFCI_RUNTIME_USERTOOLS_DIR/docker.sh +"$TFCI_RUNTIME_USERTOOLS_DIR/generate_index_html.sh" "$TFCI_RUNTIME_ARTIFACTS_DIR/index.html" + +# TODO(b/284172313) Revert this difference between presubmits and continuous. RBE serverside behavior is causing flakes, +# so we're temporarily allowing flaky tests again for presubmits. +tfrun bazel "${TFCI_BAZEL_BAZELRC_ARGS[@]}" test "${TFCI_BAZEL_CACHE_ARGS[@]}" --config=rbe --config=pycpp --config=build_event_export + +tfrun bazel analyze-profile $TFCI_RUNTIME_ART/profile.json.gz diff --git a/ci/official/utilities/code_check_changed_files.bats b/ci/official/utilities/code_check_changed_files.bats new file mode 100644 index 00000000000000..773b0847e943e1 --- /dev/null +++ b/ci/official/utilities/code_check_changed_files.bats @@ -0,0 +1,76 @@ +# vim: filetype=bash +# +# Copyright 2022 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +setup_file() { + cd "$TFCI_RUNTIME_GIT_DIR" + bazel version # Start the bazel server + # Without this, git errors if /tf/tensorflow directory owner is different + git config --global --add safe.directory "$TF_RUNTIME_GIT_DIR" + # Note that you could generate a list of all the affected targets with e.g.: + # bazel query $(paste -sd "+" $BATS_FILE_TMPDIR/changed_files) --keep_going + # Only shows Added, Changed, Modified, Renamed, and Type-changed files + if [[ "$(git rev-parse --abbrev-ref HEAD)" = "pull_branch" ]]; then + # TF's CI runs 'git fetch origin "pull/PR#/merge:pull_branch"' + # To get the as-merged branch during the CI tests + git diff --diff-filter ACMRT --name-only pull_branch^ pull_branch > $BATS_FILE_TMPDIR/changed_files + else + # If the branch is not present, then diff against origin/master + git diff --diff-filter ACMRT --name-only origin/master > $BATS_FILE_TMPDIR/changed_files + fi +} + +# Note: this is excluded on the full code base, since any submitted code must +# have passed Google's internal style guidelines. +@test "Check buildifier formatting on BUILD files" { + echo "buildifier formatting is recommended. Here are the suggested fixes:" + echo "=============================" + grep -e 'BUILD' $BATS_FILE_TMPDIR/changed_files \ + | xargs buildifier -v -mode=diff -diff_command="git diff --no-index" +} + +# Note: this is excluded on the full code base, since any submitted code must +# have passed Google's internal style guidelines. +@test "Check formatting for C++ files" { + skip "clang-format doesn't match internal clang-format checker" + echo "clang-format is recommended. Here are the suggested changes:" + echo "=============================" + grep -e '\.h$' -e '\.cc$' $BATS_FILE_TMPDIR/changed_files > $BATS_TEST_TMPDIR/files || true + if [[ ! -s $BATS_TEST_TMPDIR/files ]]; then return 0; fi + xargs -a $BATS_TEST_TMPDIR/files -i -n1 -P $(nproc --all) \ + bash -c 'clang-format-12 --style=Google {} | git diff --no-index {} -' \ + | tee $BATS_TEST_TMPDIR/needs_help.txt + echo "You can use clang-format --style=Google -i to apply changes to a file." + [[ ! -s $BATS_TEST_TMPDIR/needs_help.txt ]] +} + +# Note: this is excluded on the full code base, since any submitted code must +# have passed Google's internal style guidelines. +@test "Check pylint for Python files" { + echo "Python formatting is recommended. Here are the pylint errors:" + echo "=============================" + grep -e "\.py$" $BATS_FILE_TMPDIR/changed_files > $BATS_TEST_TMPDIR/files || true + if [[ ! -s $BATS_TEST_TMPDIR/files ]]; then return 0; fi + xargs -a $BATS_TEST_TMPDIR/files -n1 -P $(nproc --all) \ + python -m pylint --rcfile=tensorflow/tools/ci_build/pylintrc --score false \ + | grep -v "**** Module" \ + | tee $BATS_TEST_TMPDIR/needs_help.txt + [[ ! -s $BATS_TEST_TMPDIR/needs_help.txt ]] +} + +teardown_file() { + bazel shutdown +} diff --git a/ci/official/utilities/code_check_full.bats b/ci/official/utilities/code_check_full.bats new file mode 100644 index 00000000000000..b474b9fcf9e371 --- /dev/null +++ b/ci/official/utilities/code_check_full.bats @@ -0,0 +1,307 @@ +# vim: filetype=bash +# +# Copyright 2022 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +setup_file() { + cd $TFCI_RUNTIME_GIT_DIR + bazel version # Start the bazel server +} + +# Do a bazel query specifically for the licenses checker. It searches for +# targets matching the provided query, which start with // or @ but not +# //tensorflow (so it looks for //third_party, //external, etc.), and then +# gathers the list of all packages (i.e. directories) which contain those +# targets. +license_query() { + bazel cquery --experimental_cc_shared_library "$1" --keep_going \ + | grep -e "^//" -e "^@" \ + | grep -E -v "^//tensorflow" \ + | sed -e 's|:.*||' \ + | sort -u +} + +# Verify that, given a build target and a license-list generator target, all of +# the dependencies of that target which include a license notice file are then +# included when generating that license. Necessary because the license targets +# in TensorFlow are manually enumerated rather than generated automatically. +do_external_licenses_check(){ + BUILD_TARGET="$1" + LICENSES_TARGET="$2" + + # grep patterns for targets which are allowed to be missing from the licenses + cat > $BATS_TEST_TMPDIR/allowed_to_be_missing < $BATS_TEST_TMPDIR/allowed_to_be_extra < $BATS_TEST_TMPDIR/expected_licenses + license_query "deps($LICENSES_TARGET)" > $BATS_TEST_TMPDIR/actual_licenses + + # Column 1 is left only, Column 2 is right only, Column 3 is shared lines + # Select lines unique to actual_licenses, i.e. extra licenses. + comm -1 -3 $BATS_TEST_TMPDIR/expected_licenses $BATS_TEST_TMPDIR/actual_licenses | grep -v -f $BATS_TEST_TMPDIR/allowed_to_be_extra > $BATS_TEST_TMPDIR/actual_extra_licenses || true + # Select lines unique to expected_licenses, i.e. missing licenses + comm -2 -3 $BATS_TEST_TMPDIR/expected_licenses $BATS_TEST_TMPDIR/actual_licenses | grep -v -f $BATS_TEST_TMPDIR/allowed_to_be_missing > $BATS_TEST_TMPDIR/actual_missing_licenses || true + + if [[ -s $BATS_TEST_TMPDIR/actual_extra_licenses ]]; then + echo "Please remove the following extra licenses from $LICENSES_TARGET:" + cat $BATS_TEST_TMPDIR/actual_extra_licenses + fi + + if [[ -s $BATS_TEST_TMPDIR/actual_missing_licenses ]]; then + echo "Please include the missing licenses for the following packages in $LICENSES_TARGET:" + cat $BATS_TEST_TMPDIR/actual_missing_licenses + fi + + # Fail if either of the two "extras" or "missing" lists are present. If so, + # then the user will see the above error messages. + [[ ! -s $BATS_TEST_TMPDIR/actual_extra_licenses ]] && [[ ! -s $BATS_TEST_TMPDIR/actual_missing_licenses ]] +} + +@test "Pip package generated license includes all dependencies' licenses" { + do_external_licenses_check \ + "//tensorflow/tools/pip_package:build_pip_package" \ + "//tensorflow/tools/pip_package:licenses" +} + +@test "Libtensorflow generated license includes all dependencies' licenses" { + do_external_licenses_check \ + "//tensorflow:libtensorflow.so" \ + "//tensorflow/tools/lib_package:clicenses_generate" +} + +@test "Java library generated license includes all dependencies' licenses" { + do_external_licenses_check \ + "//tensorflow/java:libtensorflow_jni.so" \ + "//tensorflow/tools/lib_package:jnilicenses_generate" +} + +# This test ensures that all the targets built into the Python package include +# their dependencies. It's a rewritten version of the "smoke test", an older +# Python script that was very difficult to understand. See +# https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/pip_package/pip_smoke_test.py +@test "Pip package includes all required //tensorflow dependencies" { + # grep patterns for packages whose dependencies can be ignored + cat > $BATS_TEST_TMPDIR/ignore_deps_for_these_packages < $BATS_TEST_TMPDIR/ignore_these_deps < $BATS_TEST_TMPDIR/pip_deps + # Find all Python py_test targets not tagged "no_pip" or "manual", excluding + # any targets in ignored packages. Combine this list of targets into a bazel + # query list (e.g. the list becomes "target+target2+target3") + bazel query --keep_going 'kind(py_test, //tensorflow/python/...) - attr("tags", "no_pip|manual", //tensorflow/python/...)' | grep -v -f $BATS_TEST_TMPDIR/ignore_deps_for_these_packages | paste -sd "+" - > $BATS_TEST_TMPDIR/deps + # Find all one-step dependencies of those tests which are from //tensorflow + # (since external deps will come from Python-level pip dependencies), + # excluding dependencies and files that are known to be unneccessary. + # This creates a list of targets under //tensorflow that are required for + # TensorFlow python tests. + bazel query --keep_going "deps($(cat $BATS_TEST_TMPDIR/deps), 1)" | grep "^//tensorflow" | grep -v -f $BATS_TEST_TMPDIR/ignore_these_deps | sort -u > $BATS_TEST_TMPDIR/required_deps + + + # Find if any required dependencies are missing from the list of dependencies + # included in the pip package. + # (comm: Column 1 is left, Column 2 is right, Column 3 is shared lines) + comm -2 -3 $BATS_TEST_TMPDIR/required_deps $BATS_TEST_TMPDIR/pip_deps > $BATS_TEST_TMPDIR/missing_deps || true + + if [[ -s $BATS_TEST_TMPDIR/missing_deps ]]; then + cat < $BATS_TEST_TMPDIR/out + + cat < $BATS_TEST_TMPDIR/out + + cat <> errors.txt + fi + if [[ -e errors.txt ]]; then + echo "Broken links found:" + cat errors.txt + rm errors.txt + false + fi + done +} + +@test "No duplicate files on Windows" { + cat < + +$(basename "$KOKORO_JOB_NAME") + + +

TensorFlow Job Logs and Links

+

Job Details

+
    +
  • Job name: $KOKORO_JOB_NAME
  • +
  • Job pool: $KOKORO_JOB_POOL
  • +
  • Job ID: $KOKORO_BUILD_ID
  • +
  • Current HEAD Piper Changelist, if any: cl/${KOKORO_PIPER_CHANGELIST:-not available}
  • +
  • Pull Request Number, if any: ${KOKORO_GITHUB_PULL_REQUEST_NUMBER_tensorflow:- none}
  • +
  • Pull Request Link, if any: ${KOKORO_GITHUB_PULL_REQUEST_URL_tensorflow:-none}
  • +
  • Commit: $KOKORO_GIT_COMMIT_tensorflow
  • +
+

Googlers-Only Links

+ +

Non-Googler Links

+ + +EOF diff --git a/ci/official/utilities/rename_and_verify_wheels.sh b/ci/official/utilities/rename_and_verify_wheels.sh new file mode 100755 index 00000000000000..94c2870c13e7f2 --- /dev/null +++ b/ci/official/utilities/rename_and_verify_wheels.sh @@ -0,0 +1,35 @@ +#!/usr/bin/env bash +# +# Copyright 2022 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# +# Check and rename wheels with auditwheel. Inserts the platform tags like +# "manylinux_xyz" into the wheel filename. +set -euxo pipefail + +for wheel in "$TFCI_RUNTIME_ARTIFACTS_DIR"/*.whl; do + echo "Checking and renaming $wheel..." + time python3 -m auditwheel repair --plat manylinux2014_x86_64 "$wheel" --wheel-dir "$TFCI_RUNTIME_ARTIFACTS_DIR" 2>&1 | tee check.txt + + # We don't need the original wheel if it was renamed + new_wheel=$(grep --extended-regexp --only-matching '/tf/pkg/\S+.whl' check.txt) + if [[ "$new_wheel" != "$wheel" ]]; then + rm "$wheel" + wheel="$new_wheel" + fi + rm check.txt + + TF_WHEEL="$wheel" bats "$TFCI_RUNTIME_USERTOOLS_DIR/wheel_verification.bats" --timing +done diff --git a/ci/official/utilities/repack_libtensorflow.sh b/ci/official/utilities/repack_libtensorflow.sh new file mode 100755 index 00000000000000..7d75935d2aa9c7 --- /dev/null +++ b/ci/official/utilities/repack_libtensorflow.sh @@ -0,0 +1,71 @@ +#!/bin/bash +# +# Copyright 2022 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ============================================================================== +# +# Repacks libtensorflow tarballs into $DIR with provided $TARBALL_SUFFIX, +# and also repacks libtensorflow-src.jar into a standardized format. +# +# -e: abort script if one command fails +# -u: error if undefined variable used +# -o pipefail: entire command fails if pipe fails. watch out for yes | ... +# -o history: record shell history +set -euxo pipefail -o history +set -o allexport && source "$TFCI" && set +o allexport + +# Helper function to copy a srcjar after moving any source files +# directly under the root to the "maven-style" src/main/java layout +# +# Source files generated by annotation processors appear directly +# under the root of srcjars jars created by bazel, rather than under +# the maven-style src/main/java subdirectory. +# +# Bazel manages annotation generated source as follows: First, it +# calls javac with options that create generated files under a +# bazel-out directory. Next, it archives the generated source files +# into a srcjar directly under the root. There doesn't appear to be a +# simple way to parameterize this from bazel, hence this helper to +# "normalize" the srcjar layout. +# +# Arguments: +# src_jar - path to the original srcjar +# dest_jar - path to the destination +# Returns: +# None +function cp_normalized_srcjar() { + src_jar="$1" + dest_jar="$2" + tmp_dir=$(mktemp -d) + cp "${src_jar}" "${tmp_dir}/orig.jar" + pushd "${tmp_dir}" + # Extract any src/ files + jar -xf "${tmp_dir}/orig.jar" src/ + # Extract any org/ files under src/main/java + (mkdir -p src/main/java && cd src/main/java && jar -xf "${tmp_dir}/orig.jar" org/) + # Repackage src/ + jar -cMf "${tmp_dir}/new.jar" src + popd + cp "${tmp_dir}/new.jar" "${dest_jar}" + rm -rf "${tmp_dir}" +} +DIR=$1 +TARBALL_SUFFIX=$2 +mkdir -p "$DIR" +cp bazel-bin/tensorflow/tools/lib_package/libtensorflow.tar.gz "${DIR}/libtensorflow${TARBALL_SUFFIX}.tar.gz" +cp bazel-bin/tensorflow/tools/lib_package/libtensorflow_jni.tar.gz "${DIR}/libtensorflow_jni${TARBALL_SUFFIX}.tar.gz" +cp bazel-bin/tensorflow/java/libtensorflow.jar "${DIR}" +cp_normalized_srcjar bazel-bin/tensorflow/java/libtensorflow-src.jar "${DIR}/libtensorflow-src.jar" +cp bazel-bin/tensorflow/tools/lib_package/libtensorflow_proto.zip "${DIR}" diff --git a/ci/official/utilities/wheel_verification.bats b/ci/official/utilities/wheel_verification.bats new file mode 100644 index 00000000000000..1786163a737cb8 --- /dev/null +++ b/ci/official/utilities/wheel_verification.bats @@ -0,0 +1,78 @@ +# Copyright 2022 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# Suite of verification tests for the SINGLE TensorFlow wheel in /tf/pkg +# or whatever path is set as $TF_WHEEL. + +setup_file() { + cd "$TFCI_RUNTIME_ARTIFACTS_DIR" + if [[ -z "$TF_WHEEL" ]]; then + export TF_WHEEL=$(find "$TFCI_RUNTIME_ARTIFACTS_DIR" -iname "*.whl") + fi +} + +teardown_file() { + rm -rf /tmp/venv +} + +@test "Wheel is manylinux2014 (manylinux_2_17) compliant" { + python3 -m auditwheel show "$TF_WHEEL" > audit.txt + grep --quiet 'This constrains the platform tag to "manylinux_2_17_x86_64"' audit.txt +} + +@test "Wheel conforms to upstream size limitations" { + WHEEL_MEGABYTES=$(stat --format %s "$TF_WHEEL" | awk '{print int($1/(1024*1024))}') + # Googlers: search for "test_tf_whl_size" + case "$TF_WHEEL" in + # CPU: + *cpu*manylinux*) LARGEST_OK_SIZE=240 ;; + # GPU: + *manylinux*) LARGEST_OK_SIZE=580 ;; + # Unknown: + *) + echo "The wheel's name is in an unknown format." + exit 1 + ;; + esac + # >&3 forces output in bats even if the test passes. See + # https://bats-core.readthedocs.io/en/stable/writing-tests.html#printing-to-the-terminal + echo "# Size of $TF_WHEEL is $WHEEL_MEGABYTES / $LARGEST_OK_SIZE megabytes." >&3 + test "$WHEEL_MEGABYTES" -le "$LARGEST_OK_SIZE" +} + +# Note: this runs before the tests further down the file, so TF is installed in +# the venv and the venv is active when those tests run. The venv gets cleaned +# up in teardown_file() above. +@test "Wheel is installable" { + python3 -m venv "$BATS_FILE_TMPDIR/venv" + source "$BATS_FILE_TMPDIR/bin/activate" + python3 -m pip install "$TF_WHEEL" +} + +@test "TensorFlow is importable" { + source "$BATS_FILE_TMPDIR/bin/activate" + python3 -c 'import tensorflow as tf; t1=tf.constant([1,2,3,4]); t2=tf.constant([5,6,7,8]); print(tf.add(t1,t2).shape)' +} + +# Is this still useful? +@test "TensorFlow has Keras" { + source "$BATS_FILE_TMPDIR/bin/activate" + python3 -c 'import sys; import tensorflow as tf; sys.exit(0 if "_v2.keras" in tf.keras.__name__ else 1)' +} + +# Is this still useful? +@test "TensorFlow has Estimator" { + source "$BATS_FILE_TMPDIR/bin/activate" + python3 -c 'import sys; import tensorflow as tf; sys.exit(0 if "_v2.estimator" in tf.estimator.__name__ else 1)' +} diff --git a/ci/official/wheel.sh b/ci/official/wheel.sh new file mode 100755 index 00000000000000..d72fc174bda80a --- /dev/null +++ b/ci/official/wheel.sh @@ -0,0 +1,33 @@ +#!/bin/bash +# -e: abort script if one command fails +# -u: error if undefined variable used +# -o pipefail: entire command fails if pipe fails. watch out for yes | ... +# -o history: record shell history +set -euxo pipefail -o history +set -o allexport && source "$TFCI" && set +o allexport + +# If this is a CL presubmit, then run Copybara on the Piper code and place it +# in the same directory as the GitHub source code would normally be. This lets +# the rest of the script proceed as normal. +_TFCI_HOST_ARTIFACTS_DIR="$TFCI_RUNTIME_ARTIFACTS_DIR" +tfrun() { "$@"; } +[[ "$TFCI_COPYBARA_ENABLE" = 1 ]] && source $TFCI_RUNTIME_USERTOOLS_DIR/copybara.sh +[[ "$TFCI_DOCKER_ENABLE" = 1 ]] && source $TFCI_RUNTIME_USERTOOLS_DIR/docker.sh +"$TFCI_RUNTIME_USERTOOLS_DIR/generate_index_html.sh" "$TFCI_RUNTIME_ARTIFACTS_DIR/index.html" + +# Record GPU count and CUDA version status +[[ "$TFCI_NVIDIA_SMI_ENABLE" = 1 ]] && tfrun nvidia-smi + +# Update the version numbers for Nightly only +[[ "$TFCI_NIGHTLY_UPDATE_VERSION_ENABLE" = 1 ]] && tfrun python3 tensorflow/tools/ci_build/update_version.py --nightly + +tfrun bazel "${TFCI_BAZEL_BAZELRC_ARGS[@]}" build "${TFCI_BAZEL_CACHE_ARGS[@]}" tensorflow/tools/pip_package:build_pip_package +tfrun ./bazel-bin/tensorflow/tools/pip_package/build_pip_package "$TFCI_RUNTIME_ARTIFACTS_DIR" "${TFCI_BUILD_PIP_PACKAGE_ARGS[@]}" +tfrun "$TFCI_RUNTIME_USERTOOLS_DIR/rename_and_verify_wheels.sh" + +if [[ "$TFCI_UPLOAD_ENABLE" = 1 ]]; then + twine upload "${TFCI_UPLOAD_PYPI_ARGS[@]}" "$_TFCI_HOST_ARTIFACTS_DIR"/*.whl + gsutil cp "$_TFCI_HOST_ARTIFACTS_DIR"/*.whl "$TFCI_UPLOAD_GCS_DESTINATION" +fi + +tfrun bazel "${TFCI_BAZEL_BAZELRC_ARGS[@]}" test "${TFCI_BAZEL_CACHE_ARGS[@]}" --config=nonpip From adbcd77908a77dbf8a016ba5238ce7f4dc3f8dd3 Mon Sep 17 00:00:00 2001 From: Austin Anderson Date: Wed, 28 Jun 2023 23:12:57 -0700 Subject: [PATCH 029/376] Update cleaner scripts --- ARTIFACTS/index.html | 32 +++++++++++++++++++ ci/official/any.sh | 12 +++---- ci/official/code_check_changed_files.sh | 10 +++--- ci/official/code_check_full.sh | 10 +++--- ci/official/envs/{env.local_cpu => local_cpu} | 12 ++----- .../envs/{env.nightly_cpu => nightly_cpu} | 19 +++-------- ci/official/libtensorflow.sh | 17 ++++------ ci/official/pycpp.sh | 13 +++----- .../utilities/code_check_changed_files.bats | 4 +-- ci/official/utilities/code_check_full.bats | 2 +- ci/official/utilities/docker.sh | 7 ++-- .../utilities/rename_and_verify_wheels.sh | 7 ++-- ci/official/utilities/repack_libtensorflow.sh | 1 + ci/official/utilities/wheel_verification.bats | 7 ++-- ci/official/wheel.sh | 21 ++++++------ 15 files changed, 90 insertions(+), 84 deletions(-) create mode 100644 ARTIFACTS/index.html rename ci/official/envs/{env.local_cpu => local_cpu} (60%) rename ci/official/envs/{env.nightly_cpu => nightly_cpu} (60%) diff --git a/ARTIFACTS/index.html b/ARTIFACTS/index.html new file mode 100644 index 00000000000000..f8aa701e2a96f6 --- /dev/null +++ b/ARTIFACTS/index.html @@ -0,0 +1,32 @@ + + + + + +

TensorFlow Job Logs and Links

+

Job Details

+
    +
  • Job name:
  • +
  • Job pool:
  • +
  • Job ID:
  • +
  • Current HEAD Piper Changelist, if any: cl/not available
  • +
  • Pull Request Number, if any: none
  • +
  • Pull Request Link, if any: none
  • +
  • Commit:
  • +
+

Googlers-Only Links

+ +

Non-Googler Links

+ + diff --git a/ci/official/any.sh b/ci/official/any.sh index 43bba1a2bec474..f5a6278c99ad2f 100755 --- a/ci/official/any.sh +++ b/ci/official/any.sh @@ -6,11 +6,11 @@ set -euxo pipefail -o history set -o allexport && source "$TFCI" && set +o allexport -_TFCI_HOST_ARTIFACTS_DIR="$TFCI_RUNTIME_ARTIFACTS_DIR" +cd "$TFCI_GIT_DIR" && mkdir -p build tfrun() { "$@"; } -[[ "$TFCI_COPYBARA_ENABLE" = 1 ]] && source $TFCI_RUNTIME_USERTOOLS_DIR/copybara.sh -[[ "$TFCI_DOCKER_ENABLE" = 1 ]] && source $TFCI_RUNTIME_USERTOOLS_DIR/docker.sh -"$TFCI_RUNTIME_USERTOOLS_DIR/generate_index_html.sh" "$TFCI_RUNTIME_ARTIFACTS_DIR/index.html" +[[ "$TFCI_COPYBARA_ENABLE" = 1 ]] && source ./ci/official/utilities/copybara.sh +[[ "$TFCI_DOCKER_ENABLE" = 1 ]] && source ./ci/official/utilities/docker.sh +./ci/official/utilities/generate_index_html.sh build/index.html # Parse options and build targets into arrays, so that shelllint doesn't yell # about readability. We can't pipe into 'read -ra' to create an array because @@ -36,8 +36,8 @@ if [[ "${PIP_WHEEL}" -eq "1" ]]; then [[ "$TFCI_NIGHTLY_UPDATE_VERSION_ENABLE" = 1 ]] && tfrun python3 tensorflow/tools/ci_build/update_version.py --nightly tfrun bazel "${TFCI_BAZEL_BAZELRC_ARGS[@]}" build "${TFCI_BAZEL_CACHE_ARGS[@]}" tensorflow/tools/pip_package:build_pip_package - tfrun ./bazel-bin/tensorflow/tools/pip_package/build_pip_package "$TFCI_RUNTIME_ARTIFACTS_DIR" "${TFCI_BUILD_PIP_PACKAGE_ARGS[@]}" - tfrun "$TFCI_RUNTIME_USERTOOLS_DIR/rename_and_verify_wheels.sh" + tfrun ./bazel-bin/tensorflow/tools/pip_package/build_pip_package build "${TFCI_BUILD_PIP_PACKAGE_ARGS[@]}" + tfrun ./ci/official/utilities/rename_and_verify_wheels.sh fi if [[ -s nonpip_targets.txt ]]; then diff --git a/ci/official/code_check_changed_files.sh b/ci/official/code_check_changed_files.sh index 49924bc2344eb5..48e8b4920c0766 100755 --- a/ci/official/code_check_changed_files.sh +++ b/ci/official/code_check_changed_files.sh @@ -6,10 +6,10 @@ set -euxo pipefail -o history set -o allexport && source "$TFCI" && set +o allexport -_TFCI_HOST_ARTIFACTS_DIR="$TFCI_RUNTIME_ARTIFACTS_DIR" +cd "$TFCI_GIT_DIR" && mkdir -p build tfrun() { "$@"; } -[[ "$TFCI_COPYBARA_ENABLE" = 1 ]] && source $TFCI_RUNTIME_USERTOOLS_DIR/copybara.sh -[[ "$TFCI_DOCKER_ENABLE" = 1 ]] && source $TFCI_RUNTIME_USERTOOLS_DIR/docker.sh -"$TFCI_RUNTIME_USERTOOLS_DIR/generate_index_html.sh" "$TFCI_RUNTIME_ARTIFACTS_DIR/index.html" +[[ "$TFCI_COPYBARA_ENABLE" = 1 ]] && source ./ci/official/utilities/copybara.sh +[[ "$TFCI_DOCKER_ENABLE" = 1 ]] && source ./ci/official/utilities/docker.sh +./ci/official/utilities/generate_index_html.sh build/index.html -tfrun bats "$TFCI_RUNTIME_USERTOOLS_DIR"/code_check_changed_files.bats --timing --output "$TFCI_RUNTIME_ARTIFACTS_DIR" +tfrun bats ./ci/official/utilities/code_check_changed_files.bats --timing --output build diff --git a/ci/official/code_check_full.sh b/ci/official/code_check_full.sh index 224b73375285d8..d2f7ef4b4ecdf1 100755 --- a/ci/official/code_check_full.sh +++ b/ci/official/code_check_full.sh @@ -6,10 +6,10 @@ set -euxo pipefail -o history set -o allexport && source "$TFCI" && set +o allexport -_TFCI_HOST_ARTIFACTS_DIR="$TFCI_RUNTIME_ARTIFACTS_DIR" +cd "$TFCI_GIT_DIR" && mkdir -p build tfrun() { "$@"; } -[[ "$TFCI_COPYBARA_ENABLE" = 1 ]] && source $TFCI_RUNTIME_USERTOOLS_DIR/copybara.sh -[[ "$TFCI_DOCKER_ENABLE" = 1 ]] && source $TFCI_RUNTIME_USERTOOLS_DIR/docker.sh -"$TFCI_RUNTIME_USERTOOLS_DIR/generate_index_html.sh" "$TFCI_RUNTIME_ARTIFACTS_DIR/index.html" +[[ "$TFCI_COPYBARA_ENABLE" = 1 ]] && source ./ci/official/utilities/copybara.sh +[[ "$TFCI_DOCKER_ENABLE" = 1 ]] && source ./ci/official/utilities/docker.sh +./ci/official/utilities/generate_index_html.sh build/index.html -tfrun bats "$TFCI_RUNTIME_USERTOOLS_DIR"/code_check_full.bats --timing --output "$TFCI_RUNTIME_ARTIFACTS_DIR" +tfrun bats ./ci/official/utilities/code_check_full.bats --timing --output build diff --git a/ci/official/envs/env.local_cpu b/ci/official/envs/local_cpu similarity index 60% rename from ci/official/envs/env.local_cpu rename to ci/official/envs/local_cpu index 0686157de77cee..cf3d1137e77fb8 100644 --- a/ci/official/envs/env.local_cpu +++ b/ci/official/envs/local_cpu @@ -1,25 +1,19 @@ -TFCI_BAZEL_BAZELRC_ARGS=(--bazelrc /usertools/cpu.bazelrc) +TFCI_BAZEL_BAZELRC_ARGS=(--bazelrc ./ci/official/utilities/bazelrcs/cpu.bazelrc) TFCI_BAZEL_CACHE_ARGS=(--config sigbuild_remote_cache) TFCI_BUILD_PIP_PACKAGE_ARGS=("--cpu") TFCI_COPYBARA_ENABLE=0 -TFCI_COPYBARA_SCRIPT_PATH= -TFCI_COPYBARA_GIT_DIR= -TFCI_DOCKER_ARTIFACTS_HOST_DIR=/tmp/tf -TFCI_DOCKER_GIT_HOST_DIR=/usr/local/google/home/angerson/repos/tensorflow TFCI_DOCKER_ENABLE=1 TFCI_DOCKER_GPU_ARGS=() TFCI_DOCKER_IMAGE=tensorflow/build:latest-python3.9 TFCI_DOCKER_PULL_ENABLE= +TFCI_GIT_DIR=/usr/local/google/home/angerson/repos/tensorflow TFCI_LIB_SUFFIX="-cpu-linux-x86_64" TFCI_NIGHTLY_UPDATE_VERSION_ENABLE= TFCI_NVIDIA_SMI_ENABLE= -TFCI_RUNTIME_ARTIFACTS_DIR=/tf/artifacts -TFCI_RUNTIME_USERTOOLS_DIR=/tf/tensorflow/ci/official/utilities -TFCI_RUNTIME_GIT_DIR=/tf/tensorflow TFCI_UPLOAD_LIB_ENABLE= -TFCI_UPLOAD_LIB_URI= TFCI_UPLOAD_LIB_LATEST_ENABLE= TFCI_UPLOAD_LIB_LATEST_URI= +TFCI_UPLOAD_LIB_URI= TFCI_UPLOAD_WHL_GCS_ENABLE= TFCI_UPLOAD_WHL_GCS_URI= TFCI_UPLOAD_WHL_PYPI_ARGS= diff --git a/ci/official/envs/env.nightly_cpu b/ci/official/envs/nightly_cpu similarity index 60% rename from ci/official/envs/env.nightly_cpu rename to ci/official/envs/nightly_cpu index f0515d2ceff3c5..02ec5f0a33621f 100644 --- a/ci/official/envs/env.nightly_cpu +++ b/ci/official/envs/nightly_cpu @@ -1,33 +1,22 @@ -TFCI_BAZEL_BAZELRC_ARGS=(--bazelrc /usertools/cpu.bazelrc) +TFCI_BAZEL_BAZELRC_ARGS=(--bazelrc ./ci/official/utilities/bazelrcs/cpu.bazelrc) TFCI_BAZEL_CACHE_ARGS=(--config sigbuild_remote_cache_push) TFCI_BUILD_PIP_PACKAGE_ARGS=("--cpu" "--nightly_flag") TFCI_COPYBARA_ENABLE=1 -TFCI_COPYBARA_SCRIPT_PATH="$KOKORO_ARTIFACTS_DIR/google3/learning/brain/testing/kokoro/rel/docker/run_copybara_for_presubmit.sh" -TFCI_COPYBARA_GIT_DIR="$KOKORO_ARTIFACTS_DIR/google3/learning/brain/testing/kokoro/rel/docker/run_copybara_for_presubmit.sh" -TFCI_DOCKER_ARTIFACTS_HOST_DIR="$KOKORO_ARTIFACTS_DIR" -TFCI_DOCKER_GIT_HOST_DIR="$KOKORO_ARTIFACTS_DIR/github/tensorflow" TFCI_DOCKER_ENABLE=1 TFCI_DOCKER_GPU_ARGS=() TFCI_DOCKER_IMAGE= TFCI_DOCKER_PULL_ENABLE=1 +TFCI_GIT_DIR=/tf/tensorflow TFCI_LIB_SUFFIX="-cpu-linux-x86_64" TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1 TFCI_NVIDIA_SMI_ENABLE=1 -TFCI_RUNTIME_ARTIFACTS_DIR=/tf/artifacts -TFCI_RUNTIME_USERTOOLS_DIR=/usertools -TFCI_RUNTIME_GIT_DIR=/tf/tensorflow TFCI_UPLOAD_LIB_ENABLE=1 -TFCI_UPLOAD_LIB_URI="gs://libtensorflow-nightly/$(date -I)" -#TFCI_UPLOAD_LIB_URI="gs://tensorflow-release-packages/$RELEASE_VERSION/$KOKORO_GIT_COMMIT_tensorflow" TFCI_UPLOAD_LIB_LATEST_ENABLE=1 TFCI_UPLOAD_LIB_LATEST_URI="gs://libtensorflow-nightly/latest" +TFCI_UPLOAD_LIB_URI="gs://libtensorflow-nightly/$(date -I)" +#TFCI_UPLOAD_LIB_URI="gs://tensorflow-release-packages/$RELEASE_VERSION/$KOKORO_GIT_COMMIT_tensorflow" TFCI_UPLOAD_WHL_GCS_ENABLE= TFCI_UPLOAD_WHL_GCS_URI= #TFCI_UPLOAD_WHL_GCS_URI="gs://tensorflow-release-packages/$RELEASE_VERSION/$KOKORO_GIT_COMMIT_tensorflow" TFCI_UPLOAD_WHL_PYPI_ARGS=( --config-file "$KOKORO_KEYSTORE_DIR/73361_tensorflow_pypirc_using_global_api_token" --repository testpypi ) TFCI_UPLOAD_WHL_PYPI_ENABLE=1 - - -TFCI_GIT_DIR=tensorflow - -tensorflow/ARTIFACTS diff --git a/ci/official/libtensorflow.sh b/ci/official/libtensorflow.sh index 6b539bf10c64d9..bb5472d8233998 100755 --- a/ci/official/libtensorflow.sh +++ b/ci/official/libtensorflow.sh @@ -6,14 +6,11 @@ set -euxo pipefail -o history set -o allexport && source "$TFCI" && set +o allexport -# If this is a CL presubmit, then run Copybara on the Piper code and place it -# in the same directory as the GitHub source code would normally be. This lets -# the rest of the script proceed as normal. -_TFCI_HOST_ARTIFACTS_DIR="$TFCI_RUNTIME_ARTIFACTS_DIR" +cd "$TFCI_GIT_DIR" && mkdir -p build tfrun() { "$@"; } -[[ "$TFCI_COPYBARA_ENABLE" = 1 ]] && source $TFCI_RUNTIME_USERTOOLS_DIR/copybara.sh -[[ "$TFCI_DOCKER_ENABLE" = 1 ]] && source $TFCI_RUNTIME_USERTOOLS_DIR/docker.sh -"$TFCI_RUNTIME_USERTOOLS_DIR/generate_index_html.sh" "$TFCI_RUNTIME_ARTIFACTS_DIR/index.html" +[[ "$TFCI_COPYBARA_ENABLE" = 1 ]] && source ./ci/official/utilities/copybara.sh +[[ "$TFCI_DOCKER_ENABLE" = 1 ]] && source ./ci/official/utilities/docker.sh +./ci/official/utilities/generate_index_html.sh build/index.html # Record GPU count and CUDA version status [[ "$TFCI_NVIDIA_SMI_ENABLE" = 1 ]] && tfrun nvidia-smi @@ -24,11 +21,11 @@ tfrun() { "$@"; } tfrun bazel "${TFCI_BAZEL_BAZELRC_ARGS[@]}" test "${TFCI_BAZEL_CACHE_ARGS[@]}" --config=libtensorflow_test tfrun bazel "${TFCI_BAZEL_BAZELRC_ARGS[@]}" build "${TFCI_BAZEL_CACHE_ARGS[@]}" --config=libtensorflow_build -tfrun "$TFCI_RUNTIME_USERTOOLS_DIR"/repack_libtensorflow.sh "$TFCI_RUNTIME_ARTIFACTS_DIR" "$TFCI_LIB_SUFFIX" +tfrun ./ci/official/utilities/repack_libtensorflow.sh build "$TFCI_LIB_SUFFIX" if [[ "$TFCI_UPLOAD_LIB_ENABLE" = 1 ]]; then - gsutil cp "$_TFCI_HOST_ARTIFACTS_DIR"/*.tar.gz "$TFCI_UPLOAD_LIB_GCS_URI" + gsutil cp build/*.tar.gz "$TFCI_UPLOAD_LIB_GCS_URI" if [[ "$TFCI_UPLOAD_LIB_LATEST_ENABLE" = 1 ]]; then - gsutil cp "$_TFCI_HOST_ARTIFACTS_DIR"/*.tar.gz "$TFCI_UPLOAD_LIB_LATEST_GCS_URI" + gsutil cp build/*.tar.gz "$TFCI_UPLOAD_LIB_LATEST_GCS_URI" fi fi diff --git a/ci/official/pycpp.sh b/ci/official/pycpp.sh index e9f61a886cae1c..59d3e0ddc3b74e 100755 --- a/ci/official/pycpp.sh +++ b/ci/official/pycpp.sh @@ -6,17 +6,14 @@ set -euxo pipefail -o history set -o allexport && source "$TFCI" && set +o allexport -# If this is a CL presubmit, then run Copybara on the Piper code and place it -# in the same directory as the GitHub source code would normally be. This lets -# the rest of the script proceed as normal. -_TFCI_HOST_ARTIFACTS_DIR="$TFCI_RUNTIME_ARTIFACTS_DIR" +cd "$TFCI_GIT_DIR" && mkdir -p build tfrun() { "$@"; } -[[ "$TFCI_COPYBARA_ENABLE" = 1 ]] && source $TFCI_RUNTIME_USERTOOLS_DIR/copybara.sh -[[ "$TFCI_DOCKER_ENABLE" = 1 ]] && source $TFCI_RUNTIME_USERTOOLS_DIR/docker.sh -"$TFCI_RUNTIME_USERTOOLS_DIR/generate_index_html.sh" "$TFCI_RUNTIME_ARTIFACTS_DIR/index.html" +[[ "$TFCI_COPYBARA_ENABLE" = 1 ]] && source ./ci/official/utilities/copybara.sh +[[ "$TFCI_DOCKER_ENABLE" = 1 ]] && source ./ci/official/utilities/docker.sh +./ci/official/utilities/generate_index_html.sh build/index.html # TODO(b/284172313) Revert this difference between presubmits and continuous. RBE serverside behavior is causing flakes, # so we're temporarily allowing flaky tests again for presubmits. tfrun bazel "${TFCI_BAZEL_BAZELRC_ARGS[@]}" test "${TFCI_BAZEL_CACHE_ARGS[@]}" --config=rbe --config=pycpp --config=build_event_export -tfrun bazel analyze-profile $TFCI_RUNTIME_ART/profile.json.gz +tfrun bazel analyze-profile build/profile.json.gz diff --git a/ci/official/utilities/code_check_changed_files.bats b/ci/official/utilities/code_check_changed_files.bats index 773b0847e943e1..d912242a2b17c8 100644 --- a/ci/official/utilities/code_check_changed_files.bats +++ b/ci/official/utilities/code_check_changed_files.bats @@ -16,10 +16,10 @@ # ============================================================================== setup_file() { - cd "$TFCI_RUNTIME_GIT_DIR" + cd "$TFCI_GIT_DIR" bazel version # Start the bazel server # Without this, git errors if /tf/tensorflow directory owner is different - git config --global --add safe.directory "$TF_RUNTIME_GIT_DIR" + git config --global --add safe.directory "$TFCI_GIT_DIR" # Note that you could generate a list of all the affected targets with e.g.: # bazel query $(paste -sd "+" $BATS_FILE_TMPDIR/changed_files) --keep_going # Only shows Added, Changed, Modified, Renamed, and Type-changed files diff --git a/ci/official/utilities/code_check_full.bats b/ci/official/utilities/code_check_full.bats index b474b9fcf9e371..c963fd850fc34f 100644 --- a/ci/official/utilities/code_check_full.bats +++ b/ci/official/utilities/code_check_full.bats @@ -15,7 +15,7 @@ # limitations under the License. # ============================================================================== setup_file() { - cd $TFCI_RUNTIME_GIT_DIR + cd $TFCI_GIT_DIR bazel version # Start the bazel server } diff --git a/ci/official/utilities/docker.sh b/ci/official/utilities/docker.sh index aa57abf740e204..4d69e5a5cf6602 100755 --- a/ci/official/utilities/docker.sh +++ b/ci/official/utilities/docker.sh @@ -10,11 +10,8 @@ trap "docker rm -f tf" EXIT if [[ "$TFCI_DOCKER_PULL_ENABLE" = 1 ]]; then docker pull "$TFCI_DOCKER_IMAGE" fi -docker run "${TFCI_DOCKER_GPU_ARGS[@]}" --name tf -w /tf/tensorflow -itd --rm \ - -v "$TFCI_DOCKER_ARTIFACTS_HOST_DIR:$TFCI_RUNTIME_ARTIFACTS_DIR" \ - -v "$TFCI_DOCKER_GIT_HOST_DIR:$TFCI_RUNTIME_GIT_DIR \ +docker run "${TFCI_DOCKER_GPU_ARGS[@]}" --name tf -w "$TFCI_GIT_DIR" -itd --rm \ + -v "$TFCI_GIT_DIR:$TFCI_GIT_DIR" \ "$TFCI_DOCKER_IMAGE" \ bash tfrun() { docker exec tf "$@"; } -export _TFCI_HOST_ARTIFACTS_DIR="$TFCI_DOCKER_ARTIFACTS_HOST_DIR" -export _TFCI_HOST_GIT_DIR="$TFCI_DOCKER_GIT_HOST_DIR" diff --git a/ci/official/utilities/rename_and_verify_wheels.sh b/ci/official/utilities/rename_and_verify_wheels.sh index 94c2870c13e7f2..500d0d9478dd1a 100755 --- a/ci/official/utilities/rename_and_verify_wheels.sh +++ b/ci/official/utilities/rename_and_verify_wheels.sh @@ -19,9 +19,10 @@ # "manylinux_xyz" into the wheel filename. set -euxo pipefail -for wheel in "$TFCI_RUNTIME_ARTIFACTS_DIR"/*.whl; do +cd $TFCI_GIT_DIR +for wheel in build/*.whl; do echo "Checking and renaming $wheel..." - time python3 -m auditwheel repair --plat manylinux2014_x86_64 "$wheel" --wheel-dir "$TFCI_RUNTIME_ARTIFACTS_DIR" 2>&1 | tee check.txt + time python3 -m auditwheel repair --plat manylinux2014_x86_64 "$wheel" --wheel-dir build 2>&1 | tee check.txt # We don't need the original wheel if it was renamed new_wheel=$(grep --extended-regexp --only-matching '/tf/pkg/\S+.whl' check.txt) @@ -31,5 +32,5 @@ for wheel in "$TFCI_RUNTIME_ARTIFACTS_DIR"/*.whl; do fi rm check.txt - TF_WHEEL="$wheel" bats "$TFCI_RUNTIME_USERTOOLS_DIR/wheel_verification.bats" --timing + TF_WHEEL="$wheel" bats ./ci/official/utilities/wheel_verification.bats --timing done diff --git a/ci/official/utilities/repack_libtensorflow.sh b/ci/official/utilities/repack_libtensorflow.sh index 7d75935d2aa9c7..fefce92f747ce3 100755 --- a/ci/official/utilities/repack_libtensorflow.sh +++ b/ci/official/utilities/repack_libtensorflow.sh @@ -1,4 +1,5 @@ #!/bin/bash + # # Copyright 2022 The TensorFlow Authors. All Rights Reserved. # diff --git a/ci/official/utilities/wheel_verification.bats b/ci/official/utilities/wheel_verification.bats index 1786163a737cb8..6a35adc0f05748 100644 --- a/ci/official/utilities/wheel_verification.bats +++ b/ci/official/utilities/wheel_verification.bats @@ -16,14 +16,15 @@ # or whatever path is set as $TF_WHEEL. setup_file() { - cd "$TFCI_RUNTIME_ARTIFACTS_DIR" + cd "$TFCI_GIT_DIR/build" if [[ -z "$TF_WHEEL" ]]; then - export TF_WHEEL=$(find "$TFCI_RUNTIME_ARTIFACTS_DIR" -iname "*.whl") + export TF_WHEEL=$(find "$TFCI_GIT_DIR/build" -iname "*.whl") fi } teardown_file() { - rm -rf /tmp/venv + rm -rf "$BATS_FILE_TMPDIR/venv" + python3 -m venv } @test "Wheel is manylinux2014 (manylinux_2_17) compliant" { diff --git a/ci/official/wheel.sh b/ci/official/wheel.sh index d72fc174bda80a..bcc426364c39bf 100755 --- a/ci/official/wheel.sh +++ b/ci/official/wheel.sh @@ -6,14 +6,11 @@ set -euxo pipefail -o history set -o allexport && source "$TFCI" && set +o allexport -# If this is a CL presubmit, then run Copybara on the Piper code and place it -# in the same directory as the GitHub source code would normally be. This lets -# the rest of the script proceed as normal. -_TFCI_HOST_ARTIFACTS_DIR="$TFCI_RUNTIME_ARTIFACTS_DIR" +cd "$TFCI_GIT_DIR" && mkdir -p build tfrun() { "$@"; } -[[ "$TFCI_COPYBARA_ENABLE" = 1 ]] && source $TFCI_RUNTIME_USERTOOLS_DIR/copybara.sh -[[ "$TFCI_DOCKER_ENABLE" = 1 ]] && source $TFCI_RUNTIME_USERTOOLS_DIR/docker.sh -"$TFCI_RUNTIME_USERTOOLS_DIR/generate_index_html.sh" "$TFCI_RUNTIME_ARTIFACTS_DIR/index.html" +[[ "$TFCI_COPYBARA_ENABLE" = 1 ]] && source ./ci/official/utilities/copybara.sh +[[ "$TFCI_DOCKER_ENABLE" = 1 ]] && source ./ci/official/utilities/docker.sh +./ci/official/utilities/generate_index_html.sh build/index.html # Record GPU count and CUDA version status [[ "$TFCI_NVIDIA_SMI_ENABLE" = 1 ]] && tfrun nvidia-smi @@ -21,13 +18,13 @@ tfrun() { "$@"; } # Update the version numbers for Nightly only [[ "$TFCI_NIGHTLY_UPDATE_VERSION_ENABLE" = 1 ]] && tfrun python3 tensorflow/tools/ci_build/update_version.py --nightly -tfrun bazel "${TFCI_BAZEL_BAZELRC_ARGS[@]}" build "${TFCI_BAZEL_CACHE_ARGS[@]}" tensorflow/tools/pip_package:build_pip_package -tfrun ./bazel-bin/tensorflow/tools/pip_package/build_pip_package "$TFCI_RUNTIME_ARTIFACTS_DIR" "${TFCI_BUILD_PIP_PACKAGE_ARGS[@]}" -tfrun "$TFCI_RUNTIME_USERTOOLS_DIR/rename_and_verify_wheels.sh" +tfrun bazel "${TFCI_BAZEL_BAZELRC_ARGS[@]}" build "${TFCI_BAZEL_CACHE_ARGS[@]}" //tensorflow/tools/pip_package:build_pip_package +tfrun ./bazel-bin/tensorflow/tools/pip_package/build_pip_package build "${TFCI_BUILD_PIP_PACKAGE_ARGS[@]}" +tfrun ./ci/official/utilities/rename_and_verify_wheels.sh build if [[ "$TFCI_UPLOAD_ENABLE" = 1 ]]; then - twine upload "${TFCI_UPLOAD_PYPI_ARGS[@]}" "$_TFCI_HOST_ARTIFACTS_DIR"/*.whl - gsutil cp "$_TFCI_HOST_ARTIFACTS_DIR"/*.whl "$TFCI_UPLOAD_GCS_DESTINATION" + twine upload "${TFCI_UPLOAD_PYPI_ARGS[@]}" build/*.whl + gsutil cp build/*.whl "$TFCI_UPLOAD_GCS_DESTINATION" fi tfrun bazel "${TFCI_BAZEL_BAZELRC_ARGS[@]}" test "${TFCI_BAZEL_CACHE_ARGS[@]}" --config=nonpip From 66d542e14caaa60935efb48f238f385a00355026 Mon Sep 17 00:00:00 2001 From: David Svantesson Date: Wed, 28 Jun 2023 10:03:35 +0000 Subject: [PATCH 030/376] Update to ACL 23.05, add ACL reorder --- tensorflow/tensorflow.bzl | 4 +- tensorflow/workspace2.bzl | 15 +- third_party/compute_library/BUILD | 189 --- .../compute_library/acl_acl_reorder.patch | 42 + third_party/compute_library/build_defs.bzl | 4 +- .../compute_library/compute_library.patch | 77 +- third_party/mkl_dnn/mkldnn_acl.BUILD | 2 +- .../onednn_acl_depthwise_convolution.patch | 312 ++-- .../onednn_acl_fixed_format_kernels.patch | 1370 ++++++++++++----- .../mkl_dnn/onednn_acl_remove_winograd.patch | 326 ++++ third_party/mkl_dnn/onednn_acl_reorder.patch | 349 +++++ .../onednn_acl_threadpool_scheduler.patch | 17 + 12 files changed, 1970 insertions(+), 737 deletions(-) create mode 100644 third_party/compute_library/acl_acl_reorder.patch create mode 100644 third_party/mkl_dnn/onednn_acl_remove_winograd.patch create mode 100644 third_party/mkl_dnn/onednn_acl_reorder.patch diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl index 36092326eb3b34..2fe9c936529542 100644 --- a/tensorflow/tensorflow.bzl +++ b/tensorflow/tensorflow.bzl @@ -1498,7 +1498,7 @@ def tf_cc_test( "-lpthread", "-lm", ], - clean_dep("//third_party/compute_library:build_with_acl"): [ + clean_dep("@compute_library//:build_with_acl"): [ "-fopenmp", "-lm", ], @@ -1541,7 +1541,7 @@ def tf_cc_shared_test( "-lpthread", "-lm", ], - clean_dep("//third_party/compute_library:build_with_acl"): [ + clean_dep("@compute_library//:build_with_acl"): [ "-fopenmp", "-lm", ], diff --git a/tensorflow/workspace2.bzl b/tensorflow/workspace2.bzl index f740842a64869b..551fb36a288c9a 100644 --- a/tensorflow/workspace2.bzl +++ b/tensorflow/workspace2.bzl @@ -204,11 +204,13 @@ def _tf_repositories(): build_file = "//third_party/mkl_dnn:mkldnn_acl.BUILD", patch_file = [ "//third_party/mkl_dnn:onednn_acl_threadcap.patch", + "//third_party/mkl_dnn:onednn_acl_remove_winograd.patch", "//third_party/mkl_dnn:onednn_acl_fixed_format_kernels.patch", "//third_party/mkl_dnn:onednn_acl_depthwise_convolution.patch", "//third_party/mkl_dnn:onednn_acl_threadpool_scheduler.patch", "//third_party/mkl_dnn:onednn_reorder_padded.patch", "//third_party/mkl_dnn:onednn_acl_reorder_update.patch", + "//third_party/mkl_dnn:onednn_acl_reorder.patch", ], sha256 = "a50993aa6265b799b040fe745e0010502f9f7103cc53a9525d59646aef006633", strip_prefix = "oneDNN-2.7.3", @@ -217,15 +219,10 @@ def _tf_repositories(): tf_http_archive( name = "compute_library", - sha256 = "e20a060d3c4f803889d96c2f0b865004ba3ef4e228299a44339ea1c1ba827c85", - strip_prefix = "ComputeLibrary-22.11", - build_file = "//third_party/compute_library:BUILD", - patch_file = [ - "//third_party/compute_library:compute_library.patch", - "//third_party/compute_library:acl_fixed_format_kernels_striding.patch", - "//third_party/compute_library:acl_openmp_fix.patch", - ], - urls = tf_mirror_urls("https://github.com/ARM-software/ComputeLibrary/archive/v22.11.tar.gz"), + sha256 = "4c22983f08cbc26a7b66c695ee6850d39ea1346a6c76a902323dd10217df4606", + strip_prefix = "ComputeLibrary-23.05", + patch_file = ["//third_party/compute_library:compute_library.patch", "//third_party/compute_library:acl_acl_reorder.patch"], + urls = tf_mirror_urls("https://github.com/ARM-software/ComputeLibrary/archive/v23.05.tar.gz"), ) tf_http_archive( diff --git a/third_party/compute_library/BUILD b/third_party/compute_library/BUILD index 14bde5ac345c80..e69de29bb2d1d6 100644 --- a/third_party/compute_library/BUILD +++ b/third_party/compute_library/BUILD @@ -1,189 +0,0 @@ -load("@bazel_skylib//:bzl_library.bzl", "bzl_library") - -exports_files(["LICENSE"]) - -cc_library( - name = "include", - hdrs = glob([ - "include/**/*.h", - "include/**/*.hpp", - ]), - includes = ["include"], - strip_include_prefix = "include", -) - -_COMPUTE_LIBRARY_DEFINES = [ - "ARM_COMPUTE_OPENMP_SCHEDULER", - "ARM_COMPUTE_CPU_ENABLED", - "ENABLE_NEON", - "ARM_COMPUTE_ENABLE_NEON", - "ENABLE_SVE", - "ARM_COMPUTE_ENABLE_SVE", - "ARM_COMPUTE_ENABLE_BF16", - "ARM_COMPUTE_ENABLE_I8MM", - "ARM_COMPUTE_ENABLE_SVEF32MM", - "ENABLE_FP32_KERNELS", - "ENABLE_QASYMM8_KERNELS", - "ENABLE_QASYMM8_SIGNED_KERNELS", - "ENABLE_QSYMM16_KERNELS", - "ENABLE_INTEGER_KERNELS", - "ENABLE_NHWC_KERNELS", - "ENABLE_NCHW_KERNELS", - "ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS", -] - -cc_library( - name = "arm_compute_sve2", - srcs = glob( - [ - "src/cpu/kernels/**/sve2/*.cpp", - "**/*.h", - "**/*.hpp", - "**/*.inl", - ], - ), - copts = [ - "-march=armv8.6-a+sve2", - "-fopenmp", - ], - defines = _COMPUTE_LIBRARY_DEFINES + ["ARM_COMPUTE_ENABLE_SVE2"], - includes = [ - "src/core/NEON/kernels/arm_conv", - "src/core/NEON/kernels/arm_gemm", - "src/core/NEON/kernels/assembly", - "src/core/cpu/kernels/assembly", - "src/cpu/kernels/assembly", - ], - linkopts = ["-fopenmp"], - deps = ["include"], -) - -cc_library( - name = "arm_compute_sve", - srcs = glob( - [ - "src/core/NEON/kernels/arm_gemm/kernels/sve_*/*.cpp", - "src/core/NEON/kernels/arm_conv/**/kernels/sve_*/*.cpp", - "src/core/NEON/kernels/arm_conv/depthwise/interleaves/sve_*.cpp", - "src/core/NEON/kernels/batchnormalization/impl/SVE/*.cpp", - "src/core/NEON/kernels/convolution/winograd/input_transforms/sve_fp32_6x6.cpp", - "src/cpu/kernels/**/sve/*.cpp", - "**/*.h", - "**/*.hpp", - "**/*.inl", - ], - ) + [ - "src/core/NEON/kernels/arm_gemm/mergeresults-sve.cpp", - "src/core/NEON/kernels/arm_gemm/transform-sve.cpp", - ], - copts = [ - "-march=armv8.2-a+sve", - "-fopenmp", - ], - defines = _COMPUTE_LIBRARY_DEFINES, - includes = [ - "src/core/NEON/kernels/arm_conv", - "src/core/NEON/kernels/arm_gemm", - "src/core/NEON/kernels/assembly", - "src/core/cpu/kernels/assembly", - "src/cpu/kernels/assembly", - ], - linkopts = ["-fopenmp"], - deps = ["include"], -) - -cc_library( - name = "arm_compute", - srcs = glob( - [ - "src/common/**/*.cpp", - "src/core/*.cpp", - "src/core/CPP/kernels/*.cpp", - "src/core/helpers/*.cpp", - "src/core/utils/**/*.cpp", - "src/runtime/**/*.cpp", - "src/c/*.cpp", - "src/core/NEON/kernels/*.cpp", - "src/core/NEON/kernels/convolution/**/*.cpp", - "src/core/NEON/kernels/arm_gemm/kernels/a64_*/*.cpp", - "src/core/NEON/kernels/arm_conv/pooling/*.cpp", - "src/core/NEON/kernels/arm_conv/**/kernels/a64_*/*.cpp", - "src/core/NEON/kernels/arm_conv/depthwise/*.cpp", - "src/core/NEON/kernels/arm_conv/depthwise/interleaves/a64_*.cpp", - "src/core/NEON/kernels/arm_conv/depthwise/interleaves/generic*.cpp", - "src/core/NEON/kernels/batchnormalization/impl/NEON/*.cpp", - "src/cpu/*.cpp", - "src/cpu/kernels/*.cpp", - "src/cpu/kernels/fuse_batch_normalization/**/*.cpp", - "src/cpu/kernels/*/generic/*.cpp", - "src/cpu/operators/**/*.cpp", - "src/cpu/utils/*.cpp", - "src/cpu/kernels/internal/*.cpp", - "src/cpu/kernels/**/neon/*.cpp", - "src/cpu/kernels/**/nchw/*.cpp", - "src/core/NEON/kernels/arm_gemm/*.cpp", - "**/*.h", - "**/*.hpp", - "**/*.inl", - ], - exclude = [ - "src/core/utils/logging/**", - "src/core/TracePoint.cpp", - "src/core/NEON/kernels/arm_gemm/mergeresults-sve.cpp", - "src/core/NEON/kernels/arm_gemm/transform-sve.cpp", - "src/core/NEON/kernels/convolution/winograd/input_transforms/sve_fp32_6x6.cpp", - "src/runtime/CL/**", - "src/gpu/**", - ], - ) + [ - "src/c/operators/AclActivation.cpp", - "src/core/CPP/CPPTypes.cpp", - "src/core/NEON/kernels/arm_conv/addressing.cpp", - "src/core/NEON/kernels/arm_conv/depthwise/interleaves/8b_mla.cpp", - "src/core/NEON/kernels/arm_conv/pooling/kernels/cpp_nhwc_1x1_stride_any_depthfirst/generic.cpp", - ], - hdrs = glob([ - "src/core/NEON/kernels/**/*.h", - "src/core/NEON/kernels/**/*.hpp", - "arm_compute/runtime/**/*.h", - "arm_compute/runtime/*.h", - "arm_compute/core/**/*.h", - "**/*.inl", - ]) + [ - "arm_compute_version.embed", - ], - copts = [ - "-march=armv8-a", - "-fopenmp", - ], - defines = _COMPUTE_LIBRARY_DEFINES, - includes = [ - "arm_compute/runtime", - "src/core/NEON/kernels/assembly", - "src/core/NEON/kernels/convolution/common", - "src/core/NEON/kernels/convolution/winograd", - "src/core/cpu/kernels/assembly", - "src/cpu/kernels/assembly", - ], - linkopts = ["-fopenmp"], - visibility = ["//visibility:public"], - deps = [ - "arm_compute_sve", - "arm_compute_sve2", - "include", - ], -) - -config_setting( - name = "build_with_acl", - define_values = { - "build_with_acl": "true", - }, - visibility = ["//visibility:public"], -) - -bzl_library( - name = "build_defs_bzl", - srcs = ["build_defs.bzl"], - visibility = ["//visibility:public"], -) diff --git a/third_party/compute_library/acl_acl_reorder.patch b/third_party/compute_library/acl_acl_reorder.patch new file mode 100644 index 00000000000000..7f4a7d9f4f8d68 --- /dev/null +++ b/third_party/compute_library/acl_acl_reorder.patch @@ -0,0 +1,42 @@ + ******************************************************************************* + Copyright 2023 Arm Limited and affiliates. + SPDX-License-Identifier: Apache-2.0 + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ******************************************************************************* + +diff --git a/arm_compute/runtime/NEON/functions/NEReorderLayer.h b/arm_compute/runtime/NEON/functions/NEReorderLayer.h +index a9ce8e3e6..eb777f192 100644 +--- a/arm_compute/runtime/NEON/functions/NEReorderLayer.h ++++ b/arm_compute/runtime/NEON/functions/NEReorderLayer.h +@@ -49,7 +49,7 @@ public: + /** Prevent instances of this class from being moved (As this class contains non movable objects) */ + NEReorderLayer &operator=(NEReorderLayer &&) = delete; + /** Default destructor */ +- ~NEReorderLayer() = default; ++ ~NEReorderLayer(); + /** Set the input and output tensors. + * + * Valid data layouts: +diff --git a/src/runtime/NEON/functions/NEReorderLayer.cpp b/src/runtime/NEON/functions/NEReorderLayer.cpp +index 2ab1029f0..427bf8c50 100644 +--- a/src/runtime/NEON/functions/NEReorderLayer.cpp ++++ b/src/runtime/NEON/functions/NEReorderLayer.cpp +@@ -29,6 +29,7 @@ + + namespace arm_compute + { ++NEReorderLayer::~NEReorderLayer() = default; + + NEReorderLayer::NEReorderLayer() + : _reorder_kernel(std::make_unique()) diff --git a/third_party/compute_library/build_defs.bzl b/third_party/compute_library/build_defs.bzl index 74102fd3e6d051..3898798a42d6de 100644 --- a/third_party/compute_library/build_defs.bzl +++ b/third_party/compute_library/build_defs.bzl @@ -1,6 +1,6 @@ def if_enable_acl(if_true, if_false = []): return select({ - "@org_tensorflow//third_party/compute_library:build_with_acl": if_true, + "@compute_library//:build_with_acl": if_true, "//conditions:default": if_false, }) @@ -15,6 +15,6 @@ def acl_deps(): inclusion in the deps attribute of rules. """ return select({ - "@org_tensorflow//third_party/compute_library:build_with_acl": ["@compute_library//:arm_compute"], + "@compute_library//:build_with_acl": ["@compute_library//:arm_compute_core"], "//conditions:default": [], }) diff --git a/third_party/compute_library/compute_library.patch b/third_party/compute_library/compute_library.patch index 2b9619dd03503f..a35bdbfb552a71 100644 --- a/third_party/compute_library/compute_library.patch +++ b/third_party/compute_library/compute_library.patch @@ -1,8 +1,77 @@ + ******************************************************************************* + Copyright 2023 Arm Limited and affiliates. + SPDX-License-Identifier: Apache-2.0 + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ******************************************************************************* +diff --git a/BUILD.bazel b/BUILD.bazel +index f897a1a6a..e27c5a99b 100644 +--- a/BUILD.bazel ++++ b/BUILD.bazel +@@ -138,9 +138,7 @@ cc_library( + "ENABLE_NEON", + "ARM_COMPUTE_CPU_ENABLED", + "ARM_COMPUTE_ENABLE_NEON", +- "ARM_COMPUTE_ENABLE_FP16", + "ARM_COMPUTE_ENABLE_I8MM", +- "ENABLE_FP16_KERNELS", + "ENABLE_FP32_KERNELS", + "ENABLE_QASYMM8_KERNELS", + "ENABLE_QASYMM8_SIGNED_KERNELS", +@@ -174,17 +172,6 @@ cc_library( + visibility = ["//visibility:public"], + ) + +-#--------------------------------------------------------------------- +-# Rule for creating file "arm_compute_version.embed" +-genrule( +- name = "create_version_file", +- srcs = [".git/HEAD"], +- outs = ["arm_compute_version.embed"], +- cmd = "$(location //scripts:print_version_file) bazel-build-options `cat $(location :.git/HEAD)` > $@", +- tools = ["//scripts:print_version_file"], +- visibility = ["//visibility:public"], +-) +- + #--------------------------------------------------------------------- + # Graph library + +@@ -192,7 +179,7 @@ cc_library( + name = "arm_compute_graph", + srcs = ["//src:arm_compute_graph_srcs"], + copts = [ +- "-march=armv8.2-a+fp16", ++ "-march=armv8-a", + ] + select({ + "//:debug_flag": [ + "-O0", +@@ -330,10 +317,10 @@ cc_library( + "core/NEON/kernels/**/*.hpp", + "**/*.inl", + ]) + [ +- "//:create_version_file", ++ "arm_compute_version.embed" + ], + copts = [ +- "-march=armv8.2-a+fp16", ++ "-march=armv8-a", + ] + select({ + "//:debug_flag": [ + "-O0", diff --git a/arm_compute_version.embed b/arm_compute_version.embed new file mode 100644 -index 000000000..c986ad52a +index 000000000..3b3c7d838 --- /dev/null +++ b/arm_compute_version.embed -@@ -0,0 +1,1 @@ -+"arm_compute_version=v22.11 Build options: {} Git hash=b'1b3192e8a23513031163dc14d248f47671986121'" -\ No newline at end of file +@@ -0,0 +1 @@ ++"arm_compute_version=v23.05 Build options: {} Git hash=b'N/A'" diff --git a/third_party/mkl_dnn/mkldnn_acl.BUILD b/third_party/mkl_dnn/mkldnn_acl.BUILD index a1085427ec08da..cfbd515d7815c2 100644 --- a/third_party/mkl_dnn/mkldnn_acl.BUILD +++ b/third_party/mkl_dnn/mkldnn_acl.BUILD @@ -173,6 +173,6 @@ cc_library( ], visibility = ["//visibility:public"], deps = [ - "@compute_library//:arm_compute", + "@compute_library//:arm_compute_core", ], ) diff --git a/third_party/mkl_dnn/onednn_acl_depthwise_convolution.patch b/third_party/mkl_dnn/onednn_acl_depthwise_convolution.patch index 95f0374ec4ddd3..950077665fb4b7 100644 --- a/third_party/mkl_dnn/onednn_acl_depthwise_convolution.patch +++ b/third_party/mkl_dnn/onednn_acl_depthwise_convolution.patch @@ -1,5 +1,5 @@ ******************************************************************************* - Copyright 2022 Arm Limited and affiliates. + Copyright 2023 Arm Limited and affiliates. SPDX-License-Identifier: Apache-2.0 Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,87 +14,93 @@ See the License for the specific language governing permissions and limitations under the License. ******************************************************************************* - diff --git a/src/cpu/aarch64/acl_convolution_utils.cpp b/src/cpu/aarch64/acl_convolution_utils.cpp -index fc93d2aa9..6ebac0d17 100644 +index 6b57374643..85e45ace9d 100644 --- a/src/cpu/aarch64/acl_convolution_utils.cpp +++ b/src/cpu/aarch64/acl_convolution_utils.cpp -@@ -54,10 +54,12 @@ status_t acl_init_conf(acl_conv_conf_t &acp, memory_desc_t &src_md, +@@ -48,11 +48,14 @@ status_t acl_init_conf(acl_conv_conf_t &acp, memory_desc_t &src_md, + if (!is_fwd) return status::unimplemented; + const int ndims = src_d.ndims(); - const bool is_1d = ndims == 3; - const bool is_3d = ndims == 5; -+ const bool is_depthwise = wei_d.ndims() == 5 && wei_d.dims()[1] == 1 && wei_d.dims()[2] == 1; -+ - bool is_nspc; ++ const bool is_depthwise = wei_d.ndims() == 5 && wei_d.dims()[1] == 1 ++ && wei_d.dims()[2] == 1; - // Compute Library unsupported shape scenarios -- if (one_of(true, is_3d, is_1d, with_groups)) { -+ if (one_of(true, is_3d, is_1d, (with_groups && !is_depthwise))) { - return status::unimplemented; - } +- ACL_CHECK_SUPPORT(ndims != 4, " only supports 2 spatial dimensions"); ++ ACL_CHECK_SUPPORT( ++ ndims != 4 && !is_depthwise, " only supports 2 spatial dimensions"); -@@ -135,11 +137,11 @@ status_t acl_init_conf(acl_conv_conf_t &acp, memory_desc_t &src_md, - is_nspc = utils::one_of(src_tag, nhwc); + const int with_groups = wei_d.ndims() == src_d.ndims() + 1; +- ACL_CHECK_SUPPORT(with_groups, " does not support groups"); ++ ACL_CHECK_SUPPORT(with_groups && !is_depthwise, " does not support groups"); - memory_desc_t want_wei_md = weights_md; -- auto wei_tag = is_nspc ? ohwi : oihw; -+ auto wei_tag = is_depthwise ? hwigo : (is_nspc ? ohwi : oihw); - CHECK(memory_desc_init_by_tag(want_wei_md, wei_tag)); + ACL_CHECK_SUPPORT(src_d.data_type() != data_type::f32 + || wei_d.data_type() != data_type::f32 +@@ -108,7 +111,8 @@ status_t acl_init_conf(acl_conv_conf_t &acp, memory_desc_t &src_md, - // Compute Library does not support mismatching layouts -- if ((src_tag != wei_tag) || (src_tag != dst_tag)) -+ if (!is_depthwise && ((src_tag != wei_tag) || (src_tag != dst_tag))) - return status::unimplemented; + acp.with_bias = cd.bias_desc.format_kind != format_kind::undef; - if (weights_md.format_kind == format_kind::any) { -@@ -187,6 +189,12 @@ status_t acl_init_conf(acl_conv_conf_t &acp, memory_desc_t &src_md, - acl_wei_data_t, - acl_layout); +- if (wei_d.format_kind() != format_kind::any) return status::unimplemented; ++ if (wei_d.format_kind() != format_kind::any && !is_depthwise) ++ return status::unimplemented; -+ if(is_depthwise) { -+ // We need to set that values are not constant so that we -+ // we can update them in-place in ACL -+ acp.wei_info.set_are_values_constant(false); + auto src_tag = memory_desc_matches_one_of_tag( + src_md, format_tag::nhwc, format_tag::nchw); +@@ -138,8 +142,12 @@ status_t acl_init_conf(acl_conv_conf_t &acp, memory_desc_t &src_md, + || src_tag != dst_tag) + return status::unimplemented; + +- // Set weights to initially be the same as src +- CHECK(memory_desc_init_by_tag(weights_md, src_tag)); ++ if (is_depthwise) { ++ CHECK(memory_desc_init_by_tag(weights_md, format_tag::hwigo)); ++ } else { ++ // Set weights to initially be the same as src ++ CHECK(memory_desc_init_by_tag(weights_md, src_tag)); + } -+ - acp.dst_info = arm_compute::TensorInfo( - is_nspc ? arm_compute::TensorShape(oc, ow, oh, mb) : - arm_compute::TensorShape(ow, oh, oc, mb), -@@ -212,6 +220,12 @@ status_t acl_init_conf(acl_conv_conf_t &acp, memory_desc_t &src_md, - arm_compute::QuantizationInfo(1.0f / scales[0], 0)); - } + // Bias is just 1D, set to be the obvious format + if (acp.with_bias && bias_md.format_kind == format_kind::any) +@@ -166,6 +174,11 @@ status_t acl_init_conf(acl_conv_conf_t &acp, memory_desc_t &src_md, + 1, + acl_data_type, + acl_layout); + if(is_depthwise) { ++ // We need to set that values are not constant so that we ++ // we can update them in-place in ACL ++ acp.wei_tensor_info.set_are_values_constant(false); ++ } + + acp.dst_tensor_info = arm_compute::TensorInfo( + is_nhwc ? arm_compute::TensorShape(oc, ow, oh, mb) : +@@ -185,6 +198,11 @@ status_t acl_init_conf(acl_conv_conf_t &acp, memory_desc_t &src_md, + // Are we allowed to cast down to bf16 or not? + acp.fast_math + = one_of(attr.fpmath_mode_, fpmath_mode::bf16, fpmath_mode::any); ++ if (is_depthwise) { + // There is no support for fixed format kernels for depthwise convolution + // in ACL so we are going to use weight format that we set up earlier + return status::success; + } -+ + + // WeightFormat::ANY tells ACL we can handle any format acp.weights_info = arm_compute::WeightsInfo( - false, - kw, -@@ -302,6 +316,10 @@ status_t init_conf_gemm(acl_conv_conf_t &acp, memory_desc_t &src_md, +@@ -252,6 +270,7 @@ status_t init_conf_gemm(acl_conv_conf_t &acp, memory_desc_t &src_md, + memory_desc_t &weights_md, memory_desc_t &dst_md, + memory_desc_t &bias_md, const convolution_desc_t &cd, const primitive_attr_t &attr) { - acp.is_indirect = false; ++ if (weights_md.ndims != 4) return status::unimplemented; -+ if(weights_md.ndims != 4) { -+ return status::unimplemented; -+ } -+ // General Compute Library checks, memory tags are also set there CHECK(acl_init_conf(acp, src_md, weights_md, dst_md, bias_md, cd, attr)); +@@ -277,6 +296,7 @@ status_t init_conf_indirect_gemm(acl_conv_conf_t &acp, memory_desc_t &src_md, + memory_desc_t &weights_md, memory_desc_t &dst_md, + memory_desc_t &bias_md, const convolution_desc_t &cd, + const primitive_attr_t &attr) { ++ if (weights_md.ndims != 4) return status::unimplemented; -@@ -330,7 +348,8 @@ status_t init_conf_indirect_gemm(acl_conv_conf_t &acp, memory_desc_t &src_md, - auto math_mode = get_fpmath_mode(); - // Indirect convolution results in slowdown for low thread count or 1x1 - // kernels, so fall back to GEMM-based convolution in these cases -- if (one_of(true, weights_md.dims[2] == 1, // kh -+ if (one_of(true, weights_md.ndims != 4, -+ weights_md.dims[2] == 1, // kh - weights_md.dims[3] == 1, // kw - (!math_mode && dnnl_get_max_threads() < 28))) { - return status::unimplemented; -@@ -355,6 +374,27 @@ status_t init_conf_indirect_gemm(acl_conv_conf_t &acp, memory_desc_t &src_md, + // Indirect is slower for small convolution kernels + if (weights_md.dims[2] == 1 && weights_md.dims[3] == 1) +@@ -314,6 +334,22 @@ status_t init_conf_indirect_gemm(acl_conv_conf_t &acp, memory_desc_t &src_md, return status::success; } @@ -102,41 +108,26 @@ index fc93d2aa9..6ebac0d17 100644 + memory_desc_t &weights_md, memory_desc_t &dst_md, + memory_desc_t &bias_md, const convolution_desc_t &cd, + const primitive_attr_t &attr) { -+ acp.is_indirect = false; -+ // We need to make sure that number of dimensions for weights is either 5 or 3 -+ if(weights_md.ndims != 5) -+ return status::unimplemented; ++ if (weights_md.ndims != 5) return status::unimplemented; + + CHECK(acl_init_conf(acp, src_md, weights_md, dst_md, bias_md, cd, attr)); + + ACL_CHECK_VALID(arm_compute::NEDepthwiseConvolutionLayer::validate( -+ &acp.src_info, -+ &acp.wei_info, -+ acp.with_bias ? &acp.bia_info : nullptr, -+ &acp.dst_info, -+ acp.padstride_info)); ++ &acp.src_tensor_info, &acp.wei_tensor_info, ++ acp.with_bias ? &acp.bia_tensor_info : nullptr, ++ &acp.dst_tensor_info, acp.padstride_info)); + + return status::success; +} + - status_t init_conf_wino(acl_conv_conf_t &acp, memory_desc_t &src_md, - memory_desc_t &weights_md, memory_desc_t &dst_md, - memory_desc_t &bias_md, const convolution_desc_t &cd, -@@ -364,7 +404,8 @@ status_t init_conf_wino(acl_conv_conf_t &acp, memory_desc_t &src_md, - // Under these conditions, fallback to faster GEMM-based convolution - // unless the user explicitly specifies Winograd algorithm - // clang-format off -- if (one_of(true, src_md.dims[2] > 112, // ih -+ if (one_of(true, weights_md.ndims != 4, -+ src_md.dims[2] > 112, // ih - src_md.dims[3] > 112, // iw - src_md.dims[1] < 64, // ic - dst_md.dims[1] < 64, // oc + } // namespace acl_convolution_utils + + } // namespace aarch64 diff --git a/src/cpu/aarch64/acl_convolution_utils.hpp b/src/cpu/aarch64/acl_convolution_utils.hpp -index 44dc8eecb..7eae5cbb1 100644 +index e3d40a5e75..1ded5826c4 100644 --- a/src/cpu/aarch64/acl_convolution_utils.hpp +++ b/src/cpu/aarch64/acl_convolution_utils.hpp -@@ -67,6 +67,11 @@ status_t init_conf_indirect_gemm(acl_conv_conf_t &acp, memory_desc_t &src_md, +@@ -66,6 +66,11 @@ status_t init_conf_indirect_gemm(acl_conv_conf_t &acp, memory_desc_t &src_md, memory_desc_t &bias_md, const convolution_desc_t &cd, const primitive_attr_t &attr); @@ -145,37 +136,17 @@ index 44dc8eecb..7eae5cbb1 100644 + memory_desc_t &bias_md, const convolution_desc_t &cd, + const primitive_attr_t &attr); + - status_t init_conf_wino(acl_conv_conf_t &acp, memory_desc_t &src_md, - memory_desc_t &weights_md, memory_desc_t &dst_md, - memory_desc_t &bias_md, const convolution_desc_t &cd, -diff --git a/src/cpu/cpu_convolution_list.cpp b/src/cpu/cpu_convolution_list.cpp -index 4142dbc7e..1800aaf58 100644 ---- a/src/cpu/cpu_convolution_list.cpp -+++ b/src/cpu/cpu_convolution_list.cpp -@@ -65,6 +65,7 @@ using namespace dnnl::impl::cpu::x64; - #if DNNL_AARCH64 && DNNL_AARCH64_USE_ACL - #include "cpu/aarch64/acl_gemm_convolution.hpp" - #include "cpu/aarch64/acl_indirect_gemm_convolution.hpp" -+#include "cpu/aarch64/acl_depthwise_convolution.hpp" - #include "cpu/aarch64/acl_winograd_convolution.hpp" - #endif - using namespace dnnl::impl::cpu::aarch64; -@@ -104,6 +105,7 @@ const std::map> &impl_list_map() - CPU_INSTANCE_AARCH64(jit_sve_512_dw_convolution_fwd_t) - CPU_INSTANCE_AARCH64(jit_sve_512_1x1_convolution_fwd_f32_t) - CPU_INSTANCE_AARCH64(jit_sve_512_convolution_fwd_t) -+ CPU_INSTANCE_AARCH64_ACL(acl_depthwise_convolution_fwd_t) - CPU_INSTANCE_AARCH64_ACL(acl_indirect_gemm_convolution_fwd_t) - CPU_INSTANCE_AARCH64_ACL(acl_gemm_convolution_fwd_t) - CPU_INSTANCE(gemm_convolution_fwd_t) + } // namespace acl_convolution_utils + + template _lock {this->mtx}; -+ -+ auto *acl_resource -+ = ctx.get_resource_mapper()->get( -+ this); -+ acl_obj_t &acl_depthwise_obj -+ = acl_resource->get_acl_obj(); -+ -+ return execute_forward_conv_acl, pd_t, -+ data_t>(ctx, acl_depthwise_obj, pd()); -+ } -+ -+} -+} -+} ++ const exec_ctx_t &ctx) const { ++ std::lock_guard _lock {this->mtx}; ++ ++ auto *acl_resource ++ = ctx.get_resource_mapper() ++ ->get(this); ++ acl_obj_t &acl_depthwise_obj ++ = acl_resource->get_acl_obj(); ++ ++ return execute_forward_conv_acl< ++ acl_obj_t, pd_t, data_t>( ++ ctx, acl_depthwise_obj, pd()); +} ++ ++} // namespace aarch64 ++} // namespace cpu ++} // namespace impl ++} // namespace dnnl diff --git a/src/cpu/aarch64/acl_depthwise_convolution.hpp b/src/cpu/aarch64/acl_depthwise_convolution.hpp new file mode 100644 -index 000000000..d84fc4fb5 +index 0000000000..3e3d02cf41 --- /dev/null +++ b/src/cpu/aarch64/acl_depthwise_convolution.hpp -@@ -0,0 +1,139 @@ +@@ -0,0 +1,141 @@ +/******************************************************************************* -+* Copyright 2022 Arm Ltd. and affiliates ++* Copyright 2023 Arm Ltd. and affiliates +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. @@ -240,8 +212,8 @@ index 000000000..d84fc4fb5 +#ifndef CPU_AARCH64_ACL_DEPTHWISE_CONVOLUTION_HPP +#define CPU_AARCH64_ACL_DEPTHWISE_CONVOLUTION_HPP + -+#include "cpu/cpu_convolution_pd.hpp" +#include "cpu/aarch64/acl_convolution_utils.hpp" ++#include "cpu/cpu_convolution_pd.hpp" + +namespace dnnl { +namespace impl { @@ -250,15 +222,16 @@ index 000000000..d84fc4fb5 + +struct acl_depthwise_convolution_resource_t : public resource_t { + acl_depthwise_convolution_resource_t() -+ : acl_obj_(utils::make_unique>()) {} ++ : acl_obj_(utils::make_unique< ++ acl_obj_t>()) {} + + status_t configure(const acl_conv_conf_t &acp) { -+ if(!acl_obj_) return status::out_of_memory; ++ if (!acl_obj_) return status::out_of_memory; + -+ acl_obj_->src_tensor.allocator()->init(acp.src_info); -+ acl_obj_->wei_tensor.allocator()->init(acp.wei_info); -+ acl_obj_->dst_tensor.allocator()->init(acp.dst_info); -+ acl_obj_->bia_tensor.allocator()->init(acp.bia_info); ++ acl_obj_->src_tensor.allocator()->init(acp.src_tensor_info); ++ acl_obj_->wei_tensor.allocator()->init(acp.wei_tensor_info); ++ acl_obj_->dst_tensor.allocator()->init(acp.dst_tensor_info); ++ acl_obj_->bia_tensor.allocator()->init(acp.bia_tensor_info); + + // clang-format off + acl_obj_->conv.configure( @@ -281,14 +254,14 @@ index 000000000..d84fc4fb5 + DNNL_DISALLOW_COPY_AND_ASSIGN(acl_depthwise_convolution_resource_t); + +private: -+ std::unique_ptr> acl_obj_; -+ ++ std::unique_ptr> ++ acl_obj_; +}; + +struct acl_depthwise_convolution_fwd_t : public primitive_t { + + struct pd_t : public cpu_convolution_fwd_pd_t { -+ pd_t(const convolution_desc_t* adesc, const primitive_attr_t *attr, ++ pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr, + const typename pd_t::base_class *hint_fwd_pd) + : cpu_convolution_fwd_pd_t(adesc, attr, hint_fwd_pd), acp_() {} + @@ -297,16 +270,18 @@ index 000000000..d84fc4fb5 + + status_t init(engine_t *engine) { + using namespace data_type; -+ using smask_t = primitive_attr_t::skip_mask_t; + ++ const bool is_fp16_ok = expect_data_types(f16, f16, f16, f16, undef) ++ && attr()->has_default_values( ++ primitive_attr_t::skip_mask_t::post_ops, f16); ++ const bool is_fp32_ok = expect_data_types(f32, f32, f32, f32, undef) ++ && attr()->has_default_values( ++ primitive_attr_t::skip_mask_t::post_ops, f32); + bool ok = is_fwd() + && set_default_alg_kind(alg_kind::convolution_direct) -+ && expect_data_types(data_type::f32, data_type::f32, -+ data_type::f32, data_type::f32, undef) -+ && !has_zero_dim_memory() -+ && attr()->has_default_values( -+ smask_t::post_ops, data_type::f32); -+ if(!ok) return status::unimplemented; ++ && utils::one_of(true, is_fp16_ok, is_fp32_ok) ++ && !has_zero_dim_memory(); ++ if (!ok) return status::unimplemented; + + CHECK(acl_convolution_utils::init_conf_depthwise(acp_, src_md_, + weights_md_, dst_md_, bias_md_, *desc(), *attr())); @@ -326,32 +301,31 @@ index 000000000..d84fc4fb5 + acl_depthwise_convolution_fwd_t(const pd_t *apd) : primitive_t(apd) {} + + status_t create_resource( -+ engine_t *engine, resource_mapper_t &mapper) const override { -+ if(mapper.has_resource(this)) return status::success; ++ engine_t *engine, resource_mapper_t &mapper) const override { ++ if (mapper.has_resource(this)) return status::success; + -+ auto r = utils::make_unique(); -+ if(!r) return status::out_of_memory; ++ auto r = utils::make_unique(); ++ if (!r) return status::out_of_memory; + -+ CHECK(r->configure(pd()->acp_)); -+ mapper.add(this, std::move(r)); ++ CHECK(r->configure(pd()->acp_)); ++ mapper.add(this, std::move(r)); + -+ CHECK(pd()->post_ops.create_resource(engine, mapper)); ++ CHECK(pd()->post_ops.create_resource(engine, mapper)); + -+ return status::success; -+ } ++ return status::success; ++ } + -+ typedef typename prec_traits::type data_t; ++ typedef typename prec_traits::type data_t; + -+ status_t execute(const exec_ctx_t &ctx) const override { -+ return execute_forward(ctx); -+ } ++ status_t execute(const exec_ctx_t &ctx) const override { ++ return execute_forward(ctx); ++ } + +private: + mutable std::mutex mtx; + status_t execute_forward(const exec_ctx_t &ctx) const; + + const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } -+ +}; + +} // namespace aarch64 @@ -360,3 +334,23 @@ index 000000000..d84fc4fb5 +} // namespace dnnl + +#endif // CPU_AARCH64_ACL_DEPTHWISE_CONVOLUTION_HPP +diff --git a/src/cpu/cpu_convolution_list.cpp b/src/cpu/cpu_convolution_list.cpp +index 094c73aa36..80385432d8 100644 +--- a/src/cpu/cpu_convolution_list.cpp ++++ b/src/cpu/cpu_convolution_list.cpp +@@ -63,6 +63,7 @@ using namespace dnnl::impl::cpu::x64; + #include "cpu/aarch64/jit_sve_512_x8s8s32x_convolution.hpp" + #include "cpu/aarch64/jit_uni_dw_convolution.hpp" + #if DNNL_AARCH64 && DNNL_AARCH64_USE_ACL ++#include "cpu/aarch64/acl_depthwise_convolution.hpp" + #include "cpu/aarch64/acl_gemm_convolution.hpp" + #include "cpu/aarch64/acl_indirect_gemm_convolution.hpp" + #endif +@@ -102,6 +103,7 @@ const std::map> &impl_list_map() + CPU_INSTANCE_AARCH64(jit_sve_512_dw_convolution_fwd_t) + CPU_INSTANCE_AARCH64(jit_sve_512_1x1_convolution_fwd_f32_t) + CPU_INSTANCE_AARCH64(jit_sve_512_convolution_fwd_t) ++ CPU_INSTANCE_AARCH64_ACL(acl_depthwise_convolution_fwd_t) + CPU_INSTANCE_AARCH64_ACL(acl_indirect_gemm_convolution_fwd_t) + CPU_INSTANCE_AARCH64_ACL(acl_gemm_convolution_fwd_t) + CPU_INSTANCE(gemm_convolution_fwd_t) diff --git a/third_party/mkl_dnn/onednn_acl_fixed_format_kernels.patch b/third_party/mkl_dnn/onednn_acl_fixed_format_kernels.patch index 2c8af08ab8a4ff..282e839bf1eb36 100644 --- a/third_party/mkl_dnn/onednn_acl_fixed_format_kernels.patch +++ b/third_party/mkl_dnn/onednn_acl_fixed_format_kernels.patch @@ -1,5 +1,5 @@ ******************************************************************************* - Copyright 2022 Arm Limited and affiliates. + Copyright 2023 Arm Limited and affiliates. SPDX-License-Identifier: Apache-2.0 Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,178 +14,479 @@ See the License for the specific language governing permissions and limitations under the License. ******************************************************************************* - +diff --git a/src/common/matmul_pd.hpp b/src/common/matmul_pd.hpp +index 4330ad938b..df16c5fcca 100644 +--- a/src/common/matmul_pd.hpp ++++ b/src/common/matmul_pd.hpp +@@ -159,6 +159,19 @@ protected: + + return true; + } ++ ++ // All implementations that do not support sparse inputs/outputs should ++ // call this function. ++ bool is_dense_data() { ++#ifdef DNNL_EXPERIMENTAL_SPARSE ++ for (auto md : {&src_md_, &weights_md_, &bias_md_, &dst_md_}) { ++ if (memory_desc_wrapper(md).format_kind() == format_kind::sparse) ++ return false; ++ } ++#endif ++ return true; ++ } ++ + }; + + } // namespace impl diff --git a/src/cpu/aarch64/acl_convolution_utils.cpp b/src/cpu/aarch64/acl_convolution_utils.cpp -index c46d69757..fc93d2aa9 100644 +index 37f8ecbc06..6b57374643 100644 --- a/src/cpu/aarch64/acl_convolution_utils.cpp +++ b/src/cpu/aarch64/acl_convolution_utils.cpp -@@ -212,6 +212,87 @@ status_t acl_init_conf(acl_conv_conf_t &acp, memory_desc_t &src_md, - arm_compute::QuantizationInfo(1.0f / scales[0], 0)); - } +@@ -41,25 +41,23 @@ status_t acl_init_conf(acl_conv_conf_t &acp, memory_desc_t &src_md, + const memory_desc_wrapper dst_d(&dst_md); + const memory_desc_wrapper bia_d(&bias_md); + +- auto math_mode = get_fpmath_mode(); +- acp.fast_math = one_of(math_mode, fpmath_mode::bf16, fpmath_mode::any); +- + // Compute Library currently supports forward propagation only + const prop_kind_t prop_kind = cd.prop_kind; + const bool is_fwd = (prop_kind == dnnl_forward_training) + || (prop_kind == dnnl_forward_inference); + if (!is_fwd) return status::unimplemented; + +- const int with_groups = wei_d.ndims() == src_d.ndims() + 1; + const int ndims = src_d.ndims(); +- const bool is_1d = ndims == 3; +- const bool is_3d = ndims == 5; +- bool is_nspc; + +- // Compute Library unsupported shape scenarios +- if (one_of(true, is_3d, is_1d, with_groups)) { +- return status::unimplemented; +- } ++ ACL_CHECK_SUPPORT(ndims != 4, " only supports 2 spatial dimensions"); ++ ++ const int with_groups = wei_d.ndims() == src_d.ndims() + 1; ++ ACL_CHECK_SUPPORT(with_groups, " does not support groups"); ++ ++ ACL_CHECK_SUPPORT(src_d.data_type() != data_type::f32 ++ || wei_d.data_type() != data_type::f32 ++ || dst_d.data_type() != data_type::f32, ++ " src, dst and wei must be fp32"); + + // batch size + const int mb = src_d.dims()[0]; +@@ -110,108 +108,143 @@ status_t acl_init_conf(acl_conv_conf_t &acp, memory_desc_t &src_md, + + acp.with_bias = cd.bias_desc.format_kind != format_kind::undef; + +- auto set_or_check_tags = [&](format_tag_t desired_src_tag, +- format_tag_t desired_dst_tag) -> status_t { +- using namespace format_tag; +- auto src_tag = any, dst_tag = any; +- +- if (src_d.format_kind() == format_kind::any) { +- CHECK(memory_desc_init_by_tag(src_md, desired_src_tag)); +- src_tag = desired_src_tag; +- } else { +- src_tag = memory_desc_matches_one_of_tag(src_md, nhwc, nchw); +- } +- +- if (dst_d.format_kind() == format_kind::any) { +- CHECK(memory_desc_init_by_tag(dst_md, desired_dst_tag)); +- dst_tag = desired_dst_tag; +- } else { +- dst_tag = memory_desc_matches_one_of_tag(dst_md, nhwc, nchw); +- } +- +- if (acp.with_bias && bias_md.format_kind == format_kind::any) +- CHECK(memory_desc_init_by_tag(bias_md, x)); +- +- is_nspc = utils::one_of(src_tag, nhwc); +- +- memory_desc_t want_wei_md = weights_md; +- auto wei_tag = is_nspc ? ohwi : oihw; +- CHECK(memory_desc_init_by_tag(want_wei_md, wei_tag)); +- +- // Compute Library does not support mismatching layouts +- if ((src_tag != wei_tag) || (src_tag != dst_tag)) +- return status::unimplemented; ++ if (wei_d.format_kind() != format_kind::any) return status::unimplemented; ++ ++ auto src_tag = memory_desc_matches_one_of_tag( ++ src_md, format_tag::nhwc, format_tag::nchw); ++ auto dst_tag = memory_desc_matches_one_of_tag( ++ dst_md, format_tag::nhwc, format_tag::nchw); ++ ++ // We want src and dst to match, preferrably both to be NHWC ++ if (src_d.format_kind() == format_kind::any ++ && dst_d.format_kind() == format_kind::any) { ++ CHECK(memory_desc_init_by_tag(src_md, format_tag::nhwc)); ++ CHECK(memory_desc_init_by_tag(dst_md, format_tag::nhwc)); ++ } else if (src_d.format_kind() == format_kind::any ++ && dst_tag != format_tag::undef) { ++ CHECK(memory_desc_init_by_tag(src_md, dst_tag)); ++ } else if (dst_d.format_kind() == format_kind::any ++ && src_tag != format_tag::undef) { ++ CHECK(memory_desc_init_by_tag(dst_md, src_tag)); ++ } + +- if (weights_md.format_kind == format_kind::any) { +- weights_md = want_wei_md; +- } +- return (want_wei_md == weights_md) ? status::success +- : status::unimplemented; +- }; ++ // Recompute tags after potentially running memory desc init ++ src_tag = memory_desc_matches_one_of_tag( ++ src_md, format_tag::nhwc, format_tag::nchw); ++ dst_tag = memory_desc_matches_one_of_tag( ++ dst_md, format_tag::nhwc, format_tag::nchw); + +- auto default_dat_tag = format_tag::nhwc; +- if (set_or_check_tags(default_dat_tag, default_dat_tag) != status::success) ++ if (src_tag == format_tag::undef || dst_tag == format_tag::undef ++ || src_tag != dst_tag) + return status::unimplemented; + +- const auto acl_layout = is_nspc ? arm_compute::DataLayout::NHWC +- : arm_compute::DataLayout::NCHW; ++ // Set weights to initially be the same as src ++ CHECK(memory_desc_init_by_tag(weights_md, src_tag)); + +- // For convolutions, int8 datatypes imply quantized types in ACL +- acp.is_int8 = utils::one_of(src_d.data_type(), s8, u8) +- && wei_d.data_type() == s8; ++ // Bias is just 1D, set to be the obvious format ++ if (acp.with_bias && bias_md.format_kind == format_kind::any) ++ CHECK(memory_desc_init_by_tag(bias_md, format_tag::x)); + +- auto acl_src_data_t +- = acl_utils::get_acl_data_t(src_d.data_type(), acp.is_int8); +- auto acl_wei_data_t +- = acl_utils::get_acl_data_t(wei_d.data_type(), acp.is_int8); +- auto acl_dst_data_t +- = acl_utils::get_acl_data_t(dst_d.data_type(), acp.is_int8); +- auto acl_bia_data_t +- = acl_utils::get_acl_data_t(bia_d.data_type(), acp.is_int8); ++ bool is_nhwc = src_tag == format_tag::nhwc; ++ // The layouts have to match (although we may later modify the weights) ++ const auto acl_layout = is_nhwc ? arm_compute::DataLayout::NHWC ++ : arm_compute::DataLayout::NCHW; + +- if (acl_bia_data_t == arm_compute::DataType::UNKNOWN) +- acl_bia_data_t = arm_compute::DataType::F32; ++ auto acl_data_type = arm_compute::DataType::F32; + // clang-format off +- acp.src_info = arm_compute::TensorInfo( +- is_nspc ? arm_compute::TensorShape(ic, iw, ih, mb) : ++ acp.src_tensor_info = arm_compute::TensorInfo( ++ is_nhwc ? arm_compute::TensorShape(ic, iw, ih, mb) : + arm_compute::TensorShape(iw, ih, ic, mb), + 1, +- acl_src_data_t, ++ acl_data_type, + acl_layout); + +- acp.wei_info = arm_compute::TensorInfo( +- is_nspc ? arm_compute::TensorShape(ic, kw, kh, oc) : ++ acp.wei_tensor_info = arm_compute::TensorInfo( ++ is_nhwc ? arm_compute::TensorShape(ic, kw, kh, oc) : + arm_compute::TensorShape(kw, kh, ic, oc), + 1, +- acl_wei_data_t, ++ acl_data_type, + acl_layout); + +- acp.dst_info = arm_compute::TensorInfo( +- is_nspc ? arm_compute::TensorShape(oc, ow, oh, mb) : ++ acp.dst_tensor_info = arm_compute::TensorInfo( ++ is_nhwc ? arm_compute::TensorShape(oc, ow, oh, mb) : + arm_compute::TensorShape(ow, oh, oc, mb), + 1, +- acl_dst_data_t, ++ acl_data_type, + acl_layout); + +- acp.bia_info = arm_compute::TensorInfo( ++ acp.bia_tensor_info = arm_compute::TensorInfo( + acp.with_bias ? arm_compute::TensorShape(oc) + : arm_compute::TensorShape(), + 1, +- acl_bia_data_t, ++ acl_data_type, + acl_layout); + // clang-format on + +- // Add quantization info to tensors +- if (acp.is_int8) { +- const float *scales = attr.output_scales_.scales_; +- acp.src_info.set_quantization_info(arm_compute::QuantizationInfo(1, 0)); +- acp.bia_info.set_quantization_info(arm_compute::QuantizationInfo(1, 0)); +- acp.wei_info.set_quantization_info(arm_compute::QuantizationInfo(1, 0)); +- acp.dst_info.set_quantization_info( +- arm_compute::QuantizationInfo(1.0f / scales[0], 0)); ++ // Are we allowed to cast down to bf16 or not? ++ acp.fast_math ++ = one_of(attr.fpmath_mode_, fpmath_mode::bf16, fpmath_mode::any); ++ ++ // WeightFormat::ANY tells ACL we can handle any format + acp.weights_info = arm_compute::WeightsInfo( -+ false, -+ kw, -+ kh, -+ oc, -+ false, -+ arm_compute::WeightFormat::ANY); ++ false, kw, kh, oc, false, arm_compute::WeightFormat::ANY); ++ ++ // Get the format that the ACL kernel will expect the weights to be ++ // in (if a kernel exists). Note that these are referred to as fixed format ++ // kernels, because they require one specific weights format + arm_compute::WeightFormat expected_weight_format; -+ auto acl_st = arm_compute::NEGEMMConvolutionLayer::has_opt_impl( -+ expected_weight_format, -+ &acp.src_info, -+ &acp.wei_info, -+ acp.with_bias ? &acp.bia_info : nullptr, -+ &acp.dst_info, -+ acp.padstride_info, -+ acp.weights_info, -+ acp.dilation_info, -+ acp.act_info, -+ acp.fast_math); -+ if(acl_st.error_code() != arm_compute::ErrorCode::OK) { -+ return status::unimplemented; -+ } ++ ACL_CHECK_VALID(arm_compute::NEGEMMConvolutionLayer::has_opt_impl( ++ expected_weight_format, &acp.src_tensor_info, &acp.wei_tensor_info, ++ acp.with_bias ? &acp.bia_tensor_info : nullptr, ++ &acp.dst_tensor_info, acp.padstride_info, acp.weights_info, ++ acp.dilation_info, acp.act_info, acp.fast_math)); ++ ++ // Set weights info to the one returned by has_opt_impl + acp.weights_info.set_weight_format(expected_weight_format); + -+ int interleaved_by = arm_compute::interleave_by(expected_weight_format); -+ int block_by = arm_compute::block_by(expected_weight_format); ++ // has_opt_impl may return a non fast math kernel, even if we requested one ++ acp.fast_math ++ = arm_compute::is_fixed_format_fast_math(expected_weight_format); + -+ bool is_fast_math_kernel = arm_compute::is_fixed_format_fast_math(expected_weight_format); -+ if(!is_fast_math_kernel) { -+ // FP32 kernel is faster then BF16 -+ acp.fast_math = false; -+ } ++ // Map OIHW used in ACL WeightFormat to the logical dimensions of the memory descriptor ++ dim_t O_dim = 0; ++ dim_t I_dim = 1; ++ dim_t H_dim = 2; ++ dim_t W_dim = 3; + -+ memory_desc_t want_wei_md = weights_md; -+ -+ int ic_multiply = ic; -+ if(ic % block_by != 0) { -+ ic_multiply = utils::div_up(ic, block_by) * block_by; -+ // Also we need to set padded dimensions as well -+ want_wei_md.padded_dims[1] = ic_multiply; -+ } else { -+ // If we do not need to pad input channels for fast math mode -+ // then it would be faster to run convolution with im2row -+ // instead of using indirect buffer -+ if(acp.fast_math && acp.is_indirect) { ++ if (!is_nhwc) { ++ // We can try to support NCHW by swapping IHW around, note that this ++ // requires weights_md.dims[I_dim] % block_by != 0 (see next block) ++ O_dim = 0; ++ I_dim = 3; ++ H_dim = 1; ++ W_dim = 2; + } + ++ // We can't currently support nchw and block_by != 1. If this is the case, ++ // try a non fast math kernel, which currently have no blocking ++ int block_by = arm_compute::block_by(acp.weights_info.weight_format()); ++ if (!is_nhwc && weights_md.dims[I_dim] % block_by != 0 && acp.fast_math) { ++ acp.fast_math = false; ++ acp.weights_info.set_weight_format(arm_compute::WeightFormat::ANY); ++ ACL_CHECK_VALID(arm_compute::NEGEMMConvolutionLayer::has_opt_impl( ++ expected_weight_format, &acp.src_tensor_info, ++ &acp.wei_tensor_info, ++ acp.with_bias ? &acp.bia_tensor_info : nullptr, ++ &acp.dst_tensor_info, acp.padstride_info, acp.weights_info, ++ acp.dilation_info, acp.act_info, acp.fast_math)); ++ acp.weights_info.set_weight_format(expected_weight_format); ++ block_by = arm_compute::block_by(expected_weight_format); ++ // This shouldn't happen, because non-fastmath have no blocking, but ++ // guard against it because it would silently return incorrect results ++ if (weights_md.dims[I_dim] % block_by != 0) + return status::unimplemented; -+ } -+ } -+ if(oc % interleaved_by != 0) { -+ int padded_dim = utils::div_up(oc, interleaved_by) * interleaved_by; -+ want_wei_md.padded_dims[0] = padded_dim; -+ } -+ -+ // Set strides based on blocking information -+ want_wei_md.format_desc.blocking.strides[0] = interleaved_by*ic_multiply*kw*kh; -+ want_wei_md.format_desc.blocking.strides[1] = interleaved_by*block_by; -+ want_wei_md.format_desc.blocking.strides[2] = interleaved_by*ic_multiply*kw; -+ want_wei_md.format_desc.blocking.strides[3] = interleaved_by*ic_multiply; -+ -+ acl_utils::update_strides_y_and_z( -+ acp.wei_info, -+ want_wei_md.format_desc.blocking.strides[0] * wei_d.data_type_size(), -+ acp.wei_info.strides_in_bytes().z()); -+ -+ // Set blocking -+ want_wei_md.format_desc.blocking.inner_nblks = (block_by > 1) + 1; -+ want_wei_md.format_desc.blocking.inner_idxs[0] = 0; // second to last dimension in abcd format -+ want_wei_md.format_desc.blocking.inner_blks[0] = interleaved_by; -+ -+ if(block_by > 1) { -+ want_wei_md.format_desc.blocking.inner_idxs[1] = 1; // second to last dimension in abcd format -+ want_wei_md.format_desc.blocking.inner_blks[1] = block_by; -+ } -+ -+ if(is_fast_math_kernel) { -+ // If it is fast math mode we need weights in BFloat16 -+ want_wei_md.data_type = dnnl_bf16; + } + -+ weights_md = want_wei_md; ++ acl_utils::reorder_to_weight_format(acp.wei_tensor_info, weights_md, ++ expected_weight_format, I_dim, O_dim, {W_dim, H_dim}, {}); + return status::success; } -@@ -219,6 +300,7 @@ status_t init_conf_gemm(acl_conv_conf_t &acp, memory_desc_t &src_md, +@@ -226,10 +259,10 @@ status_t init_conf_gemm(acl_conv_conf_t &acp, memory_desc_t &src_md, + // clang-format off + // Validate convolution manually to check for return status + ACL_CHECK_VALID(arm_compute::NEGEMMConvolutionLayer::validate( +- &acp.src_info, +- &acp.wei_info, +- acp.with_bias ? &acp.bia_info : nullptr, +- &acp.dst_info, ++ &acp.src_tensor_info, ++ &acp.wei_tensor_info, ++ acp.with_bias ? &acp.bia_tensor_info : nullptr, ++ &acp.dst_tensor_info, + acp.padstride_info, + acp.weights_info, + acp.dilation_info, +@@ -244,28 +277,38 @@ status_t init_conf_indirect_gemm(acl_conv_conf_t &acp, memory_desc_t &src_md, memory_desc_t &weights_md, memory_desc_t &dst_md, memory_desc_t &bias_md, const convolution_desc_t &cd, const primitive_attr_t &attr) { -+ acp.is_indirect = false; - - // General Compute Library checks, memory tags are also set there - CHECK(acl_init_conf(acp, src_md, weights_md, dst_md, bias_md, cd, attr)); -@@ -244,11 +326,13 @@ status_t init_conf_indirect_gemm(acl_conv_conf_t &acp, memory_desc_t &src_md, - memory_desc_t &weights_md, memory_desc_t &dst_md, - memory_desc_t &bias_md, const convolution_desc_t &cd, - const primitive_attr_t &attr) { -+ acp.is_indirect = true; -+ auto math_mode = get_fpmath_mode(); - // Indirect convolution results in slowdown for low thread count or 1x1 - // kernels, so fall back to GEMM-based convolution in these cases - if (one_of(true, weights_md.dims[2] == 1, // kh - weights_md.dims[3] == 1, // kw +- // Indirect convolution results in slowdown for low thread count or 1x1 +- // kernels, so fall back to GEMM-based convolution in these cases +- if (one_of(true, weights_md.dims[2] == 1, // kh +- weights_md.dims[3] == 1, // kw - dnnl_get_max_threads() < 28)) { -+ (!math_mode && dnnl_get_max_threads() < 28))) { ++ ++ // Indirect is slower for small convolution kernels ++ if (weights_md.dims[2] == 1 && weights_md.dims[3] == 1) return status::unimplemented; - } +- } -@@ -275,6 +359,7 @@ status_t init_conf_wino(acl_conv_conf_t &acp, memory_desc_t &src_md, - memory_desc_t &weights_md, memory_desc_t &dst_md, - memory_desc_t &bias_md, const convolution_desc_t &cd, - const primitive_attr_t &attr) { -+ acp.is_indirect = false; + CHECK(acl_init_conf(acp, src_md, weights_md, dst_md, bias_md, cd, attr)); + ++ // Indirect is slower than gemm for low thread counts, except for fast math ++ if (dnnl_get_max_threads() < 28 && !acp.fast_math) ++ return status::unimplemented; ++ ++ // If we do not need to pad input channels for fast math mode then it would ++ // be faster to run convolution with im2row instead of using indirect kernel ++ int block_by = arm_compute::block_by(acp.weights_info.weight_format()); ++ int ic = src_md.dims[1]; ++ if (acp.fast_math && ic % block_by == 0) return status::unimplemented; ++ ++ // TODO: remove this once NEGEMMConv2d::validate allows src and weights to mismatch ++ acp.wei_tensor_info.set_data_layout(arm_compute::DataLayout::NHWC); ++ + // clang-format off + // NOTE: indirect convolution method supports only nhwc layout. + ACL_CHECK_VALID(arm_compute::NEGEMMConv2d::validate( +- &acp.src_info, +- &acp.wei_info, +- acp.with_bias ? &acp.bia_info : nullptr, +- &acp.dst_info, ++ &acp.src_tensor_info, ++ &acp.wei_tensor_info, ++ acp.with_bias ? &acp.bia_tensor_info : nullptr, ++ &acp.dst_tensor_info, + arm_compute::Conv2dInfo(acp.padstride_info, + acp.dilation_info, + acp.act_info, + acp.fast_math, +- 1))); ++ 1, {}, acp.weights_info))); + // clang-format on - // Under these conditions, fallback to faster GEMM-based convolution - // unless the user explicitly specifies Winograd algorithm + return status::success; diff --git a/src/cpu/aarch64/acl_convolution_utils.hpp b/src/cpu/aarch64/acl_convolution_utils.hpp -index 3e56245fa..44dc8eecb 100644 +index 0398ab06b9..e3d40a5e75 100644 --- a/src/cpu/aarch64/acl_convolution_utils.hpp +++ b/src/cpu/aarch64/acl_convolution_utils.hpp -@@ -43,6 +43,7 @@ struct acl_conv_conf_t { +@@ -38,17 +38,17 @@ struct acl_obj_t { + + struct acl_conv_conf_t { + bool with_bias; +- bool is_int8; + bool fast_math; // If this is true, the result of the convolution goes into a temporarily // allocated ACL tensor to be accumulated into the oneDNN dst during postops bool use_dst_acc; -+ bool is_indirect; - arm_compute::TensorInfo src_info; - arm_compute::TensorInfo wei_info; - arm_compute::TensorInfo bia_info; +- arm_compute::TensorInfo src_info; +- arm_compute::TensorInfo wei_info; +- arm_compute::TensorInfo bia_info; +- arm_compute::TensorInfo dst_info; ++ arm_compute::TensorInfo src_tensor_info; ++ arm_compute::TensorInfo wei_tensor_info; ++ arm_compute::TensorInfo bia_tensor_info; ++ arm_compute::TensorInfo dst_tensor_info; + arm_compute::PadStrideInfo padstride_info; + arm_compute::Size2D dilation_info; ++ // Additional information about the weights not included in wei_tensor_info + arm_compute::WeightsInfo weights_info; + // Note: this will default to not enabled, and will do nothing + arm_compute::ActivationLayerInfo act_info; +diff --git a/src/cpu/aarch64/acl_gemm_convolution.hpp b/src/cpu/aarch64/acl_gemm_convolution.hpp +index 485db954ea..da58e4f610 100644 +--- a/src/cpu/aarch64/acl_gemm_convolution.hpp ++++ b/src/cpu/aarch64/acl_gemm_convolution.hpp +@@ -1,5 +1,5 @@ + /******************************************************************************* +-* Copyright 2020-2022 Arm Ltd. and affiliates ++* Copyright 2020-2023 Arm Ltd. and affiliates + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. +@@ -36,10 +36,10 @@ struct acl_resource_t : public resource_t { + if (!acl_obj_) return status::out_of_memory; + + // Init Compute Library tensors based on info from descriptor +- acl_obj_->src_tensor.allocator()->init(acp.src_info); +- acl_obj_->wei_tensor.allocator()->init(acp.wei_info); +- acl_obj_->dst_tensor.allocator()->init(acp.dst_info); +- acl_obj_->bia_tensor.allocator()->init(acp.bia_info); ++ acl_obj_->src_tensor.allocator()->init(acp.src_tensor_info); ++ acl_obj_->wei_tensor.allocator()->init(acp.wei_tensor_info); ++ acl_obj_->dst_tensor.allocator()->init(acp.dst_tensor_info); ++ acl_obj_->bia_tensor.allocator()->init(acp.bia_tensor_info); + + acl_obj_->conv.configure(&acl_obj_->src_tensor, &acl_obj_->wei_tensor, + acp.with_bias ? &acl_obj_->bia_tensor : nullptr, diff --git a/src/cpu/aarch64/acl_indirect_gemm_convolution.hpp b/src/cpu/aarch64/acl_indirect_gemm_convolution.hpp -index bcf031a77..4ddc8cf91 100644 +index bcf031a771..b7c8dce894 100644 --- a/src/cpu/aarch64/acl_indirect_gemm_convolution.hpp +++ b/src/cpu/aarch64/acl_indirect_gemm_convolution.hpp -@@ -41,6 +41,7 @@ struct acl_indirect_gemm_resource_t : public resource_t { - acl_obj_->bia_tensor.allocator()->init(acp.bia_info); +@@ -1,5 +1,5 @@ + /******************************************************************************* +-* Copyright 2021-2022 Arm Ltd. and affiliates ++* Copyright 2021-2023 Arm Ltd. and affiliates + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. +@@ -35,10 +35,10 @@ struct acl_indirect_gemm_resource_t : public resource_t { + if (!acl_obj_) return status::out_of_memory; + + // Init Compute Library tensors based on info from descriptor +- acl_obj_->src_tensor.allocator()->init(acp.src_info); +- acl_obj_->wei_tensor.allocator()->init(acp.wei_info); +- acl_obj_->dst_tensor.allocator()->init(acp.dst_info); +- acl_obj_->bia_tensor.allocator()->init(acp.bia_info); ++ acl_obj_->src_tensor.allocator()->init(acp.src_tensor_info); ++ acl_obj_->wei_tensor.allocator()->init(acp.wei_tensor_info); ++ acl_obj_->dst_tensor.allocator()->init(acp.dst_tensor_info); ++ acl_obj_->bia_tensor.allocator()->init(acp.bia_tensor_info); // clang-format off -+ arm_compute::experimental::PostOpList empty_post_ops = arm_compute::experimental::PostOpList {}; acl_obj_->conv.configure( - &acl_obj_->src_tensor, - &acl_obj_->wei_tensor, -@@ -50,7 +51,9 @@ struct acl_indirect_gemm_resource_t : public resource_t { +@@ -50,7 +50,9 @@ struct acl_indirect_gemm_resource_t : public resource_t { acp.dilation_info, acp.act_info, acp.fast_math, - 1)); + 1, -+ empty_post_ops, ++ {}, + acp.weights_info)); // clang-format on return status::success; diff --git a/src/cpu/aarch64/acl_inner_product.hpp b/src/cpu/aarch64/acl_inner_product.hpp -index c5e507085..163ff066e 100644 +index c5e507085f..a27df640fb 100644 --- a/src/cpu/aarch64/acl_inner_product.hpp +++ b/src/cpu/aarch64/acl_inner_product.hpp -@@ -45,6 +45,7 @@ struct acl_ip_conf_t { - arm_compute::TensorInfo bia_info; - arm_compute::TensorInfo dst_info; +@@ -40,11 +40,13 @@ struct acl_ip_conf_t { + // If this is true, the result of the inner product goes into a temporarily + // allocated ACL tensor to be accumulated into the oneDNN dst during postops + bool use_dst_acc; +- arm_compute::TensorInfo src_info; +- arm_compute::TensorInfo wei_info; +- arm_compute::TensorInfo bia_info; +- arm_compute::TensorInfo dst_info; ++ arm_compute::TensorInfo src_tensor_info; ++ arm_compute::TensorInfo wei_tensor_info; ++ arm_compute::TensorInfo bia_tensor_info; ++ arm_compute::TensorInfo dst_tensor_info; arm_compute::FullyConnectedLayerInfo fc_info; ++ // Additional information about the weights not included in wei_tensor_info + arm_compute::WeightsInfo weights_info; }; struct acl_ip_resource_t : public resource_t { acl_ip_resource_t() : acl_ip_obj_(utils::make_unique()) {} -@@ -64,7 +65,8 @@ struct acl_ip_resource_t : public resource_t { +@@ -53,10 +55,10 @@ struct acl_ip_resource_t : public resource_t { + if (!acl_ip_obj_) return status::out_of_memory; + + // Init Compute Library tensors based on info from descriptor +- acl_ip_obj_->src_tensor.allocator()->init(aip.src_info); +- acl_ip_obj_->wei_tensor.allocator()->init(aip.wei_info); +- acl_ip_obj_->dst_tensor.allocator()->init(aip.dst_info); +- acl_ip_obj_->bia_tensor.allocator()->init(aip.bia_info); ++ acl_ip_obj_->src_tensor.allocator()->init(aip.src_tensor_info); ++ acl_ip_obj_->wei_tensor.allocator()->init(aip.wei_tensor_info); ++ acl_ip_obj_->dst_tensor.allocator()->init(aip.dst_tensor_info); ++ acl_ip_obj_->bia_tensor.allocator()->init(aip.bia_tensor_info); + + // clang-format off + acl_ip_obj_->fc.configure( +@@ -64,7 +66,8 @@ struct acl_ip_resource_t : public resource_t { &acl_ip_obj_->wei_tensor, aip.with_bias ? &acl_ip_obj_->bia_tensor : nullptr, &acl_ip_obj_->dst_tensor, @@ -195,41 +496,126 @@ index c5e507085..163ff066e 100644 // clang-format on return status::success; -@@ -156,8 +158,8 @@ struct acl_inner_product_fwd_t : public primitive_t { - src_shape = (src_tag == nc) ? arm_compute::TensorShape(ic, n) - : arm_compute::TensorShape(n, ic); +@@ -89,12 +92,16 @@ struct acl_inner_product_fwd_t : public primitive_t { + DECLARE_COMMON_PD_T("acl", acl_inner_product_fwd_t); + status_t init(engine_t *engine) { +- const bool ok = is_fwd() && !has_zero_dim_memory() +- && expect_data_types(data_type::f32, data_type::f32, +- data_type::f32, data_type::f32, data_type::f32) ++ using namespace data_type; ++ const bool is_fp16_ok = expect_data_types(f16, f16, f16, f16, undef) ++ && attr()->has_default_values( ++ primitive_attr_t::skip_mask_t::post_ops, f16); ++ const bool is_fp32_ok = expect_data_types(f32, f32, f32, f32, undef) + && attr()->has_default_values( +- primitive_attr_t::skip_mask_t::post_ops, +- data_type::f32) ++ primitive_attr_t::skip_mask_t::post_ops, f32); ++ const bool ok = is_fwd() && !has_zero_dim_memory() ++ && utils::one_of(true, is_fp16_ok, is_fp32_ok) + && set_default_params() == status::success; + + if (!ok) return status::unimplemented; +@@ -121,88 +128,46 @@ struct acl_inner_product_fwd_t : public primitive_t { + ACL_CHECK_SUPPORT( + !(is_2d || is_4d), "ACL supports only 2d or 4d cases"); + +- // batch size +- const int n = src_md()->dims[0]; +- +- // input and output channels +- const int ic = src_md()->dims[1]; +- const int oc = dst_md()->dims[1]; +- +- // source spatial dimensions +- const int ih = is_4d ? src_md()->dims[ndims - 2] : 0; +- const int iw = is_4d ? src_md()->dims[ndims - 1] : 0; +- +- // weights spatial dimensions +- const int kh = is_4d ? weights_md()->dims[ndims - 2] : 0; +- const int kw = is_4d ? weights_md()->dims[ndims - 1] : 0; +- +- // Only NCHW or NHWC derivatives supported by ACL kernels + using namespace format_tag; +- auto src_tag = memory_desc_matches_one_of_tag( +- src_md_, nhwc, nchw, nc, cn); +- auto wei_tag = memory_desc_matches_one_of_tag( +- weights_md_, ohwi, oihw, oi, io); +- auto dst_tag = memory_desc_matches_one_of_tag(dst_md_, nc, cn); ++ auto src_tag ++ = memory_desc_matches_one_of_tag(src_md_, nhwc, nchw, nc); ++ auto dst_tag = memory_desc_matches_one_of_tag(dst_md_, nc); + + ACL_CHECK_SUPPORT( +- utils::one_of(format_tag::undef, src_tag, wei_tag, dst_tag), ++ utils::one_of(format_tag::undef, src_tag, dst_tag), + "unsupported memory layout"); + + ACL_CHECK_SUPPORT(is_2d && src_tag != dst_tag, + "for src and dst layouts must match"); + +- arm_compute::TensorShape src_shape, wei_shape; +- if (is_2d) { +- src_shape = (src_tag == nc) ? arm_compute::TensorShape(ic, n) +- : arm_compute::TensorShape(n, ic); +- - wei_shape = (wei_tag == io) ? arm_compute::TensorShape(oc, ic) - : arm_compute::TensorShape(ic, oc); -+ // For fixed format kernels weight shape is always io -+ wei_shape = arm_compute::TensorShape(oc, ic); - } - if (is_4d) { - src_shape = (src_tag == nhwc) -@@ -166,7 +168,8 @@ struct acl_inner_product_fwd_t : public primitive_t { - - // ACL requires the weights to be in 2D flattened shape - const int flattened_ic = is_4d ? ic * kh * kw : ic; +- } +- if (is_4d) { +- src_shape = (src_tag == nhwc) +- ? arm_compute::TensorShape(ic, iw, ih, n) +- : arm_compute::TensorShape(iw, ih, ic, n); +- +- // ACL requires the weights to be in 2D flattened shape +- const int flattened_ic = is_4d ? ic * kh * kw : ic; - wei_shape = arm_compute::TensorShape(flattened_ic, oc); -+ // For fixed format kernels weights shape is always io -+ wei_shape = arm_compute::TensorShape(oc, flattened_ic); - } - - arm_compute::DataLayout src_layout = (src_tag == nhwc) -@@ -183,6 +186,9 @@ struct acl_inner_product_fwd_t : public primitive_t { - aip.wei_info = arm_compute::TensorInfo( - wei_shape, 1, arm_compute::DataType::F32, wei_layout); - -+ aip.weights_info = arm_compute::WeightsInfo( -+ false, 1, 1, is_4d ? ic * kh *kw : ic, false, arm_compute::WeightFormat::ANY); -+ - aip.dst_info - = arm_compute::TensorInfo(arm_compute::TensorShape(oc, n), - 1, arm_compute::DataType::F32); -@@ -194,15 +200,7 @@ struct acl_inner_product_fwd_t : public primitive_t { +- } +- +- arm_compute::DataLayout src_layout = (src_tag == nhwc) +- ? arm_compute::DataLayout::NHWC +- : arm_compute::DataLayout::NCHW; ++ const dim_t ic_total = IC_total(); ++ const dim_t n = MB(); ++ const dim_t oc = OC(); + +- arm_compute::DataLayout wei_layout = (wei_tag == ohwi) +- ? arm_compute::DataLayout::NHWC +- : arm_compute::DataLayout::NCHW; ++ aip.src_tensor_info = arm_compute::TensorInfo( ++ arm_compute::TensorShape(ic_total, n), 1, ++ acl_utils::get_acl_data_t(src_md()->data_type)); + +- aip.src_info = arm_compute::TensorInfo( +- src_shape, 1, arm_compute::DataType::F32, src_layout); ++ // ACL requires the weights to be in 2D flattened shape ++ aip.wei_tensor_info = arm_compute::TensorInfo( ++ arm_compute::TensorShape(oc, ic_total), 1, ++ acl_utils::get_acl_data_t(weights_md(0)->data_type)); + +- aip.wei_info = arm_compute::TensorInfo( +- wei_shape, 1, arm_compute::DataType::F32, wei_layout); +- +- aip.dst_info +- = arm_compute::TensorInfo(arm_compute::TensorShape(oc, n), +- 1, arm_compute::DataType::F32); ++ auto acl_dst_data_t ++ = acl_utils::get_acl_data_t(dst_md()->data_type); ++ aip.dst_tensor_info = arm_compute::TensorInfo( ++ arm_compute::TensorShape(oc, n), 1, acl_dst_data_t); + + aip.with_bias = desc()->bias_desc.format_kind != format_kind::undef; +- aip.bia_info = arm_compute::TensorInfo(aip.with_bias ++ auto acl_bia_data_t = aip.with_bias ++ ? acl_utils::get_acl_data_t(weights_md(1)->data_type) ++ : acl_dst_data_t; ++ aip.bia_tensor_info = arm_compute::TensorInfo(aip.with_bias + ? arm_compute::TensorShape(oc) + : arm_compute::TensorShape(), 1, arm_compute::DataType::F32); - aip.fc_info.weights_trained_layout = wei_layout; +- aip.fc_info.weights_trained_layout = wei_layout; - if (is_2d && wei_tag != src_tag) { - // weights are already transposed - aip.fc_info.transpose_weights = false; @@ -243,294 +629,536 @@ index c5e507085..163ff066e 100644 // Fast math mode auto math_mode = get_fpmath_mode(); -@@ -214,6 +212,80 @@ struct acl_inner_product_fwd_t : public primitive_t { +@@ -214,15 +179,103 @@ struct acl_inner_product_fwd_t : public primitive_t { aip.fc_info.activation_info)); aip.use_dst_acc = post_ops.has_sum(); ++ // WeightFormat::ANY tells ACL we can handle any format ++ aip.weights_info = arm_compute::WeightsInfo(false, 1, 1, ic_total, ++ false, arm_compute::WeightFormat::ANY); ++ ++ // Get the format that the ACL kernel will expect the weights to be ++ // in (if a kernel exists) Note that these are referred to as fixed ++ // format kernels, because they require one specific weights format + arm_compute::WeightFormat expected_weight_format; -+ auto acl_st = arm_compute::NEFullyConnectedLayer::has_opt_impl( -+ expected_weight_format, -+ &aip.src_info, -+ &aip.wei_info, -+ aip.with_bias ? &aip.bia_info : nullptr, -+ &aip.dst_info, -+ aip.fc_info, -+ aip.weights_info); -+ if(acl_st.error_code() != arm_compute::ErrorCode::OK) { -+ return status::unimplemented; -+ } ++ ACL_CHECK_VALID(arm_compute::NEFullyConnectedLayer::has_opt_impl( ++ expected_weight_format, &aip.src_tensor_info, ++ &aip.wei_tensor_info, ++ aip.with_bias ? &aip.bia_tensor_info : nullptr, ++ &aip.dst_tensor_info, aip.fc_info, aip.weights_info)); + ++ // Set weights info to the one returned by has_opt_impl + aip.weights_info.set_weight_format(expected_weight_format); + -+ int interleaved_by = arm_compute::interleave_by(expected_weight_format); -+ int block_by = arm_compute::block_by(expected_weight_format); -+ bool is_fast_math_kernel = arm_compute::is_fixed_format_fast_math(expected_weight_format); ++ // has_opt_impl may return a non fast math kernel, even if requested ++ aip.fc_info.enable_fast_math ++ = arm_compute::is_fixed_format_fast_math( ++ expected_weight_format); + -+ if(!is_fast_math_kernel) { -+ // FP32 kernel might be faster for some cases then BF16 -+ aip.fc_info.enable_fast_math = false; -+ } -+ -+ memory_desc_t want_wei_md = weights_md_; ++ // Inner product is the same as the matmul n x (chw) * (ihw) x o ++ // (note that the src c and weights i both correspond to the input ++ // channel). ACL FullyConnectedLayer assumes the chw dimensions of ++ // src and ihw dimensions of weights are collapsed, so we need to ++ // make sure that they have the same layout. Given that weights are ++ // more often fixed, (so reorders can be hoisted) it makes sense to ++ // reorder the weights to fit the src. + -+ int ic_multiply = ic; -+ if(is_4d) { -+ ic_multiply = ic * kh * kw; ++ // For 4D tensors we need to: ++ // - reorder the ihw of the weights to match the src chw ++ // - collapse ihw ++ // - pad the collapsed ihw ++ // But there is not yet a way to express this collapse+pad as a ++ // reorder. So we try to reorder the weights to match the src, ++ // implicitly collapse ihw in our definition of the weights ++ // TensorInfo and hope that the inner_dim has zero padding ++ // (weights_md_.dims[inner_dim] % block_by == 0). If it does, we ++ // fall back to a kernel without blocking (currently this is ++ // equivalent to non-fastmath). + -+ // Since we are flattening dimensions the memory descriptor -+ // should also be for 2D -+ want_wei_md.ndims = 2; ++ // 2D just works because we just pad the only dimension. + -+ want_wei_md.dims[1] = ic_multiply; -+ want_wei_md.padded_dims[1] = ic_multiply; -+ want_wei_md.format_desc.blocking.strides[1] = 1; ++ // o_dim is always the first logical dimension (oihw, ohwi, oi) ++ dim_t o_dim = 0; ++ dim_t inner_dim; ++ // Rest of logical dimensions in order of innermost to outermost ++ std::vector remaining_dims = {}; + -+ want_wei_md.dims[0] = oc; -+ want_wei_md.padded_dims[0] = want_wei_md.padded_dims[1]; -+ want_wei_md.padded_dims[0] = oc; ++ if (src_tag == nchw) { ++ inner_dim = 3; // w ++ remaining_dims = {2, 1}; // h, i ++ } else if (src_tag == nhwc) { ++ inner_dim = 1; // i ++ remaining_dims = {3, 2}; // w, h ++ } else { // Only remaining case is 2D (nc) ++ inner_dim = 1; // i ++ remaining_dims = {}; // No other dimensions for 2D + } + -+ want_wei_md.format_desc.blocking.strides[1] = interleaved_by * block_by; -+ if(want_wei_md.dims[1] % block_by != 0) { -+ want_wei_md.padded_dims[1] = utils::div_up(want_wei_md.dims[1], block_by) * block_by; -+ } -+ want_wei_md.format_desc.blocking.strides[0] = interleaved_by * want_wei_md.padded_dims[1]; -+ -+ if(oc % interleaved_by != 0) { -+ int padded_dim = utils::div_up(oc, interleaved_by) * interleaved_by; -+ want_wei_md.padded_dims[0] = padded_dim; -+ } -+ -+ int data_type_size = memory_desc_wrapper(want_wei_md).data_type_size(); -+ acl_utils::update_strides_y_and_z( -+ aip.wei_info, -+ want_wei_md.format_desc.blocking.strides[0] * data_type_size, -+ want_wei_md.format_desc.blocking.strides[1] * data_type_size); -+ -+ want_wei_md.format_desc.blocking.inner_nblks = (block_by > 1) + 1; -+ want_wei_md.format_desc.blocking.inner_idxs[0] = 0; -+ want_wei_md.format_desc.blocking.inner_blks[0] = interleaved_by; -+ if(block_by > 1) { -+ want_wei_md.format_desc.blocking.inner_idxs[1] = 1; -+ want_wei_md.format_desc.blocking.inner_blks[1] = block_by; -+ } -+ -+ if(is_fast_math_kernel) { -+ want_wei_md.data_type = dnnl_bf16; ++ // Fallback ++ int block_by = arm_compute::block_by(expected_weight_format); ++ if (is_4d && weights_md_.dims[inner_dim] % block_by != 0 ++ && aip.fc_info.enable_fast_math) { ++ aip.fc_info.enable_fast_math = false; ++ aip.weights_info.set_weight_format( ++ arm_compute::WeightFormat::ANY); ++ ACL_CHECK_VALID( ++ arm_compute::NEFullyConnectedLayer::has_opt_impl( ++ expected_weight_format, &aip.src_tensor_info, ++ &aip.wei_tensor_info, ++ aip.with_bias ? &aip.bia_tensor_info : nullptr, ++ &aip.dst_tensor_info, aip.fc_info, ++ aip.weights_info)); ++ aip.weights_info.set_weight_format(expected_weight_format); ++ block_by = arm_compute::block_by(expected_weight_format); ++ if (weights_md_.dims[inner_dim] % block_by != 0) ++ return status::unimplemented; + } + -+ weights_md_ = want_wei_md; ++ acl_utils::reorder_to_weight_format(aip.wei_tensor_info, ++ weights_md_, expected_weight_format, inner_dim, o_dim, ++ remaining_dims, {}); + // clang-format off ++ // Validate fully connected layer manually to check for return status ACL_CHECK_VALID(arm_compute::NEFullyConnectedLayer::validate( +- &aip.src_info, +- &aip.wei_info, +- aip.with_bias ? &aip.bia_info : nullptr, +- &aip.dst_info, +- aip.fc_info)); ++ &aip.src_tensor_info, ++ &aip.wei_tensor_info, ++ aip.with_bias ? &aip.bia_tensor_info : nullptr, ++ &aip.dst_tensor_info, ++ aip.fc_info, ++ aip.weights_info)); + // clang-format on ++ + return status::success; + } + }; // pd_t diff --git a/src/cpu/aarch64/acl_utils.cpp b/src/cpu/aarch64/acl_utils.cpp -index 79ea775d6..7ee4c7398 100644 +index 79ea775d6d..5792fd4911 100644 --- a/src/cpu/aarch64/acl_utils.cpp +++ b/src/cpu/aarch64/acl_utils.cpp -@@ -157,6 +157,28 @@ status_t tensor_info( - return status::success; +@@ -1,5 +1,5 @@ + /******************************************************************************* +-* Copyright 2021-2022 Arm Ltd. and affiliates ++* Copyright 2021-2023 Arm Ltd. and affiliates + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. +@@ -261,6 +261,75 @@ int reorder_dimensions_by_stride(std::vector permuted_mds, + return reordered_dims; } -+status_t update_strides_y_and_z( -+ arm_compute::TensorInfo &info, const int y, const int z) { ++void reorder_to_weight_format(arm_compute::TensorInfo &info, memory_desc_t &md, ++ arm_compute::WeightFormat wf, dim_t I_dim, dim_t O_dim, ++ std::vector spatial_dims, std::vector batch_dims) { ++ ++ md.format_kind = format_kind::blocked; ++ md.format_desc.blocking = blocking_desc_t {}; ++ const int interleaved_by = arm_compute::interleave_by(wf); ++ const int block_by = arm_compute::block_by(wf); + -+ arm_compute::TensorShape shape = info.tensor_shape(); -+ arm_compute::Strides old_strides_in_bytes = info.strides_in_bytes(); ++ // I dimension becomes densest (apart from blocking) ++ md.format_desc.blocking.strides[I_dim] = interleaved_by * block_by; ++ md.padded_dims[I_dim] = utils::rnd_up(md.dims[I_dim], block_by); + -+ arm_compute::Strides new_strides_in_bytes; -+ for(size_t i = 0; i < shape.num_dimensions(); ++i) { -+ new_strides_in_bytes.set(i, old_strides_in_bytes[i]); ++ // Then any spatial dimensions (e.g. HW) ++ dim_t ldb = interleaved_by * md.padded_dims[I_dim]; ++ for (dim_t sd : spatial_dims) { ++ md.format_desc.blocking.strides[sd] = ldb; ++ ldb *= md.padded_dims[sd]; + } + -+ // set y -+ new_strides_in_bytes.set(1, y); -+ // set z -+ new_strides_in_bytes.set(2, z); ++ // O dim (which was the innermost) becomes the outermost (apart from batching) ++ md.format_desc.blocking.strides[O_dim] = ldb; ++ md.padded_dims[O_dim] = utils::rnd_up(md.dims[O_dim], interleaved_by); + -+ info.init(info.tensor_shape(), info.num_channels(), info.data_type(), -+ new_strides_in_bytes, info.offset_first_element_in_bytes(), info.total_size()); ++ // Update the batch dimensions, starting with stride of the innermost batch ++ const dim_t innermost_batch_stride ++ = md.padded_dims[I_dim] * md.padded_dims[O_dim]; ++ dim_t batch_stride = innermost_batch_stride; ++ for (dim_t bd : batch_dims) { ++ md.format_desc.blocking.strides[bd] = batch_stride; ++ batch_stride *= md.padded_dims[bd]; ++ } ++ ++ // Weights can only be blocked if they are also interleaved ++ if (interleaved_by > 1) { ++ md.format_desc.blocking.inner_nblks = 1 + (block_by > 1); ++ ++ md.format_desc.blocking.inner_idxs[0] = O_dim; ++ md.format_desc.blocking.inner_blks[0] = interleaved_by; ++ if (block_by > 1) { ++ md.format_desc.blocking.inner_idxs[1] = I_dim; ++ md.format_desc.blocking.inner_blks[1] = block_by; ++ } ++ } ++ ++ if (arm_compute::is_fixed_format_fast_math(wf)) { ++ md.data_type = dnnl_bf16; ++ info.set_data_type(arm_compute::DataType::BFLOAT16); ++ } ++ ++ // The data layout is now determined by the manually set strides ++ info.set_data_layout(arm_compute::DataLayout::UNKNOWN); + -+ return status::success; ++ // x is ignored in fixed format kernels ++ // y is the leading dimension of b (ldb) in the GEMM d = a*b + c ++ // This is the stride of O_dim in the md ++ // z is the batch dimension (not strictly needed if there's only 1 batch) ++ // i.e. how much do I need to stride to get to the next matmul (ignoring ++ // the interleaving). Note that we use the innermost_batch_stride ++ // because all the batched dimensions are collapsed (as required by ACL). ++ arm_compute::Strides new_strides_in_bytes = info.strides_in_bytes(); ++ new_strides_in_bytes.set(1, ldb * info.element_size()); ++ new_strides_in_bytes.set(2, innermost_batch_stride * info.element_size()); ++ ++ info.init(info.tensor_shape(), info.num_channels(), info.data_type(), ++ new_strides_in_bytes, info.offset_first_element_in_bytes(), ++ memory_desc_wrapper(md).size()); +} + - status_t insert_singleton_dimension(arm_compute::TensorInfo &ti, size_t dim_i) { + } // namespace acl_utils - // Max 6 dims in ACL, so we can't insert another + } // namespace aarch64 diff --git a/src/cpu/aarch64/acl_utils.hpp b/src/cpu/aarch64/acl_utils.hpp -index 28693bb16..c7c9e1278 100644 +index 28693bb167..d9affe1c8f 100644 --- a/src/cpu/aarch64/acl_utils.hpp +++ b/src/cpu/aarch64/acl_utils.hpp -@@ -62,6 +62,9 @@ status_t tensor_info(arm_compute::TensorInfo &info, const memory_desc_t &md); - status_t tensor_info( - arm_compute::TensorInfo &info, const memory_desc_wrapper &md); +@@ -1,5 +1,5 @@ + /******************************************************************************* +-* Copyright 2021-2022 Arm Ltd. and affiliates ++* Copyright 2021-2023 Arm Ltd. and affiliates + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. +@@ -74,6 +74,28 @@ status_t insert_singleton_dimension(arm_compute::TensorInfo &ti, size_t dim_i); + int reorder_dimensions_by_stride(std::vector permuted_mds, + std::vector mds); -+// Update y and z strides in arm_compute::TensorInfo -+status_t update_strides_y_and_z(arm_compute::TensorInfo &info, const int y, const int z); ++// Reorder a memory_desc_t and set the strides on a arm_compute::TensorInfo to ++// match an arm_compute::WeightFormat. You are required to specify how various ++// logical dimensions in oneDNN correspond to logical dimensions in arm_compute. ++// info TensorInfo where the strides will be changed to match the reordering ++// md memory descriptor where the stride and padded dimensions will be ++// changed or reordering ++// wf Describes the memory format/layout of the weights ++// I_dim The logical dimension of md corresponding to the input channel of ++// a convolution or the K dimension in a matmul ++// O_dim The logical dimension of md corresponding to the output channel of a ++//   convolution or the N dimension in a matmul ++// spatial_dims The logical dimensions of md corresponding to the spatial ++// dimensions of the weights (H, W, D for example). These will be ++// the next densest after the inner blocks and the input channel. ++// batch_dims The logical dimensions of md related to the batch in a batched ++// matmul, ordered from innermost to outermost. ACL calls these ++// the multi_stride_b. These will become the outermost (least dense) ++// dimensions and will be collapsed. ++void reorder_to_weight_format(arm_compute::TensorInfo &info, memory_desc_t &md, ++ arm_compute::WeightFormat wf, dim_t I_dim, dim_t O_dim, ++ std::vector spatial_dims, std::vector batch_dims = {}); + - // Insert a dimension of size 1 at the index dim_i of TensorInfo - status_t insert_singleton_dimension(arm_compute::TensorInfo &ti, size_t dim_i); + // Logs a custom 'info' line describing an unsupported case + #define LOG_ACL_UNSUPPORTED(msg) \ + do { \ +diff --git a/src/cpu/aarch64/matmul/acl_matmul.cpp b/src/cpu/aarch64/matmul/acl_matmul.cpp +index dce220fb6e..ca1c7eb47e 100644 +--- a/src/cpu/aarch64/matmul/acl_matmul.cpp ++++ b/src/cpu/aarch64/matmul/acl_matmul.cpp +@@ -1,5 +1,5 @@ + /******************************************************************************* +-* Copyright 2021-2022 Arm Ltd. and affiliates ++* Copyright 2021-2023 Arm Ltd. and affiliates + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. +@@ -31,36 +31,19 @@ status_t acl_matmul_t::execute_forward(const exec_ctx_t &ctx) const { + auto wei_base = CTX_IN_MEM(const data_t *, DNNL_ARG_WEIGHTS); + + bool is_transA = pd()->amp_.is_transA; +- bool is_transB = pd()->amp_.is_transB; + bool use_dst_acc = pd()->amp_.use_dst_acc; + + std::lock_guard _lock {this->mtx}; + auto *acl_resource = ctx.get_resource_mapper()->get(this); + acl_matmul_obj_t &acl_obj = acl_resource->get_acl_obj(); + // Run transpose kernel +- if (is_transA && !is_transB) { ++ if (is_transA) { + acl_obj.src_tensor.allocator()->allocate(); + acl_obj.src_acc_tensor.allocator()->import_memory( + const_cast(src_base)); + acl_obj.transA.run(); + acl_obj.wei_tensor.allocator()->import_memory( + const_cast(wei_base)); +- } else if (is_transB && !is_transA) { +- acl_obj.wei_tensor.allocator()->allocate(); +- acl_obj.wei_acc_tensor.allocator()->import_memory( +- const_cast(wei_base)); +- acl_obj.transB.run(); +- acl_obj.src_tensor.allocator()->import_memory( +- const_cast(src_base)); +- } else if (is_transA && is_transB) { +- acl_obj.src_tensor.allocator()->allocate(); +- acl_obj.src_acc_tensor.allocator()->import_memory( +- const_cast(src_base)); +- acl_obj.wei_tensor.allocator()->allocate(); +- acl_obj.wei_acc_tensor.allocator()->import_memory( +- const_cast(wei_base)); +- acl_obj.transA.run(); +- acl_obj.transB.run(); + } else { + acl_obj.src_tensor.allocator()->import_memory( + const_cast(src_base)); +@@ -69,7 +52,7 @@ status_t acl_matmul_t::execute_forward(const exec_ctx_t &ctx) const { + } + + if (use_dst_acc) { +- // Put the result in a new tensor, it will be accumalated to the dst ++ // Put the result in a new tensor, it will be accumulated to the dst + // during the post ops + acl_obj.dst_tensor.allocator()->allocate(); + } else { +@@ -82,7 +65,6 @@ status_t acl_matmul_t::execute_forward(const exec_ctx_t &ctx) const { + acl_obj.src_tensor.allocator()->free(); + acl_obj.wei_tensor.allocator()->free(); + if (is_transA) acl_obj.src_acc_tensor.allocator()->free(); +- if (is_transB) acl_obj.wei_acc_tensor.allocator()->free(); + void *dst = acl_obj.dst_tensor.buffer(); + pd()->post_ops.execute(ctx, dst); +diff --git a/src/cpu/aarch64/matmul/acl_matmul.hpp b/src/cpu/aarch64/matmul/acl_matmul.hpp +index cdc942e995..832b1dbb68 100644 +--- a/src/cpu/aarch64/matmul/acl_matmul.hpp ++++ b/src/cpu/aarch64/matmul/acl_matmul.hpp +@@ -32,20 +32,15 @@ struct acl_resource_t : public resource_t { + + status_t configure(const acl_matmul_conf_t &) { + if (!acl_obj_) return status::out_of_memory; +- acl_obj_->src_tensor.allocator()->init(amp.src_info); +- acl_obj_->wei_tensor.allocator()->init(amp.wei_info); +- acl_obj_->dst_tensor.allocator()->init(amp.dst_info); ++ acl_obj_->src_tensor.allocator()->init(amp.src_tensor_info); ++ acl_obj_->wei_tensor.allocator()->init(amp.wei_tensor_info); ++ acl_obj_->dst_tensor.allocator()->init(amp.dst_tensor_info); + // Configure transpose kernel for src, wei or both + if (amp.is_transA) { + acl_obj_->src_acc_tensor.allocator()->init(amp.src_acc_info); + acl_obj_->transA.configure( + &acl_obj_->src_acc_tensor, &acl_obj_->src_tensor); + } +- if (amp.is_transB) { +- acl_obj_->wei_acc_tensor.allocator()->init(amp.wei_acc_info); +- acl_obj_->transB.configure( +- &acl_obj_->wei_acc_tensor, &acl_obj_->wei_tensor); +- } + // Configure GEMM + acl_obj_->gemm.configure(&acl_obj_->src_tensor, &acl_obj_->wei_tensor, + nullptr, &acl_obj_->dst_tensor, amp.alpha, 0.0f, amp.gemm_info); +@@ -72,12 +67,20 @@ struct acl_matmul_t : public primitive_t { + + status_t init(engine_t *engine) { + using smask_t = primitive_attr_t::skip_mask_t; +- bool ok = src_md()->data_type == data_type::f32 +- && weights_md()->data_type == data_type::f32 +- && desc()->accum_data_type == data_type::f32 +- && dst_md()->data_type == data_type::f32 +- && platform::has_data_type_support(data_type::f32) ++ const bool is_fp32_ok ++ = utils::everyone_is(data_type::f32, src_md()->data_type, ++ weights_md()->data_type, dst_md()->data_type, ++ desc()->accum_data_type) ++ && platform::has_data_type_support(data_type::f32); ++ const bool is_fp16_ok ++ = utils::everyone_is(data_type::f16, src_md()->data_type, ++ weights_md()->data_type, dst_md()->data_type) ++ && platform::has_data_type_support(data_type::f16); ++ bool ok = is_dense_data() ++ && utils::one_of(true, is_fp32_ok, is_fp16_ok) + && !has_zero_dim_memory() ++ && set_default_formats() + && attr()->has_default_values( + smask_t::oscale | smask_t::post_ops) + && attr_oscale_ok() && !has_runtime_dims_or_strides(); +@@ -92,9 +95,9 @@ struct acl_matmul_t : public primitive_t { + amp_.use_dst_acc = post_ops.has_sum(); + + // Validate ACL GEMM +- ACL_CHECK_VALID(arm_compute::NEGEMM::validate(&_.src_info, +- &_.wei_info, nullptr, &_.dst_info, amp_.alpha, 0.0f, +- amp_.gemm_info)); ++ ACL_CHECK_VALID(arm_compute::NEGEMM::validate(&_.src_tensor_info, ++ &_.wei_tensor_info, nullptr, &_.dst_tensor_info, ++ amp_.alpha, 0.0f, amp_.gemm_info)); + + return status::success; + } diff --git a/src/cpu/aarch64/matmul/acl_matmul_utils.cpp b/src/cpu/aarch64/matmul/acl_matmul_utils.cpp -index 679baec3a..853277e37 100644 +index 679baec3a4..30bc2c1443 100644 --- a/src/cpu/aarch64/matmul/acl_matmul_utils.cpp +++ b/src/cpu/aarch64/matmul/acl_matmul_utils.cpp -@@ -66,15 +66,12 @@ status_t init_conf_matmul(acl_matmul_conf_t &, memory_desc_t &src_md, +@@ -41,6 +41,7 @@ status_t init_conf_matmul(acl_matmul_conf_t &, memory_desc_t &src_md, + const dim_t src_batch = helper.src_batch(); + const dim_t wei_batch = helper.wei_batch(); + ++ // We can only broadcast on one of src or wei at once + // ACL supports broadcast for 3D shapes, and 4D shapes + // for e.g when ab in abcd is 1x1 + bool batch_ok = IMPLICATION(src_batch > 1, wei_batch == 1) +@@ -53,44 +54,33 @@ status_t init_conf_matmul(acl_matmul_conf_t &, memory_desc_t &src_md, + bool with_bias = md.bias_desc.format_kind != format_kind::undef; + ACL_CHECK_SUPPORT(with_bias, "ACL does not support bias for matmul"); - // Transpose A (src) or B (wei) ++ // The two innermost dimensions can be transposed, but the batch dimensions ++ // must be the outermost + using namespace format_tag; + auto src_tag = memory_desc_matches_one_of_tag( + src_md, abcd, abdc, abc, acb, ab, ba); +- auto wei_tag = memory_desc_matches_one_of_tag( +- wei_md, abcd, abdc, abc, acb, ab, ba); +- auto dst_tag +- = memory_desc_matches_one_of_tag(dst_md, abcd, abc, acb, ab, ba); +- ACL_CHECK_SUPPORT( +- utils::one_of(format_tag::undef, src_tag, wei_tag, dst_tag), ++ auto dst_tag = memory_desc_matches_one_of_tag(dst_md, abcd, abc, ab, ba); ++ ACL_CHECK_SUPPORT(utils::one_of(format_tag::undef, src_tag, dst_tag), + "Format tag is undefined"); + +- // Transpose A (src) or B (wei) ++ // Transpose A (src) amp.is_transA = helper.transA() == 'T'; - amp.is_transB = helper.transB() == 'T'; -+ amp.is_transB = false; ++ ++ auto acl_src_data_t = acl_utils::get_acl_data_t(src_md.data_type); ++ auto acl_wei_data_t = acl_utils::get_acl_data_t(wei_md.data_type); ++ auto acl_dst_data_t = acl_utils::get_acl_data_t(dst_md.data_type); + if (amp.is_transA) amp.src_acc_info = arm_compute::TensorInfo( arm_compute::TensorShape(M, K, 1, src_batch), 1, - arm_compute::DataType::F32); +- arm_compute::DataType::F32); - if (amp.is_transB) - amp.wei_acc_info = arm_compute::TensorInfo( - arm_compute::TensorShape(K, N, wei_batch), 1, - arm_compute::DataType::F32); +- +- amp.src_info = arm_compute::TensorInfo( +- arm_compute::TensorShape(K, M, 1, src_batch), 1, +- arm_compute::DataType::F32); +- amp.wei_info +- = arm_compute::TensorInfo(arm_compute::TensorShape(N, K, wei_batch), +- 1, arm_compute::DataType::F32); +- amp.dst_info = arm_compute::TensorInfo( +- arm_compute::TensorShape(N, M, 1, dst_batch), 1, +- arm_compute::DataType::F32); +- +- // Fast-math mode +- auto math_mode = get_fpmath_mode(); +- bool is_fastmath_enabled +- = utils::one_of(math_mode, fpmath_mode::bf16, fpmath_mode::any); +- amp.gemm_info.set_fast_math(is_fastmath_enabled); ++ acl_src_data_t); ++ ++ amp.src_tensor_info = arm_compute::TensorInfo( ++ arm_compute::TensorShape(K, M, 1, src_batch), 1, acl_src_data_t); ++ amp.wei_tensor_info = arm_compute::TensorInfo( ++ arm_compute::TensorShape(N, K, wei_batch), 1, acl_wei_data_t); ++ amp.dst_tensor_info = arm_compute::TensorInfo( ++ arm_compute::TensorShape(N, M, 1, dst_batch), 1, acl_dst_data_t); - amp.src_info = arm_compute::TensorInfo( - arm_compute::TensorShape(K, M, 1, src_batch), 1, -@@ -103,6 +100,140 @@ status_t init_conf_matmul(acl_matmul_conf_t &, memory_desc_t &src_md, + // Set alpha (output scaling) + amp.alpha = attr.output_scales_.scales_[0]; +@@ -98,10 +88,45 @@ status_t init_conf_matmul(acl_matmul_conf_t &, memory_desc_t &src_md, + // Validate ACL transpose + if (amp.is_transA) ACL_CHECK_VALID(arm_compute::NETranspose::validate( - &.wei_acc_info, &.wei_info)); - -+ arm_compute::WeightFormat expected_weight_format; +- &.src_acc_info, &.src_info)); +- if (amp.is_transB) +- ACL_CHECK_VALID(arm_compute::NETranspose::validate( +- &.wei_acc_info, &.wei_info)); ++ &.src_acc_info, &.src_tensor_info)); ++ ++ bool is_fastmath_enabled = utils::one_of( ++ attr.fpmath_mode_, fpmath_mode::bf16, fpmath_mode::any); ++ amp.gemm_info.set_fast_math(is_fastmath_enabled); + + amp.gemm_info.set_fixed_format(true); ++ ++ // WeightFormat::ANY tells ACL we can handle any format + amp.gemm_info.set_weight_format(arm_compute::WeightFormat::ANY); + -+ auto acl_st = arm_compute::NEGEMM::has_opt_impl( -+ expected_weight_format, -+ &.src_info, -+ &.wei_info, -+ nullptr, -+ &.dst_info, -+ amp.alpha, -+ 0.0f, -+ amp.gemm_info); -+ -+ if(acl_st.error_code() != arm_compute::ErrorCode::OK) { -+ return status::unimplemented; -+ } ++ // Get the format that the ACL kernel will expect the weights to be ++ // in (if a kernel exists). Note that these are referred to as fixed format ++ // kernels, because they require one specific weights format ++ arm_compute::WeightFormat expected_weight_format; ++ ACL_CHECK_VALID(arm_compute::NEGEMM::has_opt_impl(expected_weight_format, ++ &.src_tensor_info, &.wei_tensor_info, nullptr, ++ &.dst_tensor_info, amp.alpha, 0.0f, amp.gemm_info)); + ++ // Set gemm weights info to the one returned by has_opt_impl + amp.gemm_info.set_weight_format(expected_weight_format); + -+ memory_desc_t want_wei_md = wei_md; -+ -+ // We need to transpose second to last dimension and use blocking -+ // as returned by interleave by from expecting strides -+ int interleaved_by = arm_compute::interleave_by(expected_weight_format); -+ int block_by = arm_compute::block_by(expected_weight_format); -+ bool is_fast_math_kernel = arm_compute::is_fixed_format_fast_math(expected_weight_format); -+ if(!is_fast_math_kernel) { -+ amp.gemm_info.set_fast_math(false); -+ } -+ -+ int blocked_first_dimension = -1; -+ int blocked_second_dimension = -1; -+ -+ // Assume that interleaved by is X and blocked by is Y -+ switch(want_wei_md.ndims) { -+ case 2: { -+ // For 2D case the format that we need to pass is BaXb and -+ // when doing fast mode BAXbYa -+ want_wei_md.format_desc.blocking.strides[0] = interleaved_by * block_by; -+ // check to see whether we need to pad -+ if(want_wei_md.dims[0] % block_by != 0) { -+ want_wei_md.padded_dims[0] = utils::div_up(want_wei_md.dims[0], block_by) * block_by; -+ } -+ want_wei_md.format_desc.blocking.strides[1] = interleaved_by * want_wei_md.padded_dims[0]; -+ if(want_wei_md.dims[1] % interleaved_by != 0) { -+ want_wei_md.padded_dims[1] = utils::div_up(want_wei_md.dims[1], interleaved_by) * interleaved_by; -+ } -+ -+ acl_utils::update_strides_y_and_z( -+ amp.wei_info, -+ want_wei_md.format_desc.blocking.strides[1] * wei_d.data_type_size(), -+ want_wei_md.format_desc.blocking.strides[0] * wei_d.data_type_size()); -+ -+ blocked_first_dimension = 1; -+ blocked_second_dimension = 0; -+ -+ break; -+ } -+ -+ case 3: { -+ // For 3D case the format we need to pass is aCbXc and -+ // when doing fast mode is aCBXcYb -+ want_wei_md.format_desc.blocking.strides[1] = interleaved_by*block_by; -+ if(want_wei_md.dims[1] % block_by != 0) { -+ want_wei_md.padded_dims[1] = utils::div_up(want_wei_md.dims[1], block_by) * block_by; -+ } -+ want_wei_md.format_desc.blocking.strides[2] = interleaved_by * want_wei_md.padded_dims[1]; -+ if(want_wei_md.dims[2] % interleaved_by != 0) { -+ want_wei_md.padded_dims[2] = utils::div_up(want_wei_md.dims[2], interleaved_by) * interleaved_by; -+ } -+ want_wei_md.format_desc.blocking.strides[0] = want_wei_md.padded_dims[2] * want_wei_md.padded_dims[1]; -+ -+ acl_utils::update_strides_y_and_z( -+ amp.wei_info, -+ want_wei_md.format_desc.blocking.strides[2] * wei_d.data_type_size(), -+ want_wei_md.format_desc.blocking.strides[0] * wei_d.data_type_size()); -+ -+ blocked_first_dimension = 2; -+ blocked_second_dimension = 1; -+ -+ break; -+ } -+ -+ case 4: { -+ // For 4D case the format we need to pass is abDcXd and -+ // when doing fast mode is abDCxdYc -+ int D_padded = want_wei_md.dims[3]; -+ if(D_padded % interleaved_by != 0) { -+ D_padded = utils::div_up(D_padded, interleaved_by) * interleaved_by; -+ want_wei_md.padded_dims[3] = D_padded; -+ } -+ -+ int C_padded = want_wei_md.dims[2]; -+ if(C_padded % block_by != 0) { -+ C_padded = utils::div_up(C_padded, block_by) * block_by; -+ want_wei_md.padded_dims[2] = C_padded; -+ } -+ -+ want_wei_md.format_desc.blocking.strides[0] = want_wei_md.dims[1]*D_padded*C_padded; -+ want_wei_md.format_desc.blocking.strides[1] = D_padded*C_padded; -+ want_wei_md.format_desc.blocking.strides[2] = interleaved_by*block_by; -+ want_wei_md.format_desc.blocking.strides[3] = interleaved_by*C_padded; -+ -+ acl_utils::update_strides_y_and_z( -+ amp.wei_info, -+ want_wei_md.format_desc.blocking.strides[3] * wei_d.data_type_size(), -+ want_wei_md.format_desc.blocking.strides[1] * wei_d.data_type_size()); ++ // has_opt_impl may return a non fast math kernel, even if we requested one ++ amp.gemm_info.set_fast_math( ++ arm_compute::is_fixed_format_fast_math(expected_weight_format)); + -+ blocked_first_dimension = 3; -+ blocked_second_dimension = 2; ++ // Logical dimension indices ++ dim_t innermost_dim = wei_md.ndims - 1; ++ dim_t N_dim = innermost_dim; ++ dim_t K_dim = innermost_dim - 1; + -+ break; -+ } -+ -+ default: -+ return status::unimplemented; -+ } -+ -+ want_wei_md.format_desc.blocking.inner_nblks = (block_by > 1) + 1; -+ want_wei_md.format_desc.blocking.inner_idxs[0] = blocked_first_dimension; -+ want_wei_md.format_desc.blocking.inner_blks[0] = interleaved_by; -+ if(block_by > 1) { -+ want_wei_md.format_desc.blocking.inner_idxs[1] = blocked_second_dimension; -+ want_wei_md.format_desc.blocking.inner_blks[1] = block_by; -+ } -+ -+ if(is_fast_math_kernel) { -+ want_wei_md.data_type = dnnl_bf16; -+ } -+ -+ wei_md = want_wei_md; ++ // The logical indices of dimensions related to the batch, ordered from ++ // innermost to outermost ++ std::vector batch_dims = {}; ++ for (dim_t i = K_dim - 1; i >= 0; --i) ++ batch_dims.push_back(i); + ++ acl_utils::reorder_to_weight_format(amp.wei_tensor_info, wei_md, ++ expected_weight_format, K_dim, N_dim, {}, batch_dims); + return status::success; } +diff --git a/src/cpu/aarch64/matmul/acl_matmul_utils.hpp b/src/cpu/aarch64/matmul/acl_matmul_utils.hpp +index 0a5ee6a987..67bb2e78eb 100644 +--- a/src/cpu/aarch64/matmul/acl_matmul_utils.hpp ++++ b/src/cpu/aarch64/matmul/acl_matmul_utils.hpp +@@ -1,5 +1,5 @@ + /******************************************************************************* +-* Copyright 2021-2022 Arm Ltd. and affiliates ++* Copyright 2021-2023 Arm Ltd. and affiliates + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. +@@ -29,25 +29,21 @@ namespace aarch64 { + struct acl_matmul_obj_t { + arm_compute::NEGEMM gemm; + arm_compute::NETranspose transA; +- arm_compute::NETranspose transB; + arm_compute::Tensor src_tensor; + arm_compute::Tensor src_acc_tensor; + arm_compute::Tensor wei_tensor; +- arm_compute::Tensor wei_acc_tensor; + arm_compute::Tensor dst_tensor; + }; + struct acl_matmul_conf_t { + bool is_transA; +- bool is_transB; + // If this is true, the result of the matmul goes into a temporarily + // allocated ACL tensor to be accumulated into the oneDNN dst during postops + bool use_dst_acc; +- arm_compute::TensorInfo src_info; ++ arm_compute::TensorInfo src_tensor_info; + arm_compute::TensorInfo src_acc_info; +- arm_compute::TensorInfo wei_info; +- arm_compute::TensorInfo wei_acc_info; +- arm_compute::TensorInfo dst_info; ++ arm_compute::TensorInfo wei_tensor_info; ++ arm_compute::TensorInfo dst_tensor_info; + arm_compute::GEMMInfo gemm_info; + float alpha; + }; diff --git a/third_party/mkl_dnn/onednn_acl_remove_winograd.patch b/third_party/mkl_dnn/onednn_acl_remove_winograd.patch new file mode 100644 index 00000000000000..18abcc8f54e922 --- /dev/null +++ b/third_party/mkl_dnn/onednn_acl_remove_winograd.patch @@ -0,0 +1,326 @@ + ******************************************************************************* + Copyright 2023 Arm Limited and affiliates. + SPDX-License-Identifier: Apache-2.0 + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ******************************************************************************* +diff --git a/src/cpu/aarch64/acl_convolution_utils.cpp b/src/cpu/aarch64/acl_convolution_utils.cpp +index c46d697575..37f8ecbc06 100644 +--- a/src/cpu/aarch64/acl_convolution_utils.cpp ++++ b/src/cpu/aarch64/acl_convolution_utils.cpp +@@ -271,54 +271,6 @@ status_t init_conf_indirect_gemm(acl_conv_conf_t &acp, memory_desc_t &src_md, + return status::success; + } + +-status_t init_conf_wino(acl_conv_conf_t &acp, memory_desc_t &src_md, +- memory_desc_t &weights_md, memory_desc_t &dst_md, +- memory_desc_t &bias_md, const convolution_desc_t &cd, +- const primitive_attr_t &attr) { +- +- // Under these conditions, fallback to faster GEMM-based convolution +- // unless the user explicitly specifies Winograd algorithm +- // clang-format off +- if (one_of(true, src_md.dims[2] > 112, // ih +- src_md.dims[3] > 112, // iw +- src_md.dims[1] < 64, // ic +- dst_md.dims[1] < 64, // oc +- dnnl_get_max_threads() > 28) +- && cd.alg_kind == alg_kind::convolution_auto) { +- return status::unimplemented; +- } +- // clang-format on +- +- // General Compute Library checks, memory tags are also set there +- CHECK(acl_init_conf(acp, src_md, weights_md, dst_md, bias_md, cd, attr)); +- +- const bool shape_ok +- // only unit strides allowed +- = (acp.padstride_info.stride() == std::pair {1, 1}) +- // Note: Compute Library supports arbitrary padding for wino kernels +- // but we only allow small padding to be consistent with oneDNN +- && (acp.padstride_info.pad().first <= 1) // padding left/right +- && (acp.padstride_info.pad().second <= 1) // padding top/bottom +- // only non-dilated convolutions allowed +- && (acp.dilation_info == arm_compute::Size2D(1, 1)); +- +- ACL_CHECK_SUPPORT(!shape_ok, "shape not supported by winograd kernels"); +- +- // clang-format off +- // Validate convolution manually to check for return status +- ACL_CHECK_VALID(arm_compute::NEWinogradConvolutionLayer::validate( +- &acp.src_info, +- &acp.wei_info, +- acp.with_bias ? &acp.bia_info : nullptr, +- &acp.dst_info, +- acp.padstride_info, +- acp.act_info, +- true)); // enable_fast_math flag in ACL Winograd +- // clang-format on +- +- return status::success; +-} +- + } // namespace acl_convolution_utils + + } // namespace aarch64 +diff --git a/src/cpu/aarch64/acl_convolution_utils.hpp b/src/cpu/aarch64/acl_convolution_utils.hpp +index 3e56245faf..0398ab06b9 100644 +--- a/src/cpu/aarch64/acl_convolution_utils.hpp ++++ b/src/cpu/aarch64/acl_convolution_utils.hpp +@@ -66,11 +66,6 @@ status_t init_conf_indirect_gemm(acl_conv_conf_t &acp, memory_desc_t &src_md, + memory_desc_t &bias_md, const convolution_desc_t &cd, + const primitive_attr_t &attr); + +-status_t init_conf_wino(acl_conv_conf_t &acp, memory_desc_t &src_md, +- memory_desc_t &weights_md, memory_desc_t &dst_md, +- memory_desc_t &bias_md, const convolution_desc_t &cd, +- const primitive_attr_t &attr); +- + } // namespace acl_convolution_utils + + template _lock {this->mtx}; +- // Retrieve primitive resource and configured Compute Library objects +- auto *acl_resource +- = ctx.get_resource_mapper()->get(this); +- acl_obj_t &acl_wino_obj +- = acl_resource->get_acl_obj(); +- +- return execute_forward_conv_acl< +- acl_obj_t, pd_t, data_t>( +- ctx, acl_wino_obj, pd()); +-} +- +-} // namespace aarch64 +-} // namespace cpu +-} // namespace impl +-} // namespace dnnl +diff --git a/src/cpu/aarch64/acl_winograd_convolution.hpp b/src/cpu/aarch64/acl_winograd_convolution.hpp +deleted file mode 100644 +index 215635fe3f..0000000000 +--- a/src/cpu/aarch64/acl_winograd_convolution.hpp ++++ /dev/null +@@ -1,146 +0,0 @@ +-/******************************************************************************* +-* Copyright 2020-2022 Arm Ltd. and affiliates +-* +-* Licensed under the Apache License, Version 2.0 (the "License"); +-* you may not use this file except in compliance with the License. +-* You may obtain a copy of the License at +-* +-* http://www.apache.org/licenses/LICENSE-2.0 +-* +-* Unless required by applicable law or agreed to in writing, software +-* distributed under the License is distributed on an "AS IS" BASIS, +-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-* See the License for the specific language governing permissions and +-* limitations under the License. +-*******************************************************************************/ +- +-#ifndef CPU_AARCH64_ACL_WINOGRAD_CONVOLUTION_HPP +-#define CPU_AARCH64_ACL_WINOGRAD_CONVOLUTION_HPP +- +-#include "cpu/cpu_convolution_pd.hpp" +- +-#include "cpu/aarch64/acl_convolution_utils.hpp" +- +-namespace dnnl { +-namespace impl { +-namespace cpu { +-namespace aarch64 { +- +-struct acl_wino_resource_t : public resource_t { +- acl_wino_resource_t() +- : acl_wino_obj_(utils::make_unique< +- acl_obj_t>()) {} +- +- status_t configure(const acl_conv_conf_t &acp) { +- if (!acl_wino_obj_) return status::out_of_memory; +- +- // Init Compute Library tensors based on info from descriptor +- acl_wino_obj_->src_tensor.allocator()->init(acp.src_info); +- acl_wino_obj_->wei_tensor.allocator()->init(acp.wei_info); +- acl_wino_obj_->dst_tensor.allocator()->init(acp.dst_info); +- acl_wino_obj_->bia_tensor.allocator()->init(acp.bia_info); +- +- // clang-format off +- acl_wino_obj_->conv.configure( +- &acl_wino_obj_->src_tensor, +- &acl_wino_obj_->wei_tensor, +- acp.with_bias ? &acl_wino_obj_->bia_tensor : nullptr, +- &acl_wino_obj_->dst_tensor, +- acp.padstride_info, +- acp.act_info, +- true); // to support 5x5, 7x7 filter shapes in addition to 3x3 +- // clang-format on +- +- return status::success; +- } +- +- acl_obj_t &get_acl_obj() const { +- return *acl_wino_obj_; +- } +- +- DNNL_DISALLOW_COPY_AND_ASSIGN(acl_wino_resource_t); +- +-private: +- std::unique_ptr> +- acl_wino_obj_; +-}; // acl_wino_resource_t +- +-struct acl_wino_convolution_fwd_t : public primitive_t { +- struct pd_t : public cpu_convolution_fwd_pd_t { +- pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr, +- const typename pd_t::base_class *hint_fwd_pd) +- : cpu_convolution_fwd_pd_t(adesc, attr, hint_fwd_pd) +- , acp_() +- , post_ops() {} +- +- DECLARE_COMMON_PD_T( +- "wino:acl", acl_wino_convolution_fwd_t, USE_GLOBAL_SCRATCHPAD); +- +- status_t init(engine_t *engine) { +- bool ok = is_fwd() +- && utils::one_of(desc()->alg_kind, +- alg_kind::convolution_auto, +- alg_kind::convolution_winograd) +- && expect_data_types(data_type::f32, data_type::f32, +- data_type::f32, data_type::f32, data_type::f32) +- && attr()->has_default_values( +- primitive_attr_t::skip_mask_t::post_ops, +- data_type::f32) +- && !has_zero_dim_memory(); +- if (!ok) return status::unimplemented; +- +- CHECK(acl_convolution_utils::init_conf_wino(acp_, src_md_, +- weights_md_, dst_md_, bias_md_, *desc(), *attr())); +- +- set_default_alg_kind(alg_kind::convolution_winograd); +- +- CHECK(post_ops.init( +- engine, attr_.post_ops_, dst_md_, acp_.act_info)); +- acp_.use_dst_acc = post_ops.has_sum(); +- +- return status::success; +- } +- +- acl_conv_conf_t acp_; +- acl_post_ops_t post_ops; +- }; +- +- acl_wino_convolution_fwd_t(const pd_t *apd) : primitive_t(apd) {} +- +- status_t create_resource( +- engine_t *engine, resource_mapper_t &mapper) const override { +- if (mapper.has_resource(this)) return status::success; +- +- auto r = utils::make_unique(); +- if (!r) return status::out_of_memory; +- +- // Configure the resource based on information from primitive descriptor +- CHECK(r->configure(pd()->acp_)); +- mapper.add(this, std::move(r)); +- +- CHECK(pd()->post_ops.create_resource(engine, mapper)); +- +- return status::success; +- } +- +- ~acl_wino_convolution_fwd_t() {} +- +- typedef typename prec_traits::type data_t; +- +- status_t execute(const exec_ctx_t &ctx) const override { +- return execute_forward(ctx); +- } +- +-private: +- // To guard the const execute_forward(), the mutex must be 'mutable' +- mutable std::mutex mtx; +- status_t execute_forward(const exec_ctx_t &ctx) const; +- const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } +-}; // acl_wino_convolution_fwd_t +- +-} // namespace aarch64 +-} // namespace cpu +-} // namespace impl +-} // namespace dnnl +- +-#endif // CPU_AARCH64_ACL_WINOGRAD_CONVOLUTION_HPP +diff --git a/src/cpu/cpu_convolution_list.cpp b/src/cpu/cpu_convolution_list.cpp +index 4142dbc7e7..094c73aa36 100644 +--- a/src/cpu/cpu_convolution_list.cpp ++++ b/src/cpu/cpu_convolution_list.cpp +@@ -65,7 +65,6 @@ using namespace dnnl::impl::cpu::x64; + #if DNNL_AARCH64 && DNNL_AARCH64_USE_ACL + #include "cpu/aarch64/acl_gemm_convolution.hpp" + #include "cpu/aarch64/acl_indirect_gemm_convolution.hpp" +-#include "cpu/aarch64/acl_winograd_convolution.hpp" + #endif + using namespace dnnl::impl::cpu::aarch64; + #endif +@@ -100,7 +99,6 @@ const std::map> &impl_list_map() + CPU_INSTANCE_SSE41(jit_sse41_1x1_convolution_fwd_t) + CPU_INSTANCE_AVX2(jit_avx2_convolution_fwd_t) + CPU_INSTANCE_SSE41(jit_sse41_convolution_fwd_t) +- CPU_INSTANCE_AARCH64_ACL(acl_wino_convolution_fwd_t) + CPU_INSTANCE_AARCH64(jit_sve_512_dw_convolution_fwd_t) + CPU_INSTANCE_AARCH64(jit_sve_512_1x1_convolution_fwd_f32_t) + CPU_INSTANCE_AARCH64(jit_sve_512_convolution_fwd_t) +diff --git a/tests/gtests/test_iface_wino_convolution.cpp b/tests/gtests/test_iface_wino_convolution.cpp +index 03861b1de4..2235ceae36 100644 +--- a/tests/gtests/test_iface_wino_convolution.cpp ++++ b/tests/gtests/test_iface_wino_convolution.cpp +@@ -59,9 +59,6 @@ protected: + input_f16.wino_supported = is_gpu; + input_int8.wino_supported = is_cpu && has_avx512_core; + input_f32.backward_supported = is_cpu && impl::dnnl_thr_syncable(); +-#elif DNNL_AARCH64 && DNNL_AARCH64_USE_ACL +- const bool is_cpu = get_test_engine_kind() == engine::kind::cpu; +- input_f32.wino_supported = is_cpu; + #endif + + #else diff --git a/third_party/mkl_dnn/onednn_acl_reorder.patch b/third_party/mkl_dnn/onednn_acl_reorder.patch new file mode 100644 index 00000000000000..71312f6e87b3f0 --- /dev/null +++ b/third_party/mkl_dnn/onednn_acl_reorder.patch @@ -0,0 +1,349 @@ +diff --git a/src/cpu/aarch64/acl_reorder.cpp b/src/cpu/aarch64/acl_reorder.cpp +new file mode 100644 +index 000000000..061751b55 +--- /dev/null ++++ b/src/cpu/aarch64/acl_reorder.cpp +@@ -0,0 +1,52 @@ ++/******************************************************************************* ++* Copyright 2023 Arm Ltd. and affiliates ++* ++* Licensed under the Apache License, Version 2.0 (the "License"); ++* you may not use this file except in compliance with the License. ++* You may obtain a copy of the License at ++* ++* http://www.apache.org/licenses/LICENSE-2.0 ++* ++* Unless required by applicable law or agreed to in writing, software ++* distributed under the License is distributed on an "AS IS" BASIS, ++* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++* See the License for the specific language governing permissions and ++* limitations under the License. ++*******************************************************************************/ ++ ++#include "cpu/aarch64/acl_reorder.hpp" ++ ++namespace dnnl { ++namespace impl { ++namespace cpu { ++namespace aarch64 { ++ ++status_t acl_reorder_fwd_t::execute_forward(const exec_ctx_t &ctx) const { ++ // Lock here is needed because resource_mapper does not support ++ // concurrent multithreaded access. ++ std::lock_guard _lock {this->mtx}; ++ ++ auto src = CTX_IN_MEM(const void *, DNNL_ARG_FROM); ++ auto dst = CTX_OUT_MEM(void *, DNNL_ARG_TO); ++ ++ // Retrieve primitive resource and configured Compute Library objects ++ auto *acl_resource ++ = ctx.get_resource_mapper()->get(this); ++ ++ acl_reorder_obj_t &acl_obj = acl_resource->get_acl_obj(); ++ ++ acl_obj.src_tensor.allocator()->import_memory(const_cast(src)); ++ acl_obj.dst_tensor.allocator()->import_memory(dst); ++ ++ acl_obj.reorder.run(); ++ ++ acl_obj.src_tensor.allocator()->free(); ++ acl_obj.dst_tensor.allocator()->free(); ++ ++ return status::success; ++} ++ ++} // namespace aarch64 ++} // namespace cpu ++} // namespace impl ++} // namespace dnnl +diff --git a/src/cpu/aarch64/acl_reorder.hpp b/src/cpu/aarch64/acl_reorder.hpp +new file mode 100644 +index 000000000..930ccb40e +--- /dev/null ++++ b/src/cpu/aarch64/acl_reorder.hpp +@@ -0,0 +1,257 @@ ++/******************************************************************************* ++* Copyright 2023 Arm Ltd. and affiliates ++* ++* Licensed under the Apache License, Version 2.0 (the "License"); ++* you may not use this file except in compliance with the License. ++* You may obtain a copy of the License at ++* ++* http://www.apache.org/licenses/LICENSE-2.0 ++* ++* Unless required by applicable law or agreed to in writing, software ++* distributed under the License is distributed on an "AS IS" BASIS, ++* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++* See the License for the specific language governing permissions and ++* limitations under the License. ++*******************************************************************************/ ++#ifndef CPU_AARCH64_ACL_REORDER_HPP ++#define CPU_AARCH64_ACL_REORDER_HPP ++ ++#include "cpu/aarch64/acl_utils.hpp" ++#include "cpu/reorder/cpu_reorder_pd.hpp" ++#include "arm_compute/core/Types.h" ++#include "common/utils.hpp" ++ ++namespace dnnl { ++namespace impl { ++namespace cpu { ++namespace aarch64 { ++ ++struct acl_reorder_obj_t { ++ arm_compute::NEReorderLayer reorder; ++ arm_compute::Tensor src_tensor; ++ arm_compute::Tensor dst_tensor; ++ arm_compute::WeightFormat src_wf; ++ arm_compute::WeightFormat dst_wf; ++}; ++ ++struct acl_reorder_conf_t { ++ arm_compute::TensorInfo src_info; ++ arm_compute::TensorInfo dst_info; ++ arm_compute::WeightFormat src_wf; ++ arm_compute::WeightFormat dst_wf; ++}; ++ ++struct acl_reorder_resource_t : public resource_t { ++ acl_reorder_resource_t() : acl_obj_(utils::make_unique()) {} ++ ++ status_t configure(const acl_reorder_conf_t &app) { ++ if (!acl_obj_) return status::out_of_memory; ++ ++ // Init Compute Library tensors based on info from descriptor ++ acl_obj_->src_tensor.allocator()->init(app.src_info); ++ acl_obj_->dst_tensor.allocator()->init(app.dst_info); ++ ++ // clang-format off ++ acl_obj_->reorder.configure( ++ &acl_obj_->src_tensor, ++ &acl_obj_->dst_tensor, ++ app.src_wf, ++ app.dst_wf ++ ); ++ // clang-format on ++ ++ return status::success; ++ } ++ ++ acl_reorder_obj_t &get_acl_obj() const { return *acl_obj_; } ++ DNNL_DISALLOW_COPY_AND_ASSIGN(acl_reorder_resource_t); ++ ++private: ++ std::unique_ptr acl_obj_; ++}; // acl_reorder_resource_t ++ ++struct acl_reorder_fwd_t : public primitive_t { ++ using primitive_t::primitive_t; ++ struct pd_t : public cpu_reorder_pd_t { ++ ++ using cpu_reorder_pd_t::cpu_reorder_pd_t; ++ ++ DECLARE_COMMON_PD_T("acl", acl_reorder_fwd_t); ++ ++ static status_t create(reorder_pd_t **reorder_pd, engine_t *engine, ++ const primitive_attr_t *attr, engine_t *src_engine, ++ const memory_desc_t *src_md, engine_t *dst_engine, ++ const memory_desc_t *dst_md) { ++ ++ using namespace acl_utils; ++ // using skip_mask_t = dnnl_primitive_attr::skip_mask_t; ++ ++ bool ok = src_md->data_type ++ == dst_md->data_type // ACL only supports matching src/dst data types ++ && utils::one_of(src_md->data_type, ++ data_type::f32) // Only supports f32 for now ++ && attr->has_default_values(); ++ if (!ok) return status::unimplemented; ++ ++ int mask = -1; ++ bool is_set = false; ++ // CHECK(attr->scales_.get(DNNL_ARG_DST, &mask, &is_set)); ++ const memory_desc_wrapper input_d(src_md); ++ if (input_d.has_runtime_dims_or_strides() && is_set && mask > 0) ++ return status::unimplemented; ++ ++ // Create and check primitive descriptor ++ auto _pd = new pd_t(attr, src_engine->kind(), src_md, ++ dst_engine->kind(), dst_md); ++ if (_pd == nullptr) return status::out_of_memory; ++ if (_pd->init(engine, src_engine, dst_engine) != status::success) { ++ delete _pd; ++ return status::unimplemented; ++ } ++ ++ const memory_desc_wrapper src_d(*src_md); ++ const memory_desc_wrapper dst_d(*dst_md); ++ ++ const int ndims = src_d.ndims(); ++ ++ auto src_tag = memory_desc_matches_one_of_tag( ++ *src_md, format_tag::ba, format_tag::cdba); ++ ACL_CHECK_SUPPORT( ++ utils::one_of(format_tag::undef, src_tag), ++ ""); ++ ++ arm_compute::TensorShape acl_tensor_shape_in; ++ arm_compute::TensorShape acl_tensor_shape_out; ++ // Need even amount of dims in dim 0 for ACL kernel (eg mulitple of 8 rows when blocking by 8) ++ int dim_0_rounded_up; ++ ++ // Switch for 2 or 4 dim tensors ++ switch(ndims) ++ { ++ // Currently for Ab4a and Ab8a ++ // No format_tag for these, have to deduce from stride ++ case 2: ++ { ++ int dst_dim_1 = dst_md->dims[1]; ++ int dst_dim_0_stride = dst_md->format_desc.blocking.strides[0]; ++ int dst_dim_1_stride = dst_md->format_desc.blocking.strides[1]; ++ // Interleave of 4 or 8 that stride for dim 1 ++ if (dst_dim_1_stride != 4 && dst_dim_1_stride != 8){ ++ return status::unimplemented; ++ } ++ // Check to ensure it's a blocking transpose ++ if (dst_dim_1 * dst_dim_1_stride != dst_dim_0_stride){ ++ return status::unimplemented; ++ } ++ if(dst_dim_1_stride == 4){ ++ // Set Dest WeightFormat ++ _pd->app_.dst_wf = arm_compute::WeightFormat::OHWIo4; ++ dim_0_rounded_up ++ = utils::rnd_up(src_md->dims[0], 4); ++ } else { ++ // Set Dest WeightFormat ++ _pd->app_.dst_wf = arm_compute::WeightFormat::OHWIo8; ++ dim_0_rounded_up ++ = utils::rnd_up(src_md->dims[0], 8); ++ } ++ acl_tensor_shape_in = arm_compute::TensorShape(src_md->dims[1], src_md->dims[0]); ++ acl_tensor_shape_out = arm_compute::TensorShape(src_md->dims[1], dim_0_rounded_up); ++ ++ break; ++ } ++ // Currently for Acdb4a and Acdb8a ++ case 4: ++ { ++ ++ auto dst_tag = memory_desc_matches_one_of_tag( ++ *dst_md, format_tag::Acdb4a, format_tag::Acdb8a); ++ ACL_CHECK_SUPPORT( ++ utils::one_of(format_tag::undef, dst_tag), ++ ""); ++ if(dst_tag == format_tag::Acdb4a){ ++ // Set Dest WeightFormat ++ _pd->app_.dst_wf = arm_compute::WeightFormat::OHWIo4; ++ dim_0_rounded_up ++ = utils::rnd_up(src_md->dims[0], 4); ++ } ++ else{ ++ // Set Dest WeightFormat ++ _pd->app_.dst_wf = arm_compute::WeightFormat::OHWIo8; ++ dim_0_rounded_up ++ = utils::rnd_up(src_md->dims[0], 8); ++ } ++ // Currently only supporting AxBx1x1 cases ++ if(dst_md->dims[2] != 1 || dst_md->dims[3] != 1){ ++ return status::unimplemented; ++ } ++ ++ acl_tensor_shape_in = arm_compute::TensorShape(src_md->dims[3], src_md->dims[2], src_md->dims[1], src_md->dims[0]); ++ acl_tensor_shape_out = arm_compute::TensorShape(src_md->dims[3], src_md->dims[2], src_md->dims[1], dim_0_rounded_up); ++ break; ++ } ++ default: ++ return status::unimplemented; ++ } ++ ++ // Choose the data layout ++ // bool is_nspc = utils::one_of(src_tag, format_tag::nhwc); ++ const auto acl_layout = arm_compute::DataLayout::NCHW; ++ ++ // Set Source WeightFormat ++ _pd->app_.src_wf = arm_compute::WeightFormat::OHWI; ++ ++ // Create ACL tensor infos ++ const data_type_t data_type = src_d.data_type(); ++ const arm_compute::DataType acl_data_t ++ = acl_utils::get_acl_data_t(data_type); ++ _pd->app_.src_info = arm_compute::TensorInfo( ++ acl_tensor_shape_in, 1, acl_data_t, acl_layout); ++ _pd->app_.dst_info = arm_compute::TensorInfo( ++ acl_tensor_shape_out, 1, acl_data_t, acl_layout); ++ ++ // Init scratch memory, not used so 0 in this implementation ++ _pd->init_scratchpad_md(); ++ ++ return safe_ptr_assign(*reorder_pd, _pd); ++ } // create ++ ++ friend dnnl::impl::impl_list_item_t; ++ acl_reorder_conf_t app_; ++ ++ }; // pd_t ++ ++ acl_reorder_fwd_t(const pd_t *apd) : primitive_t(apd) {} ++ ++ status_t create_resource( ++ engine_t *engine, resource_mapper_t &mapper) const override { ++ if (mapper.has_resource(this)) return status::success; ++ ++ auto r = utils::make_unique(); ++ if (!r) return status::out_of_memory; ++ ++ // Configure the resource based on information from primitive descriptor ++ CHECK(r->configure(pd()->app_)); ++ ++ mapper.add(this, std::move(r)); ++ return status::success; ++ } ++ ++ status_t execute(const exec_ctx_t &ctx) const override { ++ return execute_forward(ctx); ++ } ++ ++private: ++ // To guard the const execute_forward, the mutex must be 'mutable' ++ mutable std::mutex mtx; ++ status_t execute_forward(const exec_ctx_t &ctx) const; ++ const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } ++ ++ ++}; // acl_reorder_fwd_t ++ ++} // namespace aarch64 ++} // namespace cpu ++} // namespace impl ++} // namespace dnnl ++ ++#endif // CPU_AARCH64_ACL_REORDER_HPP +diff --git a/src/cpu/reorder/cpu_reorder_regular_f32_f32.cpp b/src/cpu/reorder/cpu_reorder_regular_f32_f32.cpp +index bccd2f75f..5e5ea331b 100644 +--- a/src/cpu/reorder/cpu_reorder_regular_f32_f32.cpp ++++ b/src/cpu/reorder/cpu_reorder_regular_f32_f32.cpp +@@ -15,6 +15,7 @@ + *******************************************************************************/ + + #include "cpu/reorder/cpu_reorder.hpp" ++#include "cpu/aarch64/acl_reorder.hpp" + + namespace dnnl { + namespace impl { +@@ -27,6 +28,7 @@ const impl_list_map_t ®ular_f32_f32_impl_list_map() { + // f32 -> f32 + {{f32, f32, 0}, { + REG_FAST_DIRECT_COPY_F32_F32 ++ DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::acl_reorder_fwd_t)) + + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_blk_reorder_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_t)) +@@ -64,6 +66,7 @@ const impl_list_map_t ®ular_f32_f32_impl_list_map() { + nullptr, + }}, + {{f32, f32, 4}, { ++ DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::acl_reorder_fwd_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::wino_reorder_t)) + + CPU_REORDER_INSTANCE(rnn_weights_reorder_t) diff --git a/third_party/mkl_dnn/onednn_acl_threadpool_scheduler.patch b/third_party/mkl_dnn/onednn_acl_threadpool_scheduler.patch index 7e3725af270292..0e0cb39e82f1bb 100644 --- a/third_party/mkl_dnn/onednn_acl_threadpool_scheduler.patch +++ b/third_party/mkl_dnn/onednn_acl_threadpool_scheduler.patch @@ -1,3 +1,20 @@ + ******************************************************************************* + Copyright 2023 Arm Limited and affiliates. + SPDX-License-Identifier: Apache-2.0 + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ******************************************************************************* + diff --git a/src/cpu/aarch64/acl_threadpool_scheduler.cpp b/src/cpu/aarch64/acl_threadpool_scheduler.cpp index 418d7f30f..439ca862e 100644 --- a/src/cpu/aarch64/acl_threadpool_scheduler.cpp From 97f57ad350fc24be5e58818235217d9cb9aaac2c Mon Sep 17 00:00:00 2001 From: Austin Anderson Date: Thu, 29 Jun 2023 12:35:27 -0700 Subject: [PATCH 031/376] Fix confusing equality checks and some other stuff --- ci/official/any.sh | 8 +- ci/official/bazelrcs/gpu.bazelrc | 127 ------------------ ci/official/code_check_changed_files.sh | 4 +- ci/official/code_check_full.sh | 4 +- ci/official/envs/local_cpu | 2 +- ci/official/envs/nightly_cpu | 2 +- ci/official/libtensorflow.sh | 12 +- ci/official/pycpp.sh | 4 +- .../utilities/code_check_changed_files.bats | 2 +- ci/official/utilities/docker.sh | 2 +- ci/official/wheel.sh | 10 +- 11 files changed, 25 insertions(+), 152 deletions(-) delete mode 100644 ci/official/bazelrcs/gpu.bazelrc diff --git a/ci/official/any.sh b/ci/official/any.sh index f5a6278c99ad2f..658a803793c31d 100755 --- a/ci/official/any.sh +++ b/ci/official/any.sh @@ -8,8 +8,8 @@ set -o allexport && source "$TFCI" && set +o allexport cd "$TFCI_GIT_DIR" && mkdir -p build tfrun() { "$@"; } -[[ "$TFCI_COPYBARA_ENABLE" = 1 ]] && source ./ci/official/utilities/copybara.sh -[[ "$TFCI_DOCKER_ENABLE" = 1 ]] && source ./ci/official/utilities/docker.sh +[[ "$TFCI_COPYBARA_ENABLE" == 1 ]] && source ./ci/official/utilities/copybara.sh +[[ "$TFCI_DOCKER_ENABLE" == 1 ]] && source ./ci/official/utilities/docker.sh ./ci/official/utilities/generate_index_html.sh build/index.html # Parse options and build targets into arrays, so that shelllint doesn't yell @@ -25,7 +25,7 @@ config=( $(echo "$CONFIG_OPTIONS" ) ) test_flags=( $(echo "$TEST_FLAGS" ) ) set -e -[[ "$TFCI_NVIDIA_SMI_ENABLE" = 1 ]] && tfrun nvidia-smi +[[ "$TFCI_NVIDIA_SMI_ENABLE" == 1 ]] && tfrun nvidia-smi if [[ -s build_targets.txt ]]; then tfrun bazel "${TFCI_BAZEL_BAZELRC_ARGS[@]}" "${config[@]}" "${filtered_build_targets[@]}" @@ -33,7 +33,7 @@ fi if [[ "${PIP_WHEEL}" -eq "1" ]]; then # Update the version numbers to build a "nightly" package - [[ "$TFCI_NIGHTLY_UPDATE_VERSION_ENABLE" = 1 ]] && tfrun python3 tensorflow/tools/ci_build/update_version.py --nightly + [[ "$TFCI_NIGHTLY_UPDATE_VERSION_ENABLE" == 1 ]] && tfrun python3 tensorflow/tools/ci_build/update_version.py --nightly tfrun bazel "${TFCI_BAZEL_BAZELRC_ARGS[@]}" build "${TFCI_BAZEL_CACHE_ARGS[@]}" tensorflow/tools/pip_package:build_pip_package tfrun ./bazel-bin/tensorflow/tools/pip_package/build_pip_package build "${TFCI_BUILD_PIP_PACKAGE_ARGS[@]}" diff --git a/ci/official/bazelrcs/gpu.bazelrc b/ci/official/bazelrcs/gpu.bazelrc deleted file mode 100644 index 50ea575205967c..00000000000000 --- a/ci/official/bazelrcs/gpu.bazelrc +++ /dev/null @@ -1,127 +0,0 @@ -# This bazelrc can build a GPU-supporting TF package. - -# Convenient cache configurations -# Use a cache directory mounted to /tf/cache. Very useful! -build:sigbuild_local_cache --disk_cache=/tf/cache -# Use the public-access TF DevInfra cache (read only) -build:sigbuild_remote_cache --remote_cache="https://storage.googleapis.com/tensorflow-devinfra-bazel-cache/september2022" --remote_upload_local_results=false -# Write to the TF DevInfra cache (only works for internal TF CI) -build:sigbuild_remote_cache_push --remote_cache="https://storage.googleapis.com/tensorflow-devinfra-bazel-cache/september2022" --google_default_credentials -# Change the value of CACHEBUSTER when upgrading the toolchain, or when testing -# different compilation methods. E.g. for a PR to test a new CUDA version, set -# the CACHEBUSTER to the PR number. -build --action_env=CACHEBUSTER=501872366 - -# Use Python 3.X as installed in container image -build --action_env PYTHON_BIN_PATH="/usr/bin/python3" -build --action_env PYTHON_LIB_PATH="/usr/lib/tf_python" -build --python_path="/usr/bin/python3" - -# Build TensorFlow v2 -build --define=tf_api_version=2 --action_env=TF2_BEHAVIOR=1 - -# Target the AVX instruction set -build --copt=-mavx --host_copt=-mavx - -# Disable clang extention that rejects type definitions within offsetof. -# This was added in clang-16 by https://reviews.llvm.org/D133574. -# Can be removed once upb is updated, since a type definition is used within -# offset of in the current version of ubp. -# See https://github.com/protocolbuffers/upb/blob/9effcbcb27f0a665f9f345030188c0b291e32482/upb/upb.c#L183. -build --copt=-Wno-gnu-offsetof-extensions - -# Use lld as the linker -build --linkopt="-fuse-ld=lld" -build --linkopt="-lm" - -# Store performance profiling log in the mounted artifact directory. -# The profile can be viewed by visiting chrome://tracing in a Chrome browser. -# See https://docs.bazel.build/versions/main/skylark/performance.html#performance-profiling -build --profile=/tf/pkg/profile.json.gz - -# CUDA: Set up compilation CUDA version and paths -build --@local_config_cuda//:enable_cuda -build --@local_config_cuda//:cuda_compiler=clang -build --repo_env TF_NEED_CUDA=1 -build --config cuda_clang -build --action_env=TF_CUDA_VERSION="11" -build --action_env=TF_CUDNN_VERSION="8" -build --action_env=CUDA_TOOLKIT_PATH="/usr/local/cuda-11.8" -build --action_env=GCC_HOST_COMPILER_PATH="/dt9/usr/bin/gcc" -build --action_env=CLANG_CUDA_COMPILER_PATH="/usr/lib/llvm-16/bin/clang" -build --action_env=TF_CUDA_CLANG="1" -build --action_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:/usr/local/tensorrt/lib" -build --crosstool_top="@sigbuild-r2.14-clang_config_cuda//crosstool:toolchain" - -# CUDA: Enable TensorRT optimizations -# https://developer.nvidia.com/tensorrt -build --repo_env TF_NEED_TENSORRT=1 - -# CUDA: Select supported compute capabilities (supported graphics cards). -# This is the same as the official TensorFlow builds. -# See https://developer.nvidia.com/cuda-gpus#compute -# TODO(angerson, perfinion): What does sm_ vs compute_ mean? -# TODO(angerson, perfinion): How can users select a good value for this? -build --repo_env=TF_CUDA_COMPUTE_CAPABILITIES="sm_35,sm_50,sm_60,sm_70,sm_75,compute_80" - -# Test-related settings below this point. -test --build_tests_only --keep_going --test_output=errors --verbose_failures=true -test --test_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64" -# Local test jobs has to be 4 because parallel_gpu_execute is fragile, I think -test --test_timeout=300,450,1200,3600 --local_test_jobs=4 --run_under=//tensorflow/tools/ci_build/gpu_build:parallel_gpu_execute -# Give only the list of failed tests at the end of the log -test --test_summary=short - -# "nonpip" tests are regular py_test tests. -# Pass --config=nonpip to run the same suite of tests. If you want to run just -# one test for investigation, you don't need --config=nonpip; just run the -# bazel test invocation as normal. -test:nonpip_filters --test_tag_filters=gpu,requires-gpu,-no_gpu,-no_oss,-oss_excluded,-oss_serial,-no_cuda11,-no_oss_py38,-no_oss_py39,-no_oss_py310 -test:nonpip_filters --build_tag_filters=gpu,requires-gpu,-no_gpu,-no_oss,-oss_excluded,-oss_serial,-no_cuda11,-no_oss_py38,-no_oss_py39,-no_oss_py310 -test:nonpip_filters --test_lang_filters=py --test_size_filters=small,medium -test:nonpip --config=nonpip_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/compiler/xrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... - -# For building libtensorflow archives -test:libtensorflow_test -- //tensorflow/tools/lib_package:libtensorflow_test //tensorflow/tools/lib_package:libtensorflow_java_test -build:libtensorflow_build -- //tensorflow/tools/lib_package:libtensorflow.tar.gz //tensorflow/tools/lib_package:libtensorflow_jni.tar.gz //tensorflow/java:libtensorflow.jar //tensorflow/java:libtensorflow-src.jar //tensorflow/tools/lib_package:libtensorflow_proto.zip - -# For outputting Build Event Protocol files -build:build_event_export --build_event_json_file=/tf/pkg/bep.json - -# For Remote Build Execution. -build:rbe --google_default_credentials -build:rbe --bes_backend=buildeventservice.googleapis.com -build:rbe --bes_results_url="https://source.cloud.google.com/results/invocations" -build:rbe --bes_timeout=600s -build:rbe --define=EXECUTOR=remote -build:rbe --jobs=800 -build:rbe --remote_executor=grpcs://remotebuildexecution.googleapis.com -build:rbe --remote_timeout=3600 -build:rbe --spawn_strategy=remote,worker,standalone,local -build:rbe --remote_download_toplevel -build:rbe --action_env=PATH="/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/local/go/bin" -build:rbe --linkopt=-lrt --host_linkopt=-lrt --linkopt=-lm --host_linkopt=-lm # Unclear why this is here -build:rbe --host_crosstool_top="@sigbuild-r2.14-clang_config_cuda//crosstool:toolchain" -build:rbe --crosstool_top="@sigbuild-r2.14-clang_config_cuda//crosstool:toolchain" -build:rbe --extra_toolchains="@sigbuild-r2.14-clang_config_cuda//crosstool:toolchain-linux-x86_64" -build:rbe --extra_execution_platforms="@sigbuild-r2.14-clang_config_platform//:platform" -build:rbe --host_platform="@sigbuild-r2.14-clang_config_platform//:platform" -build:rbe --platforms="@sigbuild-r2.14-clang_config_platform//:platform" -# Python config is the same across all containers because the binary is the same -build:rbe --repo_env=TF_PYTHON_CONFIG_REPO="@sigbuild-r2.14-clang_config_python" -build:rbe --remote_instance_name=projects/tensorflow-testing/instances/default_instance -build:rbe --project_id="tensorflow-testing" - -# For Remote build execution -- GPU configuration -build:rbe --repo_env=REMOTE_GPU_TESTING=1 -test:rbe --test_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64" -build:rbe --repo_env=TF_CUDA_CONFIG_REPO="@sigbuild-r2.14-clang_config_cuda" -build:rbe --repo_env=TF_TENSORRT_CONFIG_REPO="@sigbuild-r2.14-clang_config_tensorrt" -build:rbe --repo_env=TF_NCCL_CONFIG_REPO="@sigbuild-r2.14-clang_config_nccl" -build:rbe --repo_env=TF_PYTHON_CONFIG_REPO="@sigbuild-r2.14-clang_config_python" - -# For continuous builds -test:pycpp_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-benchmark-test,-v1only,gpu,-no_gpu,-no_gpu_presubmit,-no_cuda11 -test:pycpp_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-benchmark-test,-v1only,gpu,-no_gpu,-no_gpu_presubmit,-no_cuda11 -test:pycpp_filters --test_lang_filters=cc,py --test_size_filters=small,medium -test:pycpp --config=pycpp_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/compiler/xrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... diff --git a/ci/official/code_check_changed_files.sh b/ci/official/code_check_changed_files.sh index 48e8b4920c0766..e18e5f4ae91634 100755 --- a/ci/official/code_check_changed_files.sh +++ b/ci/official/code_check_changed_files.sh @@ -8,8 +8,8 @@ set -o allexport && source "$TFCI" && set +o allexport cd "$TFCI_GIT_DIR" && mkdir -p build tfrun() { "$@"; } -[[ "$TFCI_COPYBARA_ENABLE" = 1 ]] && source ./ci/official/utilities/copybara.sh -[[ "$TFCI_DOCKER_ENABLE" = 1 ]] && source ./ci/official/utilities/docker.sh +[[ "$TFCI_COPYBARA_ENABLE" == 1 ]] && source ./ci/official/utilities/copybara.sh +[[ "$TFCI_DOCKER_ENABLE" == 1 ]] && source ./ci/official/utilities/docker.sh ./ci/official/utilities/generate_index_html.sh build/index.html tfrun bats ./ci/official/utilities/code_check_changed_files.bats --timing --output build diff --git a/ci/official/code_check_full.sh b/ci/official/code_check_full.sh index d2f7ef4b4ecdf1..dfadd98de8301a 100755 --- a/ci/official/code_check_full.sh +++ b/ci/official/code_check_full.sh @@ -8,8 +8,8 @@ set -o allexport && source "$TFCI" && set +o allexport cd "$TFCI_GIT_DIR" && mkdir -p build tfrun() { "$@"; } -[[ "$TFCI_COPYBARA_ENABLE" = 1 ]] && source ./ci/official/utilities/copybara.sh -[[ "$TFCI_DOCKER_ENABLE" = 1 ]] && source ./ci/official/utilities/docker.sh +[[ "$TFCI_COPYBARA_ENABLE" == 1 ]] && source ./ci/official/utilities/copybara.sh +[[ "$TFCI_DOCKER_ENABLE" == 1 ]] && source ./ci/official/utilities/docker.sh ./ci/official/utilities/generate_index_html.sh build/index.html tfrun bats ./ci/official/utilities/code_check_full.bats --timing --output build diff --git a/ci/official/envs/local_cpu b/ci/official/envs/local_cpu index cf3d1137e77fb8..820d48f1f30557 100644 --- a/ci/official/envs/local_cpu +++ b/ci/official/envs/local_cpu @@ -1,4 +1,4 @@ -TFCI_BAZEL_BAZELRC_ARGS=(--bazelrc ./ci/official/utilities/bazelrcs/cpu.bazelrc) +TFCI_BAZEL_BAZELRC_ARGS=(--bazelrc ./ci/official/bazelrcs/cpu.bazelrc) TFCI_BAZEL_CACHE_ARGS=(--config sigbuild_remote_cache) TFCI_BUILD_PIP_PACKAGE_ARGS=("--cpu") TFCI_COPYBARA_ENABLE=0 diff --git a/ci/official/envs/nightly_cpu b/ci/official/envs/nightly_cpu index 02ec5f0a33621f..56ce9767a83495 100644 --- a/ci/official/envs/nightly_cpu +++ b/ci/official/envs/nightly_cpu @@ -1,4 +1,4 @@ -TFCI_BAZEL_BAZELRC_ARGS=(--bazelrc ./ci/official/utilities/bazelrcs/cpu.bazelrc) +TFCI_BAZEL_BAZELRC_ARGS=(--bazelrc ./ci/official/bazelrcs/cpu.bazelrc) TFCI_BAZEL_CACHE_ARGS=(--config sigbuild_remote_cache_push) TFCI_BUILD_PIP_PACKAGE_ARGS=("--cpu" "--nightly_flag") TFCI_COPYBARA_ENABLE=1 diff --git a/ci/official/libtensorflow.sh b/ci/official/libtensorflow.sh index bb5472d8233998..28071d8d4521bd 100755 --- a/ci/official/libtensorflow.sh +++ b/ci/official/libtensorflow.sh @@ -8,24 +8,24 @@ set -o allexport && source "$TFCI" && set +o allexport cd "$TFCI_GIT_DIR" && mkdir -p build tfrun() { "$@"; } -[[ "$TFCI_COPYBARA_ENABLE" = 1 ]] && source ./ci/official/utilities/copybara.sh -[[ "$TFCI_DOCKER_ENABLE" = 1 ]] && source ./ci/official/utilities/docker.sh +[[ "$TFCI_COPYBARA_ENABLE" == 1 ]] && source ./ci/official/utilities/copybara.sh +[[ "$TFCI_DOCKER_ENABLE" == 1 ]] && source ./ci/official/utilities/docker.sh ./ci/official/utilities/generate_index_html.sh build/index.html # Record GPU count and CUDA version status -[[ "$TFCI_NVIDIA_SMI_ENABLE" = 1 ]] && tfrun nvidia-smi +[[ "$TFCI_NVIDIA_SMI_ENABLE" == 1 ]] && tfrun nvidia-smi # Update the version numbers for Nightly only -[[ "$TFCI_NIGHTLY_UPDATE_VERSION_ENABLE" = 1 ]] && tfrun python3 tensorflow/tools/ci_build/update_version.py --nightly +[[ "$TFCI_NIGHTLY_UPDATE_VERSION_ENABLE" == 1 ]] && tfrun python3 tensorflow/tools/ci_build/update_version.py --nightly tfrun bazel "${TFCI_BAZEL_BAZELRC_ARGS[@]}" test "${TFCI_BAZEL_CACHE_ARGS[@]}" --config=libtensorflow_test tfrun bazel "${TFCI_BAZEL_BAZELRC_ARGS[@]}" build "${TFCI_BAZEL_CACHE_ARGS[@]}" --config=libtensorflow_build tfrun ./ci/official/utilities/repack_libtensorflow.sh build "$TFCI_LIB_SUFFIX" -if [[ "$TFCI_UPLOAD_LIB_ENABLE" = 1 ]]; then +if [[ "$TFCI_UPLOAD_LIB_ENABLE" == 1 ]]; then gsutil cp build/*.tar.gz "$TFCI_UPLOAD_LIB_GCS_URI" - if [[ "$TFCI_UPLOAD_LIB_LATEST_ENABLE" = 1 ]]; then + if [[ "$TFCI_UPLOAD_LIB_LATEST_ENABLE" == 1 ]]; then gsutil cp build/*.tar.gz "$TFCI_UPLOAD_LIB_LATEST_GCS_URI" fi fi diff --git a/ci/official/pycpp.sh b/ci/official/pycpp.sh index 59d3e0ddc3b74e..5eeda2ec1ff89a 100755 --- a/ci/official/pycpp.sh +++ b/ci/official/pycpp.sh @@ -8,8 +8,8 @@ set -o allexport && source "$TFCI" && set +o allexport cd "$TFCI_GIT_DIR" && mkdir -p build tfrun() { "$@"; } -[[ "$TFCI_COPYBARA_ENABLE" = 1 ]] && source ./ci/official/utilities/copybara.sh -[[ "$TFCI_DOCKER_ENABLE" = 1 ]] && source ./ci/official/utilities/docker.sh +[[ "$TFCI_COPYBARA_ENABLE" == 1 ]] && source ./ci/official/utilities/copybara.sh +[[ "$TFCI_DOCKER_ENABLE" == 1 ]] && source ./ci/official/utilities/docker.sh ./ci/official/utilities/generate_index_html.sh build/index.html # TODO(b/284172313) Revert this difference between presubmits and continuous. RBE serverside behavior is causing flakes, diff --git a/ci/official/utilities/code_check_changed_files.bats b/ci/official/utilities/code_check_changed_files.bats index d912242a2b17c8..8704ddb53064a9 100644 --- a/ci/official/utilities/code_check_changed_files.bats +++ b/ci/official/utilities/code_check_changed_files.bats @@ -23,7 +23,7 @@ setup_file() { # Note that you could generate a list of all the affected targets with e.g.: # bazel query $(paste -sd "+" $BATS_FILE_TMPDIR/changed_files) --keep_going # Only shows Added, Changed, Modified, Renamed, and Type-changed files - if [[ "$(git rev-parse --abbrev-ref HEAD)" = "pull_branch" ]]; then + if [[ "$(git rev-parse --abbrev-ref HEAD)" == "pull_branch" ]]; then # TF's CI runs 'git fetch origin "pull/PR#/merge:pull_branch"' # To get the as-merged branch during the CI tests git diff --diff-filter ACMRT --name-only pull_branch^ pull_branch > $BATS_FILE_TMPDIR/changed_files diff --git a/ci/official/utilities/docker.sh b/ci/official/utilities/docker.sh index 4d69e5a5cf6602..b84ee381e518be 100755 --- a/ci/official/utilities/docker.sh +++ b/ci/official/utilities/docker.sh @@ -7,7 +7,7 @@ set -euxo pipefail -o history set -o allexport && source "$TFCI" && set +o allexport trap "docker rm -f tf" EXIT -if [[ "$TFCI_DOCKER_PULL_ENABLE" = 1 ]]; then +if [[ "$TFCI_DOCKER_PULL_ENABLE" == 1 ]]; then docker pull "$TFCI_DOCKER_IMAGE" fi docker run "${TFCI_DOCKER_GPU_ARGS[@]}" --name tf -w "$TFCI_GIT_DIR" -itd --rm \ diff --git a/ci/official/wheel.sh b/ci/official/wheel.sh index bcc426364c39bf..665a24d64d69e0 100755 --- a/ci/official/wheel.sh +++ b/ci/official/wheel.sh @@ -8,21 +8,21 @@ set -o allexport && source "$TFCI" && set +o allexport cd "$TFCI_GIT_DIR" && mkdir -p build tfrun() { "$@"; } -[[ "$TFCI_COPYBARA_ENABLE" = 1 ]] && source ./ci/official/utilities/copybara.sh -[[ "$TFCI_DOCKER_ENABLE" = 1 ]] && source ./ci/official/utilities/docker.sh +[[ "$TFCI_COPYBARA_ENABLE" == 1 ]] && source ./ci/official/utilities/copybara.sh +[[ "$TFCI_DOCKER_ENABLE" == 1 ]] && source ./ci/official/utilities/docker.sh ./ci/official/utilities/generate_index_html.sh build/index.html # Record GPU count and CUDA version status -[[ "$TFCI_NVIDIA_SMI_ENABLE" = 1 ]] && tfrun nvidia-smi +[[ "$TFCI_NVIDIA_SMI_ENABLE" == 1 ]] && tfrun nvidia-smi # Update the version numbers for Nightly only -[[ "$TFCI_NIGHTLY_UPDATE_VERSION_ENABLE" = 1 ]] && tfrun python3 tensorflow/tools/ci_build/update_version.py --nightly +[[ "$TFCI_NIGHTLY_UPDATE_VERSION_ENABLE" == 1 ]] && tfrun python3 tensorflow/tools/ci_build/update_version.py --nightly tfrun bazel "${TFCI_BAZEL_BAZELRC_ARGS[@]}" build "${TFCI_BAZEL_CACHE_ARGS[@]}" //tensorflow/tools/pip_package:build_pip_package tfrun ./bazel-bin/tensorflow/tools/pip_package/build_pip_package build "${TFCI_BUILD_PIP_PACKAGE_ARGS[@]}" tfrun ./ci/official/utilities/rename_and_verify_wheels.sh build -if [[ "$TFCI_UPLOAD_ENABLE" = 1 ]]; then +if [[ "$TFCI_UPLOAD_ENABLE" == 1 ]]; then twine upload "${TFCI_UPLOAD_PYPI_ARGS[@]}" build/*.whl gsutil cp build/*.whl "$TFCI_UPLOAD_GCS_DESTINATION" fi From e8521328710ab3cb8b9bdd61c4be9d2945d80a1c Mon Sep 17 00:00:00 2001 From: David Svantesson Date: Fri, 30 Jun 2023 11:09:00 +0000 Subject: [PATCH 032/376] Fix reorder shape check --- third_party/mkl_dnn/onednn_acl_reorder.patch | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/third_party/mkl_dnn/onednn_acl_reorder.patch b/third_party/mkl_dnn/onednn_acl_reorder.patch index 71312f6e87b3f0..05ef1160e1469b 100644 --- a/third_party/mkl_dnn/onednn_acl_reorder.patch +++ b/third_party/mkl_dnn/onednn_acl_reorder.patch @@ -58,10 +58,10 @@ index 000000000..061751b55 +} // namespace dnnl diff --git a/src/cpu/aarch64/acl_reorder.hpp b/src/cpu/aarch64/acl_reorder.hpp new file mode 100644 -index 000000000..930ccb40e +index 000000000..91d23e06d --- /dev/null +++ b/src/cpu/aarch64/acl_reorder.hpp -@@ -0,0 +1,257 @@ +@@ -0,0 +1,260 @@ +/******************************************************************************* +* Copyright 2023 Arm Ltd. and affiliates +* @@ -196,6 +196,9 @@ index 000000000..930ccb40e + // No format_tag for these, have to deduce from stride + case 2: + { ++ if(dst_md->dims[0] == 1 || dst_md->dims[1] == 1){ ++ return status::unimplemented; ++ } + int dst_dim_1 = dst_md->dims[1]; + int dst_dim_0_stride = dst_md->format_desc.blocking.strides[0]; + int dst_dim_1_stride = dst_md->format_desc.blocking.strides[1]; From f9a9515fd87f58184edd1b8706d1665b629364b8 Mon Sep 17 00:00:00 2001 From: Austin Anderson Date: Fri, 30 Jun 2023 16:47:31 -0700 Subject: [PATCH 033/376] Fill in nightly config options --- ci/official/any.sh | 1 - ci/official/envs/nightly_cpu_py310 | 22 +++++++++++++++++++ ci/official/envs/nightly_cpu_py311 | 22 +++++++++++++++++++ .../envs/{nightly_cpu => nightly_cpu_py39} | 12 +++++----- ci/official/envs/nightly_nvidia_py310 | 22 +++++++++++++++++++ ci/official/envs/nightly_nvidia_py311 | 22 +++++++++++++++++++ ci/official/envs/nightly_nvidia_py39 | 22 +++++++++++++++++++ 7 files changed, 116 insertions(+), 7 deletions(-) create mode 100644 ci/official/envs/nightly_cpu_py310 create mode 100644 ci/official/envs/nightly_cpu_py311 rename ci/official/envs/{nightly_cpu => nightly_cpu_py39} (77%) create mode 100644 ci/official/envs/nightly_nvidia_py310 create mode 100644 ci/official/envs/nightly_nvidia_py311 create mode 100644 ci/official/envs/nightly_nvidia_py39 diff --git a/ci/official/any.sh b/ci/official/any.sh index 658a803793c31d..5610f9ffed9a0c 100755 --- a/ci/official/any.sh +++ b/ci/official/any.sh @@ -16,7 +16,6 @@ tfrun() { "$@"; } # about readability. We can't pipe into 'read -ra' to create an array because # piped commands run in subshells, which can't store variables outside of the # subshell environment. -# See https://g3doc.corp.google.com/devtools/staticanalysis/pipeline/analyzers/shell/lint/g3doc/findings/SC2086.md?cl=head # Ignore grep failures since we're using it for basic filtering set +e filtered_build_targets=( $(echo "$BUILD_TARGETS" | tr ' ' '\n' | grep . | tee build_targets.txt) ) diff --git a/ci/official/envs/nightly_cpu_py310 b/ci/official/envs/nightly_cpu_py310 new file mode 100644 index 00000000000000..c52ab0eb11d431 --- /dev/null +++ b/ci/official/envs/nightly_cpu_py310 @@ -0,0 +1,22 @@ +TFCI_BAZEL_BAZELRC_ARGS=(--bazelrc ./ci/official/bazelrcs/cpu.bazelrc) +TFCI_BAZEL_CACHE_ARGS=(--config sigbuild_remote_cache_push) +TFCI_BUILD_PIP_PACKAGE_ARGS=(--cpu --nightly_flag) +TFCI_COPYBARA_ENABLE=1 +TFCI_DOCKER_ENABLE=1 +TFCI_DOCKER_GPU_ARGS=() +TFCI_DOCKER_IMAGE=tensorflow/build:latest-python3.10 +TFCI_DOCKER_PULL_ENABLE=1 +TFCI_GIT_DIR=$KOKORO_ARTIFACTS_DIR/github/tensorflow +TFCI_LIB_SUFFIX="-cpu-linux-x86_64" +TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1 +TFCI_NVIDIA_SMI_ENABLE=1 +TFCI_UPLOAD_LIB_ENABLE= +TFCI_UPLOAD_LIB_LATEST_ENABLE= +TFCI_UPLOAD_LIB_LATEST_URI="gs://libtensorflow-nightly/latest" +TFCI_UPLOAD_LIB_URI="gs://libtensorflow-nightly/$(date -I)" +#TFCI_UPLOAD_LIB_URI="gs://tensorflow-release-packages/$RELEASE_VERSION/$KOKORO_GIT_COMMIT_tensorflow" +TFCI_UPLOAD_WHL_GCS_ENABLE= +TFCI_UPLOAD_WHL_GCS_URI= +#TFCI_UPLOAD_WHL_GCS_URI="gs://tensorflow-release-packages/$RELEASE_VERSION/$KOKORO_GIT_COMMIT_tensorflow" +TFCI_UPLOAD_WHL_PYPI_ARGS=( --config-file "$KOKORO_KEYSTORE_DIR/73361_tensorflow_pypirc_using_global_api_token" --repository testpypi ) +TFCI_UPLOAD_WHL_PYPI_ENABLE= diff --git a/ci/official/envs/nightly_cpu_py311 b/ci/official/envs/nightly_cpu_py311 new file mode 100644 index 00000000000000..928f20532259f1 --- /dev/null +++ b/ci/official/envs/nightly_cpu_py311 @@ -0,0 +1,22 @@ +TFCI_BAZEL_BAZELRC_ARGS=(--bazelrc ./ci/official/bazelrcs/cpu.bazelrc) +TFCI_BAZEL_CACHE_ARGS=(--config sigbuild_remote_cache_push) +TFCI_BUILD_PIP_PACKAGE_ARGS=(--cpu --nightly_flag) +TFCI_COPYBARA_ENABLE=1 +TFCI_DOCKER_ENABLE=1 +TFCI_DOCKER_GPU_ARGS=() +TFCI_DOCKER_IMAGE=tensorflow/build:latest-python3.11 +TFCI_DOCKER_PULL_ENABLE=1 +TFCI_GIT_DIR=$KOKORO_ARTIFACTS_DIR/github/tensorflow +TFCI_LIB_SUFFIX="-cpu-linux-x86_64" +TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1 +TFCI_NVIDIA_SMI_ENABLE=1 +TFCI_UPLOAD_LIB_ENABLE= +TFCI_UPLOAD_LIB_LATEST_ENABLE= +TFCI_UPLOAD_LIB_LATEST_URI="gs://libtensorflow-nightly/latest" +TFCI_UPLOAD_LIB_URI="gs://libtensorflow-nightly/$(date -I)" +#TFCI_UPLOAD_LIB_URI="gs://tensorflow-release-packages/$RELEASE_VERSION/$KOKORO_GIT_COMMIT_tensorflow" +TFCI_UPLOAD_WHL_GCS_ENABLE= +TFCI_UPLOAD_WHL_GCS_URI= +#TFCI_UPLOAD_WHL_GCS_URI="gs://tensorflow-release-packages/$RELEASE_VERSION/$KOKORO_GIT_COMMIT_tensorflow" +TFCI_UPLOAD_WHL_PYPI_ARGS=( --config-file "$KOKORO_KEYSTORE_DIR/73361_tensorflow_pypirc_using_global_api_token" --repository testpypi ) +TFCI_UPLOAD_WHL_PYPI_ENABLE= diff --git a/ci/official/envs/nightly_cpu b/ci/official/envs/nightly_cpu_py39 similarity index 77% rename from ci/official/envs/nightly_cpu rename to ci/official/envs/nightly_cpu_py39 index 56ce9767a83495..03c114995af013 100644 --- a/ci/official/envs/nightly_cpu +++ b/ci/official/envs/nightly_cpu_py39 @@ -1,17 +1,17 @@ TFCI_BAZEL_BAZELRC_ARGS=(--bazelrc ./ci/official/bazelrcs/cpu.bazelrc) TFCI_BAZEL_CACHE_ARGS=(--config sigbuild_remote_cache_push) -TFCI_BUILD_PIP_PACKAGE_ARGS=("--cpu" "--nightly_flag") +TFCI_BUILD_PIP_PACKAGE_ARGS=(--cpu --nightly_flag) TFCI_COPYBARA_ENABLE=1 TFCI_DOCKER_ENABLE=1 TFCI_DOCKER_GPU_ARGS=() -TFCI_DOCKER_IMAGE= +TFCI_DOCKER_IMAGE=tensorflow/build:latest-python3.9 TFCI_DOCKER_PULL_ENABLE=1 -TFCI_GIT_DIR=/tf/tensorflow +TFCI_GIT_DIR=$KOKORO_ARTIFACTS_DIR/github/tensorflow TFCI_LIB_SUFFIX="-cpu-linux-x86_64" TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1 TFCI_NVIDIA_SMI_ENABLE=1 -TFCI_UPLOAD_LIB_ENABLE=1 -TFCI_UPLOAD_LIB_LATEST_ENABLE=1 +TFCI_UPLOAD_LIB_ENABLE= +TFCI_UPLOAD_LIB_LATEST_ENABLE= TFCI_UPLOAD_LIB_LATEST_URI="gs://libtensorflow-nightly/latest" TFCI_UPLOAD_LIB_URI="gs://libtensorflow-nightly/$(date -I)" #TFCI_UPLOAD_LIB_URI="gs://tensorflow-release-packages/$RELEASE_VERSION/$KOKORO_GIT_COMMIT_tensorflow" @@ -19,4 +19,4 @@ TFCI_UPLOAD_WHL_GCS_ENABLE= TFCI_UPLOAD_WHL_GCS_URI= #TFCI_UPLOAD_WHL_GCS_URI="gs://tensorflow-release-packages/$RELEASE_VERSION/$KOKORO_GIT_COMMIT_tensorflow" TFCI_UPLOAD_WHL_PYPI_ARGS=( --config-file "$KOKORO_KEYSTORE_DIR/73361_tensorflow_pypirc_using_global_api_token" --repository testpypi ) -TFCI_UPLOAD_WHL_PYPI_ENABLE=1 +TFCI_UPLOAD_WHL_PYPI_ENABLE= diff --git a/ci/official/envs/nightly_nvidia_py310 b/ci/official/envs/nightly_nvidia_py310 new file mode 100644 index 00000000000000..891ea3d1181f68 --- /dev/null +++ b/ci/official/envs/nightly_nvidia_py310 @@ -0,0 +1,22 @@ +TFCI_BAZEL_BAZELRC_ARGS=(--bazelrc ./ci/official/bazelrcs/nvidia.bazelrc) +TFCI_BAZEL_CACHE_ARGS=(--config sigbuild_remote_cache_push) +TFCI_BUILD_PIP_PACKAGE_ARGS=(--gpu --nightly_flag) +TFCI_COPYBARA_ENABLE=1 +TFCI_DOCKER_ENABLE=1 +TFCI_DOCKER_GPU_ARGS=(--gpus all) +TFCI_DOCKER_IMAGE=tensorflow/build:latest-python3.10 +TFCI_DOCKER_PULL_ENABLE=1 +TFCI_GIT_DIR=$KOKORO_ARTIFACTS_DIR/github/tensorflow +TFCI_LIB_SUFFIX="-gpu-linux-x86_64" +TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1 +TFCI_NVIDIA_SMI_ENABLE=1 +TFCI_UPLOAD_LIB_ENABLE= +TFCI_UPLOAD_LIB_LATEST_ENABLE= +TFCI_UPLOAD_LIB_LATEST_URI="gs://libtensorflow-nightly/latest" +TFCI_UPLOAD_LIB_URI="gs://libtensorflow-nightly/$(date -I)" +#TFCI_UPLOAD_LIB_URI="gs://tensorflow-release-packages/$RELEASE_VERSION/$KOKORO_GIT_COMMIT_tensorflow" +TFCI_UPLOAD_WHL_GCS_ENABLE= +TFCI_UPLOAD_WHL_GCS_URI= +#TFCI_UPLOAD_WHL_GCS_URI="gs://tensorflow-release-packages/$RELEASE_VERSION/$KOKORO_GIT_COMMIT_tensorflow" +TFCI_UPLOAD_WHL_PYPI_ARGS=( --config-file "$KOKORO_KEYSTORE_DIR/73361_tensorflow_pypirc_using_global_api_token" --repository testpypi ) +TFCI_UPLOAD_WHL_PYPI_ENABLE= diff --git a/ci/official/envs/nightly_nvidia_py311 b/ci/official/envs/nightly_nvidia_py311 new file mode 100644 index 00000000000000..7515b5b0a310c3 --- /dev/null +++ b/ci/official/envs/nightly_nvidia_py311 @@ -0,0 +1,22 @@ +TFCI_BAZEL_BAZELRC_ARGS=(--bazelrc ./ci/official/bazelrcs/nvidia.bazelrc) +TFCI_BAZEL_CACHE_ARGS=(--config sigbuild_remote_cache_push) +TFCI_BUILD_PIP_PACKAGE_ARGS=(--gpu --nightly_flag) +TFCI_COPYBARA_ENABLE=1 +TFCI_DOCKER_ENABLE=1 +TFCI_DOCKER_GPU_ARGS=(--gpus all) +TFCI_DOCKER_IMAGE=tensorflow/build:latest-python3.11 +TFCI_DOCKER_PULL_ENABLE=1 +TFCI_GIT_DIR=$KOKORO_ARTIFACTS_DIR/github/tensorflow +TFCI_LIB_SUFFIX="-gpu-linux-x86_64" +TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1 +TFCI_NVIDIA_SMI_ENABLE=1 +TFCI_UPLOAD_LIB_ENABLE= +TFCI_UPLOAD_LIB_LATEST_ENABLE= +TFCI_UPLOAD_LIB_LATEST_URI="gs://libtensorflow-nightly/latest" +TFCI_UPLOAD_LIB_URI="gs://libtensorflow-nightly/$(date -I)" +#TFCI_UPLOAD_LIB_URI="gs://tensorflow-release-packages/$RELEASE_VERSION/$KOKORO_GIT_COMMIT_tensorflow" +TFCI_UPLOAD_WHL_GCS_ENABLE= +TFCI_UPLOAD_WHL_GCS_URI= +#TFCI_UPLOAD_WHL_GCS_URI="gs://tensorflow-release-packages/$RELEASE_VERSION/$KOKORO_GIT_COMMIT_tensorflow" +TFCI_UPLOAD_WHL_PYPI_ARGS=( --config-file "$KOKORO_KEYSTORE_DIR/73361_tensorflow_pypirc_using_global_api_token" --repository testpypi ) +TFCI_UPLOAD_WHL_PYPI_ENABLE= diff --git a/ci/official/envs/nightly_nvidia_py39 b/ci/official/envs/nightly_nvidia_py39 new file mode 100644 index 00000000000000..8312efd24cf55e --- /dev/null +++ b/ci/official/envs/nightly_nvidia_py39 @@ -0,0 +1,22 @@ +TFCI_BAZEL_BAZELRC_ARGS=(--bazelrc ./ci/official/bazelrcs/nvidia.bazelrc) +TFCI_BAZEL_CACHE_ARGS=(--config sigbuild_remote_cache_push) +TFCI_BUILD_PIP_PACKAGE_ARGS=(--gpu --nightly_flag) +TFCI_COPYBARA_ENABLE=1 +TFCI_DOCKER_ENABLE=1 +TFCI_DOCKER_GPU_ARGS=(--gpus all) +TFCI_DOCKER_IMAGE=tensorflow/build:latest-python3.9 +TFCI_DOCKER_PULL_ENABLE=1 +TFCI_GIT_DIR=$KOKORO_ARTIFACTS_DIR/github/tensorflow +TFCI_LIB_SUFFIX="-gpu-linux-x86_64" +TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1 +TFCI_NVIDIA_SMI_ENABLE=1 +TFCI_UPLOAD_LIB_ENABLE= +TFCI_UPLOAD_LIB_LATEST_ENABLE= +TFCI_UPLOAD_LIB_LATEST_URI="gs://libtensorflow-nightly/latest" +TFCI_UPLOAD_LIB_URI="gs://libtensorflow-nightly/$(date -I)" +#TFCI_UPLOAD_LIB_URI="gs://tensorflow-release-packages/$RELEASE_VERSION/$KOKORO_GIT_COMMIT_tensorflow" +TFCI_UPLOAD_WHL_GCS_ENABLE= +TFCI_UPLOAD_WHL_GCS_URI= +#TFCI_UPLOAD_WHL_GCS_URI="gs://tensorflow-release-packages/$RELEASE_VERSION/$KOKORO_GIT_COMMIT_tensorflow" +TFCI_UPLOAD_WHL_PYPI_ARGS=( --config-file "$KOKORO_KEYSTORE_DIR/73361_tensorflow_pypirc_using_global_api_token" --repository testpypi ) +TFCI_UPLOAD_WHL_PYPI_ENABLE= From ba3ffc555b3dbac9a9e6be1b8099d3db8346a49c Mon Sep 17 00:00:00 2001 From: Renato Arantes Date: Tue, 30 May 2023 08:07:36 +0000 Subject: [PATCH 034/376] Improving the performance of TF models for aarch64. --- tensorflow/core/common_runtime/mkl_layout_pass.cc | 13 ++++++++++--- tensorflow/core/grappler/optimizers/remapper.cc | 9 +++++++++ 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/tensorflow/core/common_runtime/mkl_layout_pass.cc b/tensorflow/core/common_runtime/mkl_layout_pass.cc index 599d44d3af6858..d07041609928dd 100644 --- a/tensorflow/core/common_runtime/mkl_layout_pass.cc +++ b/tensorflow/core/common_runtime/mkl_layout_pass.cc @@ -394,7 +394,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass { CopyAttrsAll, RewriteIfAtleastOneMklInput, GetRewriteCause()}); rinfo_.push_back({csinfo_.avg_pool, mkl_op_registry::GetMklOpName(csinfo_.avg_pool), - CopyAttrsAll, AlwaysRewrite, GetRewriteCause()}); + CopyAttrsAll, RewriteIfX86, GetRewriteCause()}); rinfo_.push_back({csinfo_.avg_pool_grad, mkl_op_registry::GetMklOpName(csinfo_.avg_pool_grad), CopyAttrsAll, AlwaysRewrite, GetRewriteCause()}); @@ -712,7 +712,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass { #endif // !ENABLE_ONEDNN_V3 rinfo_.push_back({csinfo_.softmax, mkl_op_registry::GetMklOpName(csinfo_.softmax), - CopyAttrsAll, AlwaysRewrite, GetRewriteCause()}); + CopyAttrsAll, RewriteIfX86, GetRewriteCause()}); #ifndef ENABLE_ONEDNN_V3 rinfo_.push_back({csinfo_.squared_difference, @@ -725,7 +725,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass { #endif // !ENABLE_ONEDNN_V3 rinfo_.push_back({csinfo_.transpose, mkl_op_registry::GetMklOpName(csinfo_.transpose), - CopyAttrsAll, AlwaysRewrite, kRewriteForOpNameChange}); + CopyAttrsAll, RewriteIfX86, kRewriteForOpNameChange}); // Add info about which ops to add workspace edge to and the slots. wsinfo_.push_back({csinfo_.lrn, csinfo_.lrn_grad, 0, 2, 1, 3}); @@ -1452,6 +1452,13 @@ class MklLayoutRewritePass : public GraphOptimizationPass { // Default rewrite rule to be used in scenario 1 for rewrite. // @return - true (since we want to always rewrite) static bool AlwaysRewrite(const Node* n) { return true; } + static bool RewriteIfX86(const Node* n) { +#ifdef DNNL_AARCH64_USE_ACL + return false; +#else + return true; +#endif + } // Rewrite rule which considers "context" of the current node to decide if we // should rewrite. By "context" we currently mean all the inputs of current diff --git a/tensorflow/core/grappler/optimizers/remapper.cc b/tensorflow/core/grappler/optimizers/remapper.cc index 3ee0c18b696b52..3b5795edb52492 100644 --- a/tensorflow/core/grappler/optimizers/remapper.cc +++ b/tensorflow/core/grappler/optimizers/remapper.cc @@ -758,6 +758,10 @@ bool FindContractionWithBias(const RemapperContext& ctx, int node_index, IsMatMul(*contraction_node_def) || IsDepthwiseConv2dNative(*contraction_node_def); +#ifdef DNNL_AARCH64_USE_ACL + if (IsDepthwiseConv2dNative(*contraction_node_def)) is_contraction = false; +#endif + if (!is_contraction || !HaveSameDataType(node_def, contraction_node_def) || HasControlFaninOrFanout(*contraction_node_view) || !HasAtMostOneFanoutAtPort0(*contraction_node_view) || @@ -4438,6 +4442,7 @@ Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item, continue; } +#ifndef DNNL_AARCH64_USE_ACL // Remap {Conv2D,Conv3D}+BiasAdd+Add into the _FusedConv2D/3D. if (FindContractionWithBiasAddAndAdd(ctx, i, &contract_with_bias_and_add)) { @@ -4446,6 +4451,7 @@ Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item, &invalidated_nodes, &nodes_to_delete)); continue; } +#endif PadWithConv3D pad_with_conv3d; // Remap Pad+{Conv3D,_FusedConv3D} into the _FusedConv3D. @@ -4482,6 +4488,7 @@ Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item, continue; } +#ifndef DNNL_AARCH64_USE_ACL // Fuse Conv2d + BiasAdd/FusedBatchNorm + Swish. std::map fusedconv2dSwish_matched_nodes_map; std::set fusedconv2dSwish_remove_node_indices; @@ -4493,6 +4500,8 @@ Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item, &invalidated_nodes, &nodes_to_delete)); continue; } +#endif + // Remap Maximum(x, alpha * x) pattern, fuse them into the LeakyRelu(x). std::map mulmax_matched_nodes_map; std::set mulmax_remove_node_indices; From bf05c18391e392df7566c26286dc271d6826c00a Mon Sep 17 00:00:00 2001 From: zehuiw Date: Thu, 6 Jul 2023 15:00:42 -0700 Subject: [PATCH 035/376] Fix doc rendering error for TFLite inference This g3 doc is auto synced to tensorflow.org where the rendering is not correct: https://www.tensorflow.org/lite/guide/inference.md#run_inference_with_dynamic_shape_model --- tensorflow/lite/g3doc/guide/inference.md | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tensorflow/lite/g3doc/guide/inference.md b/tensorflow/lite/g3doc/guide/inference.md index 64054f6cc5330a..2ea6ec267157e6 100644 --- a/tensorflow/lite/g3doc/guide/inference.md +++ b/tensorflow/lite/g3doc/guide/inference.md @@ -616,16 +616,15 @@ running inference in different languages. All the examples assume that the input shape is defined as `[1/None, 10]`, and need to be resized to `[3, 10]`. -
- -###### C++ {.new-tab} +C++ example: ```c++ // Resize input tensors before allocate tensors interpreter->ResizeInputTensor(/*tensor_index=*/0, std::vector{3,10}); interpreter->AllocateTensors(); ``` -###### Python {.new-tab} + +Python example: ```python # Load the TFLite model in TFLite Interpreter From 211d99d4830e4987fe72abe3ed0006c7cbee69c8 Mon Sep 17 00:00:00 2001 From: Austin Anderson Date: Thu, 6 Jul 2023 17:57:28 -0700 Subject: [PATCH 036/376] Dedup setup --- ci/official/any.sh | 13 +---------- ci/official/code_check_changed_files.sh | 13 +---------- ci/official/code_check_full.sh | 13 +---------- ci/official/envs/local_cpu | 1 + ci/official/envs/nightly_cpu_py310 | 5 +++-- ci/official/envs/nightly_cpu_py311 | 5 +++-- ci/official/envs/nightly_cpu_py39 | 5 +++-- ci/official/envs/nightly_nvidia_py310 | 5 +++-- ci/official/envs/nightly_nvidia_py311 | 5 +++-- ci/official/envs/nightly_nvidia_py39 | 5 +++-- ci/official/libtensorflow.sh | 13 +---------- ci/official/pycpp.sh | 13 +---------- ci/official/utilities/setup.sh | 29 +++++++++++++++++++++++++ ci/official/wheel.sh | 13 +---------- 14 files changed, 54 insertions(+), 84 deletions(-) create mode 100755 ci/official/utilities/setup.sh diff --git a/ci/official/any.sh b/ci/official/any.sh index 5610f9ffed9a0c..8bfba811d74419 100755 --- a/ci/official/any.sh +++ b/ci/official/any.sh @@ -1,16 +1,5 @@ #!/bin/bash -# -e: abort script if one command fails -# -u: error if undefined variable used -# -o pipefail: entire command fails if pipe fails. watch out for yes | ... -# -o history: record shell history -set -euxo pipefail -o history -set -o allexport && source "$TFCI" && set +o allexport - -cd "$TFCI_GIT_DIR" && mkdir -p build -tfrun() { "$@"; } -[[ "$TFCI_COPYBARA_ENABLE" == 1 ]] && source ./ci/official/utilities/copybara.sh -[[ "$TFCI_DOCKER_ENABLE" == 1 ]] && source ./ci/official/utilities/docker.sh -./ci/official/utilities/generate_index_html.sh build/index.html +source "${BASH_SOURCE%/*}/utilities/setup.sh" # Parse options and build targets into arrays, so that shelllint doesn't yell # about readability. We can't pipe into 'read -ra' to create an array because diff --git a/ci/official/code_check_changed_files.sh b/ci/official/code_check_changed_files.sh index e18e5f4ae91634..8155d9c1c5e2ab 100755 --- a/ci/official/code_check_changed_files.sh +++ b/ci/official/code_check_changed_files.sh @@ -1,15 +1,4 @@ #!/bin/bash -# -e: abort script if one command fails -# -u: error if undefined variable used -# -o pipefail: entire command fails if pipe fails. watch out for yes | ... -# -o history: record shell history -set -euxo pipefail -o history -set -o allexport && source "$TFCI" && set +o allexport - -cd "$TFCI_GIT_DIR" && mkdir -p build -tfrun() { "$@"; } -[[ "$TFCI_COPYBARA_ENABLE" == 1 ]] && source ./ci/official/utilities/copybara.sh -[[ "$TFCI_DOCKER_ENABLE" == 1 ]] && source ./ci/official/utilities/docker.sh -./ci/official/utilities/generate_index_html.sh build/index.html +source "${BASH_SOURCE%/*}/utilities/setup.sh" tfrun bats ./ci/official/utilities/code_check_changed_files.bats --timing --output build diff --git a/ci/official/code_check_full.sh b/ci/official/code_check_full.sh index dfadd98de8301a..4567143ba0ad59 100755 --- a/ci/official/code_check_full.sh +++ b/ci/official/code_check_full.sh @@ -1,15 +1,4 @@ #!/bin/bash -# -e: abort script if one command fails -# -u: error if undefined variable used -# -o pipefail: entire command fails if pipe fails. watch out for yes | ... -# -o history: record shell history -set -euxo pipefail -o history -set -o allexport && source "$TFCI" && set +o allexport - -cd "$TFCI_GIT_DIR" && mkdir -p build -tfrun() { "$@"; } -[[ "$TFCI_COPYBARA_ENABLE" == 1 ]] && source ./ci/official/utilities/copybara.sh -[[ "$TFCI_DOCKER_ENABLE" == 1 ]] && source ./ci/official/utilities/docker.sh -./ci/official/utilities/generate_index_html.sh build/index.html +source "${BASH_SOURCE%/*}/utilities/setup.sh" tfrun bats ./ci/official/utilities/code_check_full.bats --timing --output build diff --git a/ci/official/envs/local_cpu b/ci/official/envs/local_cpu index 820d48f1f30557..914dd4c856afd4 100644 --- a/ci/official/envs/local_cpu +++ b/ci/official/envs/local_cpu @@ -7,6 +7,7 @@ TFCI_DOCKER_GPU_ARGS=() TFCI_DOCKER_IMAGE=tensorflow/build:latest-python3.9 TFCI_DOCKER_PULL_ENABLE= TFCI_GIT_DIR=/usr/local/google/home/angerson/repos/tensorflow +TFCI_INDEX_HTML_ENABLE=1 TFCI_LIB_SUFFIX="-cpu-linux-x86_64" TFCI_NIGHTLY_UPDATE_VERSION_ENABLE= TFCI_NVIDIA_SMI_ENABLE= diff --git a/ci/official/envs/nightly_cpu_py310 b/ci/official/envs/nightly_cpu_py310 index c52ab0eb11d431..9cff1f3803ff5a 100644 --- a/ci/official/envs/nightly_cpu_py310 +++ b/ci/official/envs/nightly_cpu_py310 @@ -1,3 +1,5 @@ +#TFCI_UPLOAD_LIB_URI="gs://tensorflow-release-packages/$RELEASE_VERSION/$KOKORO_GIT_COMMIT_tensorflow" +#TFCI_UPLOAD_WHL_GCS_URI="gs://tensorflow-release-packages/$RELEASE_VERSION/$KOKORO_GIT_COMMIT_tensorflow" TFCI_BAZEL_BAZELRC_ARGS=(--bazelrc ./ci/official/bazelrcs/cpu.bazelrc) TFCI_BAZEL_CACHE_ARGS=(--config sigbuild_remote_cache_push) TFCI_BUILD_PIP_PACKAGE_ARGS=(--cpu --nightly_flag) @@ -7,6 +9,7 @@ TFCI_DOCKER_GPU_ARGS=() TFCI_DOCKER_IMAGE=tensorflow/build:latest-python3.10 TFCI_DOCKER_PULL_ENABLE=1 TFCI_GIT_DIR=$KOKORO_ARTIFACTS_DIR/github/tensorflow +TFCI_INDEX_HTML_ENABLE=1 TFCI_LIB_SUFFIX="-cpu-linux-x86_64" TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1 TFCI_NVIDIA_SMI_ENABLE=1 @@ -14,9 +17,7 @@ TFCI_UPLOAD_LIB_ENABLE= TFCI_UPLOAD_LIB_LATEST_ENABLE= TFCI_UPLOAD_LIB_LATEST_URI="gs://libtensorflow-nightly/latest" TFCI_UPLOAD_LIB_URI="gs://libtensorflow-nightly/$(date -I)" -#TFCI_UPLOAD_LIB_URI="gs://tensorflow-release-packages/$RELEASE_VERSION/$KOKORO_GIT_COMMIT_tensorflow" TFCI_UPLOAD_WHL_GCS_ENABLE= TFCI_UPLOAD_WHL_GCS_URI= -#TFCI_UPLOAD_WHL_GCS_URI="gs://tensorflow-release-packages/$RELEASE_VERSION/$KOKORO_GIT_COMMIT_tensorflow" TFCI_UPLOAD_WHL_PYPI_ARGS=( --config-file "$KOKORO_KEYSTORE_DIR/73361_tensorflow_pypirc_using_global_api_token" --repository testpypi ) TFCI_UPLOAD_WHL_PYPI_ENABLE= diff --git a/ci/official/envs/nightly_cpu_py311 b/ci/official/envs/nightly_cpu_py311 index 928f20532259f1..e28e8f6cf3c413 100644 --- a/ci/official/envs/nightly_cpu_py311 +++ b/ci/official/envs/nightly_cpu_py311 @@ -1,3 +1,5 @@ +#TFCI_UPLOAD_LIB_URI="gs://tensorflow-release-packages/$RELEASE_VERSION/$KOKORO_GIT_COMMIT_tensorflow" +#TFCI_UPLOAD_WHL_GCS_URI="gs://tensorflow-release-packages/$RELEASE_VERSION/$KOKORO_GIT_COMMIT_tensorflow" TFCI_BAZEL_BAZELRC_ARGS=(--bazelrc ./ci/official/bazelrcs/cpu.bazelrc) TFCI_BAZEL_CACHE_ARGS=(--config sigbuild_remote_cache_push) TFCI_BUILD_PIP_PACKAGE_ARGS=(--cpu --nightly_flag) @@ -7,6 +9,7 @@ TFCI_DOCKER_GPU_ARGS=() TFCI_DOCKER_IMAGE=tensorflow/build:latest-python3.11 TFCI_DOCKER_PULL_ENABLE=1 TFCI_GIT_DIR=$KOKORO_ARTIFACTS_DIR/github/tensorflow +TFCI_INDEX_HTML_ENABLE=1 TFCI_LIB_SUFFIX="-cpu-linux-x86_64" TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1 TFCI_NVIDIA_SMI_ENABLE=1 @@ -14,9 +17,7 @@ TFCI_UPLOAD_LIB_ENABLE= TFCI_UPLOAD_LIB_LATEST_ENABLE= TFCI_UPLOAD_LIB_LATEST_URI="gs://libtensorflow-nightly/latest" TFCI_UPLOAD_LIB_URI="gs://libtensorflow-nightly/$(date -I)" -#TFCI_UPLOAD_LIB_URI="gs://tensorflow-release-packages/$RELEASE_VERSION/$KOKORO_GIT_COMMIT_tensorflow" TFCI_UPLOAD_WHL_GCS_ENABLE= TFCI_UPLOAD_WHL_GCS_URI= -#TFCI_UPLOAD_WHL_GCS_URI="gs://tensorflow-release-packages/$RELEASE_VERSION/$KOKORO_GIT_COMMIT_tensorflow" TFCI_UPLOAD_WHL_PYPI_ARGS=( --config-file "$KOKORO_KEYSTORE_DIR/73361_tensorflow_pypirc_using_global_api_token" --repository testpypi ) TFCI_UPLOAD_WHL_PYPI_ENABLE= diff --git a/ci/official/envs/nightly_cpu_py39 b/ci/official/envs/nightly_cpu_py39 index 03c114995af013..6c34a60b89cde9 100644 --- a/ci/official/envs/nightly_cpu_py39 +++ b/ci/official/envs/nightly_cpu_py39 @@ -1,3 +1,5 @@ +#TFCI_UPLOAD_LIB_URI="gs://tensorflow-release-packages/$RELEASE_VERSION/$KOKORO_GIT_COMMIT_tensorflow" +#TFCI_UPLOAD_WHL_GCS_URI="gs://tensorflow-release-packages/$RELEASE_VERSION/$KOKORO_GIT_COMMIT_tensorflow" TFCI_BAZEL_BAZELRC_ARGS=(--bazelrc ./ci/official/bazelrcs/cpu.bazelrc) TFCI_BAZEL_CACHE_ARGS=(--config sigbuild_remote_cache_push) TFCI_BUILD_PIP_PACKAGE_ARGS=(--cpu --nightly_flag) @@ -7,6 +9,7 @@ TFCI_DOCKER_GPU_ARGS=() TFCI_DOCKER_IMAGE=tensorflow/build:latest-python3.9 TFCI_DOCKER_PULL_ENABLE=1 TFCI_GIT_DIR=$KOKORO_ARTIFACTS_DIR/github/tensorflow +TFCI_INDEX_HTML_ENABLE=1 TFCI_LIB_SUFFIX="-cpu-linux-x86_64" TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1 TFCI_NVIDIA_SMI_ENABLE=1 @@ -14,9 +17,7 @@ TFCI_UPLOAD_LIB_ENABLE= TFCI_UPLOAD_LIB_LATEST_ENABLE= TFCI_UPLOAD_LIB_LATEST_URI="gs://libtensorflow-nightly/latest" TFCI_UPLOAD_LIB_URI="gs://libtensorflow-nightly/$(date -I)" -#TFCI_UPLOAD_LIB_URI="gs://tensorflow-release-packages/$RELEASE_VERSION/$KOKORO_GIT_COMMIT_tensorflow" TFCI_UPLOAD_WHL_GCS_ENABLE= TFCI_UPLOAD_WHL_GCS_URI= -#TFCI_UPLOAD_WHL_GCS_URI="gs://tensorflow-release-packages/$RELEASE_VERSION/$KOKORO_GIT_COMMIT_tensorflow" TFCI_UPLOAD_WHL_PYPI_ARGS=( --config-file "$KOKORO_KEYSTORE_DIR/73361_tensorflow_pypirc_using_global_api_token" --repository testpypi ) TFCI_UPLOAD_WHL_PYPI_ENABLE= diff --git a/ci/official/envs/nightly_nvidia_py310 b/ci/official/envs/nightly_nvidia_py310 index 891ea3d1181f68..dbfd3ca756b4f6 100644 --- a/ci/official/envs/nightly_nvidia_py310 +++ b/ci/official/envs/nightly_nvidia_py310 @@ -1,3 +1,5 @@ +#TFCI_UPLOAD_LIB_URI="gs://tensorflow-release-packages/$RELEASE_VERSION/$KOKORO_GIT_COMMIT_tensorflow" +#TFCI_UPLOAD_WHL_GCS_URI="gs://tensorflow-release-packages/$RELEASE_VERSION/$KOKORO_GIT_COMMIT_tensorflow" TFCI_BAZEL_BAZELRC_ARGS=(--bazelrc ./ci/official/bazelrcs/nvidia.bazelrc) TFCI_BAZEL_CACHE_ARGS=(--config sigbuild_remote_cache_push) TFCI_BUILD_PIP_PACKAGE_ARGS=(--gpu --nightly_flag) @@ -7,6 +9,7 @@ TFCI_DOCKER_GPU_ARGS=(--gpus all) TFCI_DOCKER_IMAGE=tensorflow/build:latest-python3.10 TFCI_DOCKER_PULL_ENABLE=1 TFCI_GIT_DIR=$KOKORO_ARTIFACTS_DIR/github/tensorflow +TFCI_INDEX_HTML_ENABLE=1 TFCI_LIB_SUFFIX="-gpu-linux-x86_64" TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1 TFCI_NVIDIA_SMI_ENABLE=1 @@ -14,9 +17,7 @@ TFCI_UPLOAD_LIB_ENABLE= TFCI_UPLOAD_LIB_LATEST_ENABLE= TFCI_UPLOAD_LIB_LATEST_URI="gs://libtensorflow-nightly/latest" TFCI_UPLOAD_LIB_URI="gs://libtensorflow-nightly/$(date -I)" -#TFCI_UPLOAD_LIB_URI="gs://tensorflow-release-packages/$RELEASE_VERSION/$KOKORO_GIT_COMMIT_tensorflow" TFCI_UPLOAD_WHL_GCS_ENABLE= TFCI_UPLOAD_WHL_GCS_URI= -#TFCI_UPLOAD_WHL_GCS_URI="gs://tensorflow-release-packages/$RELEASE_VERSION/$KOKORO_GIT_COMMIT_tensorflow" TFCI_UPLOAD_WHL_PYPI_ARGS=( --config-file "$KOKORO_KEYSTORE_DIR/73361_tensorflow_pypirc_using_global_api_token" --repository testpypi ) TFCI_UPLOAD_WHL_PYPI_ENABLE= diff --git a/ci/official/envs/nightly_nvidia_py311 b/ci/official/envs/nightly_nvidia_py311 index 7515b5b0a310c3..c988660c54e086 100644 --- a/ci/official/envs/nightly_nvidia_py311 +++ b/ci/official/envs/nightly_nvidia_py311 @@ -1,3 +1,5 @@ +#TFCI_UPLOAD_LIB_URI="gs://tensorflow-release-packages/$RELEASE_VERSION/$KOKORO_GIT_COMMIT_tensorflow" +#TFCI_UPLOAD_WHL_GCS_URI="gs://tensorflow-release-packages/$RELEASE_VERSION/$KOKORO_GIT_COMMIT_tensorflow" TFCI_BAZEL_BAZELRC_ARGS=(--bazelrc ./ci/official/bazelrcs/nvidia.bazelrc) TFCI_BAZEL_CACHE_ARGS=(--config sigbuild_remote_cache_push) TFCI_BUILD_PIP_PACKAGE_ARGS=(--gpu --nightly_flag) @@ -7,6 +9,7 @@ TFCI_DOCKER_GPU_ARGS=(--gpus all) TFCI_DOCKER_IMAGE=tensorflow/build:latest-python3.11 TFCI_DOCKER_PULL_ENABLE=1 TFCI_GIT_DIR=$KOKORO_ARTIFACTS_DIR/github/tensorflow +TFCI_INDEX_HTML_ENABLE=1 TFCI_LIB_SUFFIX="-gpu-linux-x86_64" TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1 TFCI_NVIDIA_SMI_ENABLE=1 @@ -14,9 +17,7 @@ TFCI_UPLOAD_LIB_ENABLE= TFCI_UPLOAD_LIB_LATEST_ENABLE= TFCI_UPLOAD_LIB_LATEST_URI="gs://libtensorflow-nightly/latest" TFCI_UPLOAD_LIB_URI="gs://libtensorflow-nightly/$(date -I)" -#TFCI_UPLOAD_LIB_URI="gs://tensorflow-release-packages/$RELEASE_VERSION/$KOKORO_GIT_COMMIT_tensorflow" TFCI_UPLOAD_WHL_GCS_ENABLE= TFCI_UPLOAD_WHL_GCS_URI= -#TFCI_UPLOAD_WHL_GCS_URI="gs://tensorflow-release-packages/$RELEASE_VERSION/$KOKORO_GIT_COMMIT_tensorflow" TFCI_UPLOAD_WHL_PYPI_ARGS=( --config-file "$KOKORO_KEYSTORE_DIR/73361_tensorflow_pypirc_using_global_api_token" --repository testpypi ) TFCI_UPLOAD_WHL_PYPI_ENABLE= diff --git a/ci/official/envs/nightly_nvidia_py39 b/ci/official/envs/nightly_nvidia_py39 index 8312efd24cf55e..9c0984080a211b 100644 --- a/ci/official/envs/nightly_nvidia_py39 +++ b/ci/official/envs/nightly_nvidia_py39 @@ -1,3 +1,5 @@ +#TFCI_UPLOAD_LIB_URI="gs://tensorflow-release-packages/$RELEASE_VERSION/$KOKORO_GIT_COMMIT_tensorflow" +#TFCI_UPLOAD_WHL_GCS_URI="gs://tensorflow-release-packages/$RELEASE_VERSION/$KOKORO_GIT_COMMIT_tensorflow" TFCI_BAZEL_BAZELRC_ARGS=(--bazelrc ./ci/official/bazelrcs/nvidia.bazelrc) TFCI_BAZEL_CACHE_ARGS=(--config sigbuild_remote_cache_push) TFCI_BUILD_PIP_PACKAGE_ARGS=(--gpu --nightly_flag) @@ -7,6 +9,7 @@ TFCI_DOCKER_GPU_ARGS=(--gpus all) TFCI_DOCKER_IMAGE=tensorflow/build:latest-python3.9 TFCI_DOCKER_PULL_ENABLE=1 TFCI_GIT_DIR=$KOKORO_ARTIFACTS_DIR/github/tensorflow +TFCI_INDEX_HTML_ENABLE=1 TFCI_LIB_SUFFIX="-gpu-linux-x86_64" TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1 TFCI_NVIDIA_SMI_ENABLE=1 @@ -14,9 +17,7 @@ TFCI_UPLOAD_LIB_ENABLE= TFCI_UPLOAD_LIB_LATEST_ENABLE= TFCI_UPLOAD_LIB_LATEST_URI="gs://libtensorflow-nightly/latest" TFCI_UPLOAD_LIB_URI="gs://libtensorflow-nightly/$(date -I)" -#TFCI_UPLOAD_LIB_URI="gs://tensorflow-release-packages/$RELEASE_VERSION/$KOKORO_GIT_COMMIT_tensorflow" TFCI_UPLOAD_WHL_GCS_ENABLE= TFCI_UPLOAD_WHL_GCS_URI= -#TFCI_UPLOAD_WHL_GCS_URI="gs://tensorflow-release-packages/$RELEASE_VERSION/$KOKORO_GIT_COMMIT_tensorflow" TFCI_UPLOAD_WHL_PYPI_ARGS=( --config-file "$KOKORO_KEYSTORE_DIR/73361_tensorflow_pypirc_using_global_api_token" --repository testpypi ) TFCI_UPLOAD_WHL_PYPI_ENABLE= diff --git a/ci/official/libtensorflow.sh b/ci/official/libtensorflow.sh index 28071d8d4521bd..31ed771f40b382 100755 --- a/ci/official/libtensorflow.sh +++ b/ci/official/libtensorflow.sh @@ -1,16 +1,5 @@ #!/bin/bash -# -e: abort script if one command fails -# -u: error if undefined variable used -# -o pipefail: entire command fails if pipe fails. watch out for yes | ... -# -o history: record shell history -set -euxo pipefail -o history -set -o allexport && source "$TFCI" && set +o allexport - -cd "$TFCI_GIT_DIR" && mkdir -p build -tfrun() { "$@"; } -[[ "$TFCI_COPYBARA_ENABLE" == 1 ]] && source ./ci/official/utilities/copybara.sh -[[ "$TFCI_DOCKER_ENABLE" == 1 ]] && source ./ci/official/utilities/docker.sh -./ci/official/utilities/generate_index_html.sh build/index.html +source "${BASH_SOURCE%/*}/utilities/setup.sh" # Record GPU count and CUDA version status [[ "$TFCI_NVIDIA_SMI_ENABLE" == 1 ]] && tfrun nvidia-smi diff --git a/ci/official/pycpp.sh b/ci/official/pycpp.sh index 5eeda2ec1ff89a..7c5f254b982948 100755 --- a/ci/official/pycpp.sh +++ b/ci/official/pycpp.sh @@ -1,16 +1,5 @@ #!/bin/bash -# -e: abort script if one command fails -# -u: error if undefined variable used -# -o pipefail: entire command fails if pipe fails. watch out for yes | ... -# -o history: record shell history -set -euxo pipefail -o history -set -o allexport && source "$TFCI" && set +o allexport - -cd "$TFCI_GIT_DIR" && mkdir -p build -tfrun() { "$@"; } -[[ "$TFCI_COPYBARA_ENABLE" == 1 ]] && source ./ci/official/utilities/copybara.sh -[[ "$TFCI_DOCKER_ENABLE" == 1 ]] && source ./ci/official/utilities/docker.sh -./ci/official/utilities/generate_index_html.sh build/index.html +source "${BASH_SOURCE%/*}/utilities/setup.sh" # TODO(b/284172313) Revert this difference between presubmits and continuous. RBE serverside behavior is causing flakes, # so we're temporarily allowing flaky tests again for presubmits. diff --git a/ci/official/utilities/setup.sh b/ci/official/utilities/setup.sh new file mode 100755 index 00000000000000..6e96ed0bc19132 --- /dev/null +++ b/ci/official/utilities/setup.sh @@ -0,0 +1,29 @@ +#!/bin/bash +# -e: abort script if one command fails +# -u: error if undefined variable used +# -o pipefail: entire command fails if pipe fails. watch out for yes | ... +# -o history: record shell history +# -o allexport: export all functions and variables to be available to subscripts +set -euxo pipefail -o history -o allexport + +# Import all variables as set in $TFCI, which should be a file like those in +# the envs directory that sets all TFCI_ variables, e.g. /path/to/envs/local_cpu +source "$TFCI" + +# Make a "build" directory for outputting all build artifacts (TF's .gitignore +# ignores the "build" directory) +cd "$TFCI_GIT_DIR" && mkdir -p build + +# Setup tfrun, a helper function for executing steps that can either be run +# locally or run under Docker. docker.sh, below, redefines it as "docker exec". +tfrun() { "$@"; } + +# For Google-internal jobs, run copybara, which will overwrite the source tree. +# Never useful for outside users. +[[ "$TFCI_COPYBARA_ENABLE" == 1 ]] && source ./ci/official/utilities/copybara.sh + +# Run all "tfrun" commands under Docker. See docker.sh for details +[[ "$TFCI_DOCKER_ENABLE" == 1 ]] && source ./ci/official/utilities/docker.sh + +# Generate an overview page describing the build +[[ "$TFCI_INDEX_HTML_ENABLE" == 1 ]] && ./ci/official/utilities/generate_index_html.sh build/index.html diff --git a/ci/official/wheel.sh b/ci/official/wheel.sh index 665a24d64d69e0..359517a405224c 100755 --- a/ci/official/wheel.sh +++ b/ci/official/wheel.sh @@ -1,16 +1,5 @@ #!/bin/bash -# -e: abort script if one command fails -# -u: error if undefined variable used -# -o pipefail: entire command fails if pipe fails. watch out for yes | ... -# -o history: record shell history -set -euxo pipefail -o history -set -o allexport && source "$TFCI" && set +o allexport - -cd "$TFCI_GIT_DIR" && mkdir -p build -tfrun() { "$@"; } -[[ "$TFCI_COPYBARA_ENABLE" == 1 ]] && source ./ci/official/utilities/copybara.sh -[[ "$TFCI_DOCKER_ENABLE" == 1 ]] && source ./ci/official/utilities/docker.sh -./ci/official/utilities/generate_index_html.sh build/index.html +source "${BASH_SOURCE%/*}/utilities/setup.sh" # Record GPU count and CUDA version status [[ "$TFCI_NVIDIA_SMI_ENABLE" == 1 ]] && tfrun nvidia-smi From 5a85a6859d9780dadf027b6d2190caef32df2e9a Mon Sep 17 00:00:00 2001 From: Austin Anderson Date: Thu, 6 Jul 2023 18:37:42 -0700 Subject: [PATCH 037/376] Cleanup --- ci/official/pycpp.sh | 2 -- ci/official/utilities/setup.sh | 14 ++++++++++++++ 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/ci/official/pycpp.sh b/ci/official/pycpp.sh index 7c5f254b982948..984e27d021a9ac 100755 --- a/ci/official/pycpp.sh +++ b/ci/official/pycpp.sh @@ -1,8 +1,6 @@ #!/bin/bash source "${BASH_SOURCE%/*}/utilities/setup.sh" -# TODO(b/284172313) Revert this difference between presubmits and continuous. RBE serverside behavior is causing flakes, -# so we're temporarily allowing flaky tests again for presubmits. tfrun bazel "${TFCI_BAZEL_BAZELRC_ARGS[@]}" test "${TFCI_BAZEL_CACHE_ARGS[@]}" --config=rbe --config=pycpp --config=build_event_export tfrun bazel analyze-profile build/profile.json.gz diff --git a/ci/official/utilities/setup.sh b/ci/official/utilities/setup.sh index 6e96ed0bc19132..ba1988964b0fd9 100755 --- a/ci/official/utilities/setup.sh +++ b/ci/official/utilities/setup.sh @@ -1,6 +1,15 @@ #!/bin/bash +# Common setup for all TF scripts. +# +# Make as FEW changes to this file as possible. It should not contain utility +# functions (except for tfrun); use dedicated scripts instead and reference them +# specifically. Use your best judgment to keep the scripts in this directory +# lean and easy to follow. When in doubt, remember that for CI scripts, "keep it +# simple" is MUCH more important than "don't repeat yourself." + # -e: abort script if one command fails # -u: error if undefined variable used +# -x: log all commands # -o pipefail: entire command fails if pipe fails. watch out for yes | ... # -o history: record shell history # -o allexport: export all functions and variables to be available to subscripts @@ -16,6 +25,11 @@ cd "$TFCI_GIT_DIR" && mkdir -p build # Setup tfrun, a helper function for executing steps that can either be run # locally or run under Docker. docker.sh, below, redefines it as "docker exec". +# Important: "tfrun foo | bar" is "( tfrun foo ) | bar", not tfrun (foo | bar). +# Therefore, "tfrun" commands cannot include pipes -- which is probably for the +# better. If a pipe is necessary for something, it is probably complex. Write a +# well-documented script under utilities/ to encapsulate the functionality +# instead. tfrun() { "$@"; } # For Google-internal jobs, run copybara, which will overwrite the source tree. From 9e5137e7d6a1bcb3f817838aa9fa9247ef295e3e Mon Sep 17 00:00:00 2001 From: Michael Hudgins Date: Fri, 7 Jul 2023 11:35:04 -0400 Subject: [PATCH 038/376] Allow release branch builds for arm64 This should enable the arm64 cd build to run on pushes to the release branches in addition to just tags. The final step of uploading a binary should not be run in the case of a push to the release branch and only on the release tag. --- .github/workflows/arm-cd.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/workflows/arm-cd.yml b/.github/workflows/arm-cd.yml index a191c65a98f35f..ad6bcb9edc0e28 100644 --- a/.github/workflows/arm-cd.yml +++ b/.github/workflows/arm-cd.yml @@ -19,6 +19,8 @@ on: push: tags: - v2.** + branches: + - r2.** schedule: - cron: '0 8 * * *' @@ -66,5 +68,6 @@ jobs: CI_DOCKER_BUILD_EXTRA_PARAMS="--build-arg py_major_minor_version=${{ matrix.pyver }} --build-arg is_nightly=${is_nightly} --build-arg tf_project_name=${tf_project_name}" \ ./tensorflow/tools/ci_build/ci_build.sh cpu.arm64 bash tensorflow/tools/ci_build/rel/ubuntu/cpu_arm64_pip.sh - name: Upload pip wheel to PyPI + if: github.event_name == 'schedule' || (github.event_name == 'push' && contains(github.ref, 'refs/tags/')) # only if it is a scheduled nightly or tagged shell: bash run: python3 -m twine upload --verbose /home/ubuntu/actions-runner/_work/tensorflow/tensorflow/whl/* -u "__token__" -p ${{ secrets.AWS_PYPI_ACCOUNT_TOKEN }} From 9c57db4e22b0d33ed864e577ef7a2cb2fa67f1ff Mon Sep 17 00:00:00 2001 From: Austin Anderson Date: Fri, 7 Jul 2023 17:31:00 -0700 Subject: [PATCH 039/376] Handle feedback --- ci/official/any.sh | 14 +++++++------ ci/official/bazelrcs/cpu.bazelrc | 4 ++-- ci/official/bazelrcs/cpu_gcc.bazelrc | 4 ++-- ci/official/bazelrcs/nvidia.bazelrc | 4 ++-- ci/official/envs/local_cpu | 4 ++-- ci/official/envs/nightly_cpu_py310 | 2 +- ci/official/envs/nightly_cpu_py311 | 2 +- ci/official/envs/nightly_cpu_py39 | 2 +- ci/official/envs/nightly_nvidia_py310 | 2 +- ci/official/envs/nightly_nvidia_py311 | 2 +- ci/official/envs/nightly_nvidia_py39 | 2 +- ci/official/libtensorflow.sh | 12 +++++++---- ci/official/pycpp.sh | 2 +- ci/official/utilities/copybara.sh | 7 ------- ci/official/utilities/docker.sh | 7 ------- ci/official/utilities/repack_libtensorflow.sh | 7 ------- ci/official/utilities/setup.sh | 20 ++++++++++++++----- ci/official/wheel.sh | 12 +++++++---- 18 files changed, 54 insertions(+), 55 deletions(-) diff --git a/ci/official/any.sh b/ci/official/any.sh index 8bfba811d74419..0548e4e9e1bccd 100755 --- a/ci/official/any.sh +++ b/ci/official/any.sh @@ -7,15 +7,17 @@ source "${BASH_SOURCE%/*}/utilities/setup.sh" # subshell environment. # Ignore grep failures since we're using it for basic filtering set +e -filtered_build_targets=( $(echo "$BUILD_TARGETS" | tr ' ' '\n' | grep . | tee build_targets.txt) ) -nonpip_targets=( $(echo "$TEST_TARGETS" | tr ' ' '\n' | grep -E "^//tensorflow/" | tee nonpip_targets.txt) ) +filtered_build_targets=( $(echo "$BUILD_TARGETS" | tr ' ' '\n' | grep .) ) +nonpip_targets=( $(echo "$TEST_TARGETS" | tr ' ' '\n' | grep -E "^//tensorflow/" ) ) config=( $(echo "$CONFIG_OPTIONS" ) ) test_flags=( $(echo "$TEST_FLAGS" ) ) set -e -[[ "$TFCI_NVIDIA_SMI_ENABLE" == 1 ]] && tfrun nvidia-smi +if [[ "$TFCI_NVIDIA_SMI_ENABLE" == 1 ]]; then + tfrun nvidia-smi +fi -if [[ -s build_targets.txt ]]; then +if [[ "${#filtered_build_targets[@]}" -ne 0 ]]; then tfrun bazel "${TFCI_BAZEL_BAZELRC_ARGS[@]}" "${config[@]}" "${filtered_build_targets[@]}" fi @@ -23,11 +25,11 @@ if [[ "${PIP_WHEEL}" -eq "1" ]]; then # Update the version numbers to build a "nightly" package [[ "$TFCI_NIGHTLY_UPDATE_VERSION_ENABLE" == 1 ]] && tfrun python3 tensorflow/tools/ci_build/update_version.py --nightly - tfrun bazel "${TFCI_BAZEL_BAZELRC_ARGS[@]}" build "${TFCI_BAZEL_CACHE_ARGS[@]}" tensorflow/tools/pip_package:build_pip_package + tfrun bazel "${TFCI_BAZEL_BAZELRC_ARGS[@]}" build "${TFCI_BAZEL_COMMON_ARGS[@]}" tensorflow/tools/pip_package:build_pip_package tfrun ./bazel-bin/tensorflow/tools/pip_package/build_pip_package build "${TFCI_BUILD_PIP_PACKAGE_ARGS[@]}" tfrun ./ci/official/utilities/rename_and_verify_wheels.sh fi -if [[ -s nonpip_targets.txt ]]; then +if [[ "${#nonpip_targets[@]}" -ne 0 ]]; then tfrun bazel "${TFCI_BAZEL_BAZELRC_ARGS[@]}" test "${config[@]}" "${test_flags[@]}" "${nonpip_targets[@]}" fi diff --git a/ci/official/bazelrcs/cpu.bazelrc b/ci/official/bazelrcs/cpu.bazelrc index 3a324603bdf0ce..2a1597e24a5e2e 100644 --- a/ci/official/bazelrcs/cpu.bazelrc +++ b/ci/official/bazelrcs/cpu.bazelrc @@ -37,7 +37,7 @@ build --copt=-Wno-gnu-offsetof-extensions # Store performance profiling log in the mounted artifact directory. # The profile can be viewed by visiting chrome://tracing in a Chrome browser. # See https://docs.bazel.build/versions/main/skylark/performance.html#performance-profiling -build --profile=/tf/pkg/profile.json.gz +build --profile=build/profile.json.gz # Use the NVCC toolchain to compile for manylinux2014 build --crosstool_top="@sigbuild-r2.14-clang_config_cuda//crosstool:toolchain" @@ -63,7 +63,7 @@ test:libtensorflow_test -- //tensorflow/tools/lib_package:libtensorflow_test //t build:libtensorflow_build -- //tensorflow/tools/lib_package:libtensorflow.tar.gz //tensorflow/tools/lib_package:libtensorflow_jni.tar.gz //tensorflow/java:libtensorflow.jar //tensorflow/java:libtensorflow-src.jar //tensorflow/tools/lib_package:libtensorflow_proto.zip # For outputting Build Event Protocol files -build:build_event_export --build_event_json_file=/tf/pkg/bep.json +build:build_event_export --build_event_json_file=build/bep.json # For Remote Build Execution. build:rbe --google_default_credentials diff --git a/ci/official/bazelrcs/cpu_gcc.bazelrc b/ci/official/bazelrcs/cpu_gcc.bazelrc index cc74fd978cfade..311d59b27a05dd 100644 --- a/ci/official/bazelrcs/cpu_gcc.bazelrc +++ b/ci/official/bazelrcs/cpu_gcc.bazelrc @@ -26,7 +26,7 @@ build --copt=-mavx --host_copt=-mavx # Store performance profiling log in the mounted artifact directory. # The profile can be viewed by visiting chrome://tracing in a Chrome browser. # See https://docs.bazel.build/versions/main/skylark/performance.html#performance-profiling -build --profile=/tf/pkg/profile.json.gz +build --profile=build/profile.json.gz # Use the NVCC toolchain to compile for manylinux2014 build --crosstool_top="@sigbuild-r2.14_config_cuda//crosstool:toolchain" @@ -52,7 +52,7 @@ test:libtensorflow_test -- //tensorflow/tools/lib_package:libtensorflow_test //t build:libtensorflow_build -- //tensorflow/tools/lib_package:libtensorflow.tar.gz //tensorflow/tools/lib_package:libtensorflow_jni.tar.gz //tensorflow/java:libtensorflow.jar //tensorflow/java:libtensorflow-src.jar //tensorflow/tools/lib_package:libtensorflow_proto.zip # For outputting Build Event Protocol files -build:build_event_export --build_event_json_file=/tf/pkg/bep.json +build:build_event_export --build_event_json_file=build/bep.json # For Remote Build Execution. build:rbe --google_default_credentials diff --git a/ci/official/bazelrcs/nvidia.bazelrc b/ci/official/bazelrcs/nvidia.bazelrc index 50ea575205967c..3a6773579e9d37 100644 --- a/ci/official/bazelrcs/nvidia.bazelrc +++ b/ci/official/bazelrcs/nvidia.bazelrc @@ -37,7 +37,7 @@ build --linkopt="-lm" # Store performance profiling log in the mounted artifact directory. # The profile can be viewed by visiting chrome://tracing in a Chrome browser. # See https://docs.bazel.build/versions/main/skylark/performance.html#performance-profiling -build --profile=/tf/pkg/profile.json.gz +build --profile=build/profile.json.gz # CUDA: Set up compilation CUDA version and paths build --@local_config_cuda//:enable_cuda @@ -86,7 +86,7 @@ test:libtensorflow_test -- //tensorflow/tools/lib_package:libtensorflow_test //t build:libtensorflow_build -- //tensorflow/tools/lib_package:libtensorflow.tar.gz //tensorflow/tools/lib_package:libtensorflow_jni.tar.gz //tensorflow/java:libtensorflow.jar //tensorflow/java:libtensorflow-src.jar //tensorflow/tools/lib_package:libtensorflow_proto.zip # For outputting Build Event Protocol files -build:build_event_export --build_event_json_file=/tf/pkg/bep.json +build:build_event_export --build_event_json_file=build/bep.json # For Remote Build Execution. build:rbe --google_default_credentials diff --git a/ci/official/envs/local_cpu b/ci/official/envs/local_cpu index 914dd4c856afd4..79e64bfc28bcf9 100644 --- a/ci/official/envs/local_cpu +++ b/ci/official/envs/local_cpu @@ -1,12 +1,12 @@ TFCI_BAZEL_BAZELRC_ARGS=(--bazelrc ./ci/official/bazelrcs/cpu.bazelrc) -TFCI_BAZEL_CACHE_ARGS=(--config sigbuild_remote_cache) +TFCI_BAZEL_COMMON_ARGS=(--config sigbuild_remote_cache) TFCI_BUILD_PIP_PACKAGE_ARGS=("--cpu") TFCI_COPYBARA_ENABLE=0 TFCI_DOCKER_ENABLE=1 TFCI_DOCKER_GPU_ARGS=() TFCI_DOCKER_IMAGE=tensorflow/build:latest-python3.9 TFCI_DOCKER_PULL_ENABLE= -TFCI_GIT_DIR=/usr/local/google/home/angerson/repos/tensorflow +TFCI_GIT_DIR=. TFCI_INDEX_HTML_ENABLE=1 TFCI_LIB_SUFFIX="-cpu-linux-x86_64" TFCI_NIGHTLY_UPDATE_VERSION_ENABLE= diff --git a/ci/official/envs/nightly_cpu_py310 b/ci/official/envs/nightly_cpu_py310 index 9cff1f3803ff5a..eabe2dcc845a1e 100644 --- a/ci/official/envs/nightly_cpu_py310 +++ b/ci/official/envs/nightly_cpu_py310 @@ -1,7 +1,7 @@ #TFCI_UPLOAD_LIB_URI="gs://tensorflow-release-packages/$RELEASE_VERSION/$KOKORO_GIT_COMMIT_tensorflow" #TFCI_UPLOAD_WHL_GCS_URI="gs://tensorflow-release-packages/$RELEASE_VERSION/$KOKORO_GIT_COMMIT_tensorflow" TFCI_BAZEL_BAZELRC_ARGS=(--bazelrc ./ci/official/bazelrcs/cpu.bazelrc) -TFCI_BAZEL_CACHE_ARGS=(--config sigbuild_remote_cache_push) +TFCI_BAZEL_COMMON_ARGS=(--config sigbuild_remote_cache_push) TFCI_BUILD_PIP_PACKAGE_ARGS=(--cpu --nightly_flag) TFCI_COPYBARA_ENABLE=1 TFCI_DOCKER_ENABLE=1 diff --git a/ci/official/envs/nightly_cpu_py311 b/ci/official/envs/nightly_cpu_py311 index e28e8f6cf3c413..0201e5aa44c0d4 100644 --- a/ci/official/envs/nightly_cpu_py311 +++ b/ci/official/envs/nightly_cpu_py311 @@ -1,7 +1,7 @@ #TFCI_UPLOAD_LIB_URI="gs://tensorflow-release-packages/$RELEASE_VERSION/$KOKORO_GIT_COMMIT_tensorflow" #TFCI_UPLOAD_WHL_GCS_URI="gs://tensorflow-release-packages/$RELEASE_VERSION/$KOKORO_GIT_COMMIT_tensorflow" TFCI_BAZEL_BAZELRC_ARGS=(--bazelrc ./ci/official/bazelrcs/cpu.bazelrc) -TFCI_BAZEL_CACHE_ARGS=(--config sigbuild_remote_cache_push) +TFCI_BAZEL_COMMON_ARGS=(--config sigbuild_remote_cache_push) TFCI_BUILD_PIP_PACKAGE_ARGS=(--cpu --nightly_flag) TFCI_COPYBARA_ENABLE=1 TFCI_DOCKER_ENABLE=1 diff --git a/ci/official/envs/nightly_cpu_py39 b/ci/official/envs/nightly_cpu_py39 index 6c34a60b89cde9..436bd41e169143 100644 --- a/ci/official/envs/nightly_cpu_py39 +++ b/ci/official/envs/nightly_cpu_py39 @@ -1,7 +1,7 @@ #TFCI_UPLOAD_LIB_URI="gs://tensorflow-release-packages/$RELEASE_VERSION/$KOKORO_GIT_COMMIT_tensorflow" #TFCI_UPLOAD_WHL_GCS_URI="gs://tensorflow-release-packages/$RELEASE_VERSION/$KOKORO_GIT_COMMIT_tensorflow" TFCI_BAZEL_BAZELRC_ARGS=(--bazelrc ./ci/official/bazelrcs/cpu.bazelrc) -TFCI_BAZEL_CACHE_ARGS=(--config sigbuild_remote_cache_push) +TFCI_BAZEL_COMMON_ARGS=(--config sigbuild_remote_cache_push) TFCI_BUILD_PIP_PACKAGE_ARGS=(--cpu --nightly_flag) TFCI_COPYBARA_ENABLE=1 TFCI_DOCKER_ENABLE=1 diff --git a/ci/official/envs/nightly_nvidia_py310 b/ci/official/envs/nightly_nvidia_py310 index dbfd3ca756b4f6..214efe40d42db3 100644 --- a/ci/official/envs/nightly_nvidia_py310 +++ b/ci/official/envs/nightly_nvidia_py310 @@ -1,7 +1,7 @@ #TFCI_UPLOAD_LIB_URI="gs://tensorflow-release-packages/$RELEASE_VERSION/$KOKORO_GIT_COMMIT_tensorflow" #TFCI_UPLOAD_WHL_GCS_URI="gs://tensorflow-release-packages/$RELEASE_VERSION/$KOKORO_GIT_COMMIT_tensorflow" TFCI_BAZEL_BAZELRC_ARGS=(--bazelrc ./ci/official/bazelrcs/nvidia.bazelrc) -TFCI_BAZEL_CACHE_ARGS=(--config sigbuild_remote_cache_push) +TFCI_BAZEL_COMMON_ARGS=(--config sigbuild_remote_cache_push) TFCI_BUILD_PIP_PACKAGE_ARGS=(--gpu --nightly_flag) TFCI_COPYBARA_ENABLE=1 TFCI_DOCKER_ENABLE=1 diff --git a/ci/official/envs/nightly_nvidia_py311 b/ci/official/envs/nightly_nvidia_py311 index c988660c54e086..9a4a8f173eb2a6 100644 --- a/ci/official/envs/nightly_nvidia_py311 +++ b/ci/official/envs/nightly_nvidia_py311 @@ -1,7 +1,7 @@ #TFCI_UPLOAD_LIB_URI="gs://tensorflow-release-packages/$RELEASE_VERSION/$KOKORO_GIT_COMMIT_tensorflow" #TFCI_UPLOAD_WHL_GCS_URI="gs://tensorflow-release-packages/$RELEASE_VERSION/$KOKORO_GIT_COMMIT_tensorflow" TFCI_BAZEL_BAZELRC_ARGS=(--bazelrc ./ci/official/bazelrcs/nvidia.bazelrc) -TFCI_BAZEL_CACHE_ARGS=(--config sigbuild_remote_cache_push) +TFCI_BAZEL_COMMON_ARGS=(--config sigbuild_remote_cache_push) TFCI_BUILD_PIP_PACKAGE_ARGS=(--gpu --nightly_flag) TFCI_COPYBARA_ENABLE=1 TFCI_DOCKER_ENABLE=1 diff --git a/ci/official/envs/nightly_nvidia_py39 b/ci/official/envs/nightly_nvidia_py39 index 9c0984080a211b..4e729536b1d60e 100644 --- a/ci/official/envs/nightly_nvidia_py39 +++ b/ci/official/envs/nightly_nvidia_py39 @@ -1,7 +1,7 @@ #TFCI_UPLOAD_LIB_URI="gs://tensorflow-release-packages/$RELEASE_VERSION/$KOKORO_GIT_COMMIT_tensorflow" #TFCI_UPLOAD_WHL_GCS_URI="gs://tensorflow-release-packages/$RELEASE_VERSION/$KOKORO_GIT_COMMIT_tensorflow" TFCI_BAZEL_BAZELRC_ARGS=(--bazelrc ./ci/official/bazelrcs/nvidia.bazelrc) -TFCI_BAZEL_CACHE_ARGS=(--config sigbuild_remote_cache_push) +TFCI_BAZEL_COMMON_ARGS=(--config sigbuild_remote_cache_push) TFCI_BUILD_PIP_PACKAGE_ARGS=(--gpu --nightly_flag) TFCI_COPYBARA_ENABLE=1 TFCI_DOCKER_ENABLE=1 diff --git a/ci/official/libtensorflow.sh b/ci/official/libtensorflow.sh index 31ed771f40b382..db3647bec7ded0 100755 --- a/ci/official/libtensorflow.sh +++ b/ci/official/libtensorflow.sh @@ -2,13 +2,17 @@ source "${BASH_SOURCE%/*}/utilities/setup.sh" # Record GPU count and CUDA version status -[[ "$TFCI_NVIDIA_SMI_ENABLE" == 1 ]] && tfrun nvidia-smi +if [[ "$TFCI_NVIDIA_SMI_ENABLE" == 1 ]]; then + tfrun nvidia-smi +fi # Update the version numbers for Nightly only -[[ "$TFCI_NIGHTLY_UPDATE_VERSION_ENABLE" == 1 ]] && tfrun python3 tensorflow/tools/ci_build/update_version.py --nightly +if [[ "$TFCI_NIGHTLY_UPDATE_VERSION_ENABLE" == 1 ]]; then + tfrun python3 tensorflow/tools/ci_build/update_version.py --nightly +fi -tfrun bazel "${TFCI_BAZEL_BAZELRC_ARGS[@]}" test "${TFCI_BAZEL_CACHE_ARGS[@]}" --config=libtensorflow_test -tfrun bazel "${TFCI_BAZEL_BAZELRC_ARGS[@]}" build "${TFCI_BAZEL_CACHE_ARGS[@]}" --config=libtensorflow_build +tfrun bazel "${TFCI_BAZEL_BAZELRC_ARGS[@]}" test "${TFCI_BAZEL_COMMON_ARGS[@]}" --config=libtensorflow_test +tfrun bazel "${TFCI_BAZEL_BAZELRC_ARGS[@]}" build "${TFCI_BAZEL_COMMON_ARGS[@]}" --config=libtensorflow_build tfrun ./ci/official/utilities/repack_libtensorflow.sh build "$TFCI_LIB_SUFFIX" diff --git a/ci/official/pycpp.sh b/ci/official/pycpp.sh index 984e27d021a9ac..f6aef18b96cb39 100755 --- a/ci/official/pycpp.sh +++ b/ci/official/pycpp.sh @@ -1,6 +1,6 @@ #!/bin/bash source "${BASH_SOURCE%/*}/utilities/setup.sh" -tfrun bazel "${TFCI_BAZEL_BAZELRC_ARGS[@]}" test "${TFCI_BAZEL_CACHE_ARGS[@]}" --config=rbe --config=pycpp --config=build_event_export +tfrun bazel "${TFCI_BAZEL_BAZELRC_ARGS[@]}" test "${TFCI_BAZEL_COMMON_ARGS[@]}" --config=pycpp tfrun bazel analyze-profile build/profile.json.gz diff --git a/ci/official/utilities/copybara.sh b/ci/official/utilities/copybara.sh index 34f24c802fdc27..fd9dac640d22f2 100755 --- a/ci/official/utilities/copybara.sh +++ b/ci/official/utilities/copybara.sh @@ -1,11 +1,4 @@ #!/bin/bash -# -e: abort script if one command fails -# -u: error if undefined variable used -# -o pipefail: entire command fails if pipe fails. watch out for yes | ... -# -o history: record shell history -set -euxo pipefail -o history -set -o allexport && source "$TFCI" && set +o allexport - # Destroy any existing github code rm -rf "$TFCI_GIT_DIR" mkdir -p "$TFCI_GIT_DIR" diff --git a/ci/official/utilities/docker.sh b/ci/official/utilities/docker.sh index b84ee381e518be..e0168a4b94bc77 100755 --- a/ci/official/utilities/docker.sh +++ b/ci/official/utilities/docker.sh @@ -1,11 +1,4 @@ #!/bin/bash -# -e: abort script if one command fails -# -u: error if undefined variable used -# -o pipefail: entire command fails if pipe fails. watch out for yes | ... -# -o history: record shell history -set -euxo pipefail -o history -set -o allexport && source "$TFCI" && set +o allexport - trap "docker rm -f tf" EXIT if [[ "$TFCI_DOCKER_PULL_ENABLE" == 1 ]]; then docker pull "$TFCI_DOCKER_IMAGE" diff --git a/ci/official/utilities/repack_libtensorflow.sh b/ci/official/utilities/repack_libtensorflow.sh index fefce92f747ce3..7492642148fa78 100755 --- a/ci/official/utilities/repack_libtensorflow.sh +++ b/ci/official/utilities/repack_libtensorflow.sh @@ -19,13 +19,6 @@ # # Repacks libtensorflow tarballs into $DIR with provided $TARBALL_SUFFIX, # and also repacks libtensorflow-src.jar into a standardized format. -# -# -e: abort script if one command fails -# -u: error if undefined variable used -# -o pipefail: entire command fails if pipe fails. watch out for yes | ... -# -o history: record shell history -set -euxo pipefail -o history -set -o allexport && source "$TFCI" && set +o allexport # Helper function to copy a srcjar after moving any source files # directly under the root to the "maven-style" src/main/java layout diff --git a/ci/official/utilities/setup.sh b/ci/official/utilities/setup.sh index ba1988964b0fd9..3d5773ffc32875 100755 --- a/ci/official/utilities/setup.sh +++ b/ci/official/utilities/setup.sh @@ -13,15 +13,19 @@ # -o pipefail: entire command fails if pipe fails. watch out for yes | ... # -o history: record shell history # -o allexport: export all functions and variables to be available to subscripts +# (affects 'source $TFCI') set -euxo pipefail -o history -o allexport # Import all variables as set in $TFCI, which should be a file like those in # the envs directory that sets all TFCI_ variables, e.g. /path/to/envs/local_cpu -source "$TFCI" +if [[ -n "$TFCI" ]]; then + source "$TFCI" +fi # Make a "build" directory for outputting all build artifacts (TF's .gitignore # ignores the "build" directory) -cd "$TFCI_GIT_DIR" && mkdir -p build +cd "$TFCI_GIT_DIR" +mkdir -p build # Setup tfrun, a helper function for executing steps that can either be run # locally or run under Docker. docker.sh, below, redefines it as "docker exec". @@ -34,10 +38,16 @@ tfrun() { "$@"; } # For Google-internal jobs, run copybara, which will overwrite the source tree. # Never useful for outside users. -[[ "$TFCI_COPYBARA_ENABLE" == 1 ]] && source ./ci/official/utilities/copybara.sh +if [[ "$TFCI_COPYBARA_ENABLE" == 1 ]]; then + source ./ci/official/utilities/copybara.sh +fi # Run all "tfrun" commands under Docker. See docker.sh for details -[[ "$TFCI_DOCKER_ENABLE" == 1 ]] && source ./ci/official/utilities/docker.sh +if [[ "$TFCI_DOCKER_ENABLE" == 1 ]]; then + source ./ci/official/utilities/docker.sh +fi # Generate an overview page describing the build -[[ "$TFCI_INDEX_HTML_ENABLE" == 1 ]] && ./ci/official/utilities/generate_index_html.sh build/index.html +if [[ "$TFCI_INDEX_HTML_ENABLE" == 1 ]]; then + ./ci/official/utilities/generate_index_html.sh build/index.html +fi diff --git a/ci/official/wheel.sh b/ci/official/wheel.sh index 359517a405224c..81a7c260d5df98 100755 --- a/ci/official/wheel.sh +++ b/ci/official/wheel.sh @@ -2,12 +2,16 @@ source "${BASH_SOURCE%/*}/utilities/setup.sh" # Record GPU count and CUDA version status -[[ "$TFCI_NVIDIA_SMI_ENABLE" == 1 ]] && tfrun nvidia-smi +if [[ "$TFCI_NVIDIA_SMI_ENABLE" == 1 ]]; then + tfrun nvidia-smi +fi # Update the version numbers for Nightly only -[[ "$TFCI_NIGHTLY_UPDATE_VERSION_ENABLE" == 1 ]] && tfrun python3 tensorflow/tools/ci_build/update_version.py --nightly +if [[ "$TFCI_NIGHTLY_UPDATE_VERSION_ENABLE" == 1 ]]; then + tfrun python3 tensorflow/tools/ci_build/update_version.py --nightly +fi -tfrun bazel "${TFCI_BAZEL_BAZELRC_ARGS[@]}" build "${TFCI_BAZEL_CACHE_ARGS[@]}" //tensorflow/tools/pip_package:build_pip_package +tfrun bazel "${TFCI_BAZEL_BAZELRC_ARGS[@]}" build "${TFCI_BAZEL_COMMON_ARGS[@]}" //tensorflow/tools/pip_package:build_pip_package tfrun ./bazel-bin/tensorflow/tools/pip_package/build_pip_package build "${TFCI_BUILD_PIP_PACKAGE_ARGS[@]}" tfrun ./ci/official/utilities/rename_and_verify_wheels.sh build @@ -16,4 +20,4 @@ if [[ "$TFCI_UPLOAD_ENABLE" == 1 ]]; then gsutil cp build/*.whl "$TFCI_UPLOAD_GCS_DESTINATION" fi -tfrun bazel "${TFCI_BAZEL_BAZELRC_ARGS[@]}" test "${TFCI_BAZEL_CACHE_ARGS[@]}" --config=nonpip +tfrun bazel "${TFCI_BAZEL_BAZELRC_ARGS[@]}" test "${TFCI_BAZEL_COMMON_ARGS[@]}" --config=nonpip From 7a77ef6fe4812e1da95f7eb87a7b88c48de7ca9a Mon Sep 17 00:00:00 2001 From: Austin Anderson Date: Fri, 7 Jul 2023 17:40:13 -0700 Subject: [PATCH 040/376] Update TFCI sourcing --- ci/official/utilities/setup.sh | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/ci/official/utilities/setup.sh b/ci/official/utilities/setup.sh index 3d5773ffc32875..21b832c4e25afa 100755 --- a/ci/official/utilities/setup.sh +++ b/ci/official/utilities/setup.sh @@ -16,10 +16,17 @@ # (affects 'source $TFCI') set -euxo pipefail -o history -o allexport -# Import all variables as set in $TFCI, which should be a file like those in -# the envs directory that sets all TFCI_ variables, e.g. /path/to/envs/local_cpu -if [[ -n "$TFCI" ]]; then +# "TFCI" may optionally be set to the name of an env-type file with TFCI +# variables in it, OR may be left empty if the user has already exported the +# relevant variables in their environment. Because of 'set -o allexport' above +# (which is equivalent to "set -a"), every variable in the file is exported +# for other files to use. +if [[ -n "${TFCI:-}" ]]; then source "$TFCI" +else + echo '==TFCI==: The $TFCI variable is not set. This is fine as long as you' + echo 'already sourced a TFCI env file with "set -a; source ; set +a".' + echo 'If you have not, you will see a lot of undefined variable errors.' fi # Make a "build" directory for outputting all build artifacts (TF's .gitignore From 6391b8edb33283390522fffd02b349d3a49c6327 Mon Sep 17 00:00:00 2001 From: Andrew Goodbody Date: Mon, 10 Jul 2023 10:55:14 +0100 Subject: [PATCH 041/376] [Linaro:ARM_CI] Drop building with Python 3.8 as not supported Python 3.8 is no longer supported so drop attempts to build using it. --- .github/workflows/arm-cd.yml | 2 +- .github/workflows/arm-ci-extended.yml | 2 +- .../toolchains/cpus/aarch64/aarch64_compiler_configure.bzl | 2 -- 3 files changed, 2 insertions(+), 4 deletions(-) diff --git a/.github/workflows/arm-cd.yml b/.github/workflows/arm-cd.yml index a191c65a98f35f..05159d99516914 100644 --- a/.github/workflows/arm-cd.yml +++ b/.github/workflows/arm-cd.yml @@ -30,7 +30,7 @@ jobs: strategy: fail-fast: false matrix: - pyver: ['3.8', '3.9', '3.10'] + pyver: ['3.9', '3.10'] experimental: [false] include: - pyver: '3.11' diff --git a/.github/workflows/arm-ci-extended.yml b/.github/workflows/arm-ci-extended.yml index 7c386590addf70..8dd4f437cde18b 100644 --- a/.github/workflows/arm-ci-extended.yml +++ b/.github/workflows/arm-ci-extended.yml @@ -29,7 +29,7 @@ jobs: strategy: fail-fast: false matrix: - pyver: ['3.8', '3.9', '3.10', '3.11'] + pyver: ['3.9', '3.10', '3.11'] steps: - name: Stop old running containers (if any) shell: bash diff --git a/tensorflow/tools/toolchains/cpus/aarch64/aarch64_compiler_configure.bzl b/tensorflow/tools/toolchains/cpus/aarch64/aarch64_compiler_configure.bzl index fff3dd70496e89..a2bdd6a7eedafe 100644 --- a/tensorflow/tools/toolchains/cpus/aarch64/aarch64_compiler_configure.bzl +++ b/tensorflow/tools/toolchains/cpus/aarch64/aarch64_compiler_configure.bzl @@ -34,7 +34,6 @@ def aarch64_compiler_configure(): ml2014_tf_aarch64_configs( name_container_map = { "ml2014_aarch64": "docker://localhost/tensorflow-build-aarch64", - "ml2014_aarch64-python3.8": "docker://localhost/tensorflow-build-aarch64:latest-python3.8", "ml2014_aarch64-python3.9": "docker://localhost/tensorflow-build-aarch64:latest-python3.9", "ml2014_aarch64-python3.10": "docker://localhost/tensorflow-build-aarch64:latest-python3.10", "ml2014_aarch64-python3.11": "docker://localhost/tensorflow-build-aarch64:latest-python3.11", @@ -72,7 +71,6 @@ def aarch64_compiler_configure(): ml2014_tf_aarch64_configs( name_container_map = { "ml2014_clang_aarch64": "docker://localhost/tensorflow-build-aarch64", - "ml2014_clang_aarch64-python3.8": "docker://localhost/tensorflow-build-aarch64:latest-python3.8", "ml2014_clang_aarch64-python3.9": "docker://localhost/tensorflow-build-aarch64:latest-python3.9", "ml2014_clang_aarch64-python3.10": "docker://localhost/tensorflow-build-aarch64:latest-python3.10", "ml2014_clang_aarch64-python3.11": "docker://localhost/tensorflow-build-aarch64:latest-python3.11", From 10785e9ca342539148464da2b959c7e61b319d62 Mon Sep 17 00:00:00 2001 From: Andrew Goodbody Date: Mon, 10 Jul 2023 10:59:21 +0100 Subject: [PATCH 042/376] Fix ambiguity in use of overloaded functions in XLA gcc complains of abiguity in some overloaded functions so cast the parameter to overcome this. --- .../xla/python/pjrt_ifrt/xla_sharding_test.cc | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tensorflow/compiler/xla/python/pjrt_ifrt/xla_sharding_test.cc b/tensorflow/compiler/xla/python/pjrt_ifrt/xla_sharding_test.cc index d6b3bd5784cc7b..408c978fab375f 100644 --- a/tensorflow/compiler/xla/python/pjrt_ifrt/xla_sharding_test.cc +++ b/tensorflow/compiler/xla/python/pjrt_ifrt/xla_sharding_test.cc @@ -82,7 +82,7 @@ TEST(HloShardingTest, DisassembleWithReplication) { TEST(HloShardingTest, IndexDomainsWithTile) { auto device_list = CreateDummyDevices(2); // 2-way sharded along axis 0, 1-way sharded along axis 1. - auto xla_hlo_sharding = xla::HloSharding::Tile(xla::TileAssignment({2, 1})); + auto xla_hlo_sharding = xla::HloSharding::Tile(xla::TileAssignment((absl::Span){2, 1})); std::shared_ptr sharding = HloSharding::Create(device_list, xla_hlo_sharding); @@ -100,7 +100,7 @@ TEST(HloShardingTest, IndexDomainsWithTile) { TEST(HloShardingTest, DisassembleWithTile) { auto device_list = CreateDummyDevices(2); // 2-way sharded along axis 0, 1-way sharded along axis 1. - auto xla_hlo_sharding = xla::HloSharding::Tile(xla::TileAssignment({2, 1})); + auto xla_hlo_sharding = xla::HloSharding::Tile(xla::TileAssignment((absl::Span){2, 1})); std::shared_ptr sharding = HloSharding::Create(device_list, xla_hlo_sharding); @@ -120,7 +120,7 @@ TEST(HloShardingTest, DisassembleWithTile) { TEST(HloShardingTest, IndexDomainsWithUnevenTile) { auto device_list = CreateDummyDevices(2); // 2-way sharded along axis 0, 1-way sharded along axis 1. - auto xla_hlo_sharding = xla::HloSharding::Tile(xla::TileAssignment({2, 1})); + auto xla_hlo_sharding = xla::HloSharding::Tile(xla::TileAssignment((absl::Span){2, 1})); std::shared_ptr sharding = HloSharding::Create(device_list, xla_hlo_sharding); @@ -138,7 +138,7 @@ TEST(HloShardingTest, IndexDomainsWithUnevenTile) { TEST(HloShardingTest, DisassembleWithUnevenTile) { auto device_list = CreateDummyDevices(2); // 2-way sharded along axis 0, 1-way sharded along axis 1. - auto xla_hlo_sharding = xla::HloSharding::Tile(xla::TileAssignment({2, 1})); + auto xla_hlo_sharding = xla::HloSharding::Tile(xla::TileAssignment((absl::Span){2, 1})); std::shared_ptr sharding = HloSharding::Create(device_list, xla_hlo_sharding); @@ -300,7 +300,7 @@ TEST(HloShardingTest, DisassembleWithSubgroupMaximalSlowPath) { TEST(HloShardingTest, DisassembleFailsWithInvalidDeviceCount) { auto device_list = CreateDummyDevices(1); // 2-way sharded along axis 0, 1-way sharded along axis 1. - auto xla_hlo_sharding = xla::HloSharding::Tile(xla::TileAssignment({2, 1})); + auto xla_hlo_sharding = xla::HloSharding::Tile(xla::TileAssignment((absl::Span){2, 1})); std::shared_ptr sharding = HloSharding::Create(device_list, xla_hlo_sharding); @@ -314,7 +314,7 @@ TEST(HloShardingTest, DisassembleFailsWithInvalidDeviceCount) { TEST(HloShardingTest, DisassembleFailsWithMismatchingShapeDimsSize) { auto device_list = CreateDummyDevices(2); // 2-way sharded along axis 0, 1-way sharded along axis 1. - auto xla_hlo_sharding = xla::HloSharding::Tile(xla::TileAssignment({2, 1})); + auto xla_hlo_sharding = xla::HloSharding::Tile(xla::TileAssignment((absl::Span){2, 1})); std::shared_ptr sharding = HloSharding::Create(device_list, xla_hlo_sharding); From a0bc8c2af10684d0206e7be246c810e35a89b663 Mon Sep 17 00:00:00 2001 From: David Svantesson Date: Mon, 10 Jul 2023 12:27:50 +0000 Subject: [PATCH 043/376] Address review feedback --- tensorflow/tensorflow.bzl | 4 ++-- tensorflow/workspace2.bzl | 4 ++-- third_party/compute_library/BUILD | 7 +++++++ .../{acl_acl_reorder.patch => acl_reorder.patch} | 0 third_party/compute_library/build_defs.bzl | 4 ++-- ...eorder_padded.patch => onednn_acl_reorder_padded.patch} | 0 6 files changed, 13 insertions(+), 6 deletions(-) rename third_party/compute_library/{acl_acl_reorder.patch => acl_reorder.patch} (100%) rename third_party/mkl_dnn/{onednn_reorder_padded.patch => onednn_acl_reorder_padded.patch} (100%) diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl index 2fe9c936529542..de91916c364182 100644 --- a/tensorflow/tensorflow.bzl +++ b/tensorflow/tensorflow.bzl @@ -1498,7 +1498,7 @@ def tf_cc_test( "-lpthread", "-lm", ], - clean_dep("@compute_library//:build_with_acl"): [ + clean_dep("@org_tensorflow//third_party/compute_library:build_with_acl"): [ "-fopenmp", "-lm", ], @@ -1541,7 +1541,7 @@ def tf_cc_shared_test( "-lpthread", "-lm", ], - clean_dep("@compute_library//:build_with_acl"): [ + clean_dep("@org_tensorflow//third_party/compute_library:build_with_acl"): [ "-fopenmp", "-lm", ], diff --git a/tensorflow/workspace2.bzl b/tensorflow/workspace2.bzl index 551fb36a288c9a..46325e286b2338 100644 --- a/tensorflow/workspace2.bzl +++ b/tensorflow/workspace2.bzl @@ -208,7 +208,7 @@ def _tf_repositories(): "//third_party/mkl_dnn:onednn_acl_fixed_format_kernels.patch", "//third_party/mkl_dnn:onednn_acl_depthwise_convolution.patch", "//third_party/mkl_dnn:onednn_acl_threadpool_scheduler.patch", - "//third_party/mkl_dnn:onednn_reorder_padded.patch", + "//third_party/mkl_dnn:onednn_acl_reorder_padded.patch", "//third_party/mkl_dnn:onednn_acl_reorder_update.patch", "//third_party/mkl_dnn:onednn_acl_reorder.patch", ], @@ -221,7 +221,7 @@ def _tf_repositories(): name = "compute_library", sha256 = "4c22983f08cbc26a7b66c695ee6850d39ea1346a6c76a902323dd10217df4606", strip_prefix = "ComputeLibrary-23.05", - patch_file = ["//third_party/compute_library:compute_library.patch", "//third_party/compute_library:acl_acl_reorder.patch"], + patch_file = ["//third_party/compute_library:compute_library.patch", "//third_party/compute_library:acl_reorder.patch"], urls = tf_mirror_urls("https://github.com/ARM-software/ComputeLibrary/archive/v23.05.tar.gz"), ) diff --git a/third_party/compute_library/BUILD b/third_party/compute_library/BUILD index e69de29bb2d1d6..6ccd503a3c7ba8 100644 --- a/third_party/compute_library/BUILD +++ b/third_party/compute_library/BUILD @@ -0,0 +1,7 @@ +config_setting( + name = "build_with_acl", + define_values = { + "build_with_acl": "true", + }, + visibility = ["//visibility:public"], +) \ No newline at end of file diff --git a/third_party/compute_library/acl_acl_reorder.patch b/third_party/compute_library/acl_reorder.patch similarity index 100% rename from third_party/compute_library/acl_acl_reorder.patch rename to third_party/compute_library/acl_reorder.patch diff --git a/third_party/compute_library/build_defs.bzl b/third_party/compute_library/build_defs.bzl index 3898798a42d6de..5c5f8f6df5bd2a 100644 --- a/third_party/compute_library/build_defs.bzl +++ b/third_party/compute_library/build_defs.bzl @@ -1,6 +1,6 @@ def if_enable_acl(if_true, if_false = []): return select({ - "@compute_library//:build_with_acl": if_true, + "@org_tensorflow//third_party/compute_library:build_with_acl": if_true, "//conditions:default": if_false, }) @@ -15,6 +15,6 @@ def acl_deps(): inclusion in the deps attribute of rules. """ return select({ - "@compute_library//:build_with_acl": ["@compute_library//:arm_compute_core"], + "@org_tensorflow//third_party/compute_library:build_with_acl": ["@compute_library//:arm_compute_core"], "//conditions:default": [], }) diff --git a/third_party/mkl_dnn/onednn_reorder_padded.patch b/third_party/mkl_dnn/onednn_acl_reorder_padded.patch similarity index 100% rename from third_party/mkl_dnn/onednn_reorder_padded.patch rename to third_party/mkl_dnn/onednn_acl_reorder_padded.patch From dee8aded404b216523ac118a5807485357b71d64 Mon Sep 17 00:00:00 2001 From: David Svantesson Date: Mon, 10 Jul 2023 12:29:52 +0000 Subject: [PATCH 044/376] Update to ACL 23.05.1 --- tensorflow/workspace2.bzl | 7 +- .../acl_fixed_format_kernels_striding.patch | 70 ----------------- .../compute_library/acl_openmp_fix.patch | 46 ----------- third_party/compute_library/acl_reorder.patch | 42 ---------- third_party/compute_library/build_defs.bzl | 2 +- .../compute_library/compute_library.patch | 77 ------------------- third_party/mkl_dnn/mkldnn_acl.BUILD | 2 +- 7 files changed, 5 insertions(+), 241 deletions(-) delete mode 100644 third_party/compute_library/acl_fixed_format_kernels_striding.patch delete mode 100644 third_party/compute_library/acl_openmp_fix.patch delete mode 100644 third_party/compute_library/acl_reorder.patch delete mode 100644 third_party/compute_library/compute_library.patch diff --git a/tensorflow/workspace2.bzl b/tensorflow/workspace2.bzl index 46325e286b2338..3b919ce34c4811 100644 --- a/tensorflow/workspace2.bzl +++ b/tensorflow/workspace2.bzl @@ -219,10 +219,9 @@ def _tf_repositories(): tf_http_archive( name = "compute_library", - sha256 = "4c22983f08cbc26a7b66c695ee6850d39ea1346a6c76a902323dd10217df4606", - strip_prefix = "ComputeLibrary-23.05", - patch_file = ["//third_party/compute_library:compute_library.patch", "//third_party/compute_library:acl_reorder.patch"], - urls = tf_mirror_urls("https://github.com/ARM-software/ComputeLibrary/archive/v23.05.tar.gz"), + sha256 = "c4ca329a78da380163b2d86e91ba728349b6f0ee97d66e260a694ef37f0b0d93", + strip_prefix = "ComputeLibrary-23.05.1", + urls = tf_mirror_urls("https://github.com/ARM-software/ComputeLibrary/archive/v23.05.1.tar.gz"), ) tf_http_archive( diff --git a/third_party/compute_library/acl_fixed_format_kernels_striding.patch b/third_party/compute_library/acl_fixed_format_kernels_striding.patch deleted file mode 100644 index 8e501a1d6d9c79..00000000000000 --- a/third_party/compute_library/acl_fixed_format_kernels_striding.patch +++ /dev/null @@ -1,70 +0,0 @@ - ******************************************************************************* - Copyright 2022 Arm Limited and affiliates. - SPDX-License-Identifier: Apache-2.0 - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ******************************************************************************* - -diff --git a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp -index 77da83070..985f96761 100644 ---- a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp -+++ b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp -@@ -495,48 +495,6 @@ void Fallback::run(ITensorPack &tensors) - { - ldb = b->info()->strides_in_bytes().y() / sizeof(TypeInput); - multi_stride_b = b->info()->strides_in_bytes().z() / sizeof(TypeInput); -- const arm_compute::WeightFormat wf = assembly_utils::map_to_arm_compute_weight_format(_gemm_kernel_asm->get_config().weight_format); -- if(is_fixed_format(wf)) -- { -- // The 4D tensor of dimension O'HWI' created for the -- // OHWIoi format is in reality seen -- // as a 2D tensor at arm_gemm level, where the rows are -- // O'/ and the columns are * -- // H * W * I'. -- ITensorInfo *tensor_info = b->info(); -- const DataLayout data_layout = tensor_info->data_layout(); -- const TensorShape tensor_shape = tensor_info->tensor_shape(); -- const int tensor_height = tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT)]; -- const int tensor_width = tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH)]; -- int tensor_channels = tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL)]; -- const int interleave_by = arm_compute::interleave_by(wf); -- const int blocked_by = arm_compute::block_by(wf); -- // We need to find a new stride that is distance from the data for one -- // set of output channels to the next -- if(ldb == tensor_channels && multi_stride_b == tensor_channels * tensor_width) -- { -- // In this case dimensions that are packed are height, width and channel -- // so we need to stride it by interleave_by -- if(tensor_channels % blocked_by != 0) -- { -- // We need to pad -- tensor_channels = arm_gemm::iceildiv(tensor_channels, blocked_by) * blocked_by; -- } -- ldb = interleave_by * tensor_height * tensor_width * tensor_channels; -- } -- else if(multi_stride_b == 0 || (ldb == tensor_width && multi_stride_b == tensor_height * tensor_width)) -- { -- // In this case dimension that is packed is only height -- // so we need to stride only height by interleave_by -- ldb = interleave_by * tensor_height; -- } -- else -- { -- // If dimensions are not packed as above error is thrown -- // as at the moment other forms of packing are not supported -- ARM_COMPUTE_ERROR("Unsupported packing for fixed format kernel"); -- } -- } - in1_ptr = reinterpret_cast(b->buffer() + b->info()->offset_first_element_in_bytes()); - } - diff --git a/third_party/compute_library/acl_openmp_fix.patch b/third_party/compute_library/acl_openmp_fix.patch deleted file mode 100644 index 512148c8eca114..00000000000000 --- a/third_party/compute_library/acl_openmp_fix.patch +++ /dev/null @@ -1,46 +0,0 @@ - ******************************************************************************* - Copyright 2022 Arm Limited and affiliates. - SPDX-License-Identifier: Apache-2.0 - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ******************************************************************************* - -diff --git a/src/runtime/OMP/OMPScheduler.cpp b/src/runtime/OMP/OMPScheduler.cpp -index aad24b4f0..78d1523af 100644 ---- a/src/runtime/OMP/OMPScheduler.cpp -+++ b/src/runtime/OMP/OMPScheduler.cpp -@@ -90,18 +116,21 @@ void OMPScheduler::schedule_op(ICPPKernel *kernel, const Hints &hints, const Win - void OMPScheduler::run_workloads(std::vector &workloads) - { - const unsigned int amount_of_work = static_cast(workloads.size()); -- if(amount_of_work < 1 || _num_threads == 1) -+ const unsigned int num_threads_to_use = std::min(_num_threads, amount_of_work ); -+ -+ if(amount_of_work < 1 || num_threads_to_use == 1) - { - return; - } - - ThreadInfo info; - info.cpu_info = &cpu_info(); -- info.num_threads = _num_threads; -- #pragma omp parallel for firstprivate(info) num_threads(_num_threads) default(shared) proc_bind(close) schedule(static, 1) -+ info.num_threads = num_threads_to_use; -+ #pragma omp parallel for firstprivate(info) num_threads(num_threads_to_use) default(shared) proc_bind(close) schedule(static, 1) - for(unsigned int wid = 0; wid < amount_of_work; ++wid) - { - const int tid = omp_get_thread_num(); -+ - info.thread_id = tid; - workloads[wid](info); - } diff --git a/third_party/compute_library/acl_reorder.patch b/third_party/compute_library/acl_reorder.patch deleted file mode 100644 index 7f4a7d9f4f8d68..00000000000000 --- a/third_party/compute_library/acl_reorder.patch +++ /dev/null @@ -1,42 +0,0 @@ - ******************************************************************************* - Copyright 2023 Arm Limited and affiliates. - SPDX-License-Identifier: Apache-2.0 - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ******************************************************************************* - -diff --git a/arm_compute/runtime/NEON/functions/NEReorderLayer.h b/arm_compute/runtime/NEON/functions/NEReorderLayer.h -index a9ce8e3e6..eb777f192 100644 ---- a/arm_compute/runtime/NEON/functions/NEReorderLayer.h -+++ b/arm_compute/runtime/NEON/functions/NEReorderLayer.h -@@ -49,7 +49,7 @@ public: - /** Prevent instances of this class from being moved (As this class contains non movable objects) */ - NEReorderLayer &operator=(NEReorderLayer &&) = delete; - /** Default destructor */ -- ~NEReorderLayer() = default; -+ ~NEReorderLayer(); - /** Set the input and output tensors. - * - * Valid data layouts: -diff --git a/src/runtime/NEON/functions/NEReorderLayer.cpp b/src/runtime/NEON/functions/NEReorderLayer.cpp -index 2ab1029f0..427bf8c50 100644 ---- a/src/runtime/NEON/functions/NEReorderLayer.cpp -+++ b/src/runtime/NEON/functions/NEReorderLayer.cpp -@@ -29,6 +29,7 @@ - - namespace arm_compute - { -+NEReorderLayer::~NEReorderLayer() = default; - - NEReorderLayer::NEReorderLayer() - : _reorder_kernel(std::make_unique()) diff --git a/third_party/compute_library/build_defs.bzl b/third_party/compute_library/build_defs.bzl index 5c5f8f6df5bd2a..74102fd3e6d051 100644 --- a/third_party/compute_library/build_defs.bzl +++ b/third_party/compute_library/build_defs.bzl @@ -15,6 +15,6 @@ def acl_deps(): inclusion in the deps attribute of rules. """ return select({ - "@org_tensorflow//third_party/compute_library:build_with_acl": ["@compute_library//:arm_compute_core"], + "@org_tensorflow//third_party/compute_library:build_with_acl": ["@compute_library//:arm_compute"], "//conditions:default": [], }) diff --git a/third_party/compute_library/compute_library.patch b/third_party/compute_library/compute_library.patch deleted file mode 100644 index a35bdbfb552a71..00000000000000 --- a/third_party/compute_library/compute_library.patch +++ /dev/null @@ -1,77 +0,0 @@ - ******************************************************************************* - Copyright 2023 Arm Limited and affiliates. - SPDX-License-Identifier: Apache-2.0 - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ******************************************************************************* -diff --git a/BUILD.bazel b/BUILD.bazel -index f897a1a6a..e27c5a99b 100644 ---- a/BUILD.bazel -+++ b/BUILD.bazel -@@ -138,9 +138,7 @@ cc_library( - "ENABLE_NEON", - "ARM_COMPUTE_CPU_ENABLED", - "ARM_COMPUTE_ENABLE_NEON", -- "ARM_COMPUTE_ENABLE_FP16", - "ARM_COMPUTE_ENABLE_I8MM", -- "ENABLE_FP16_KERNELS", - "ENABLE_FP32_KERNELS", - "ENABLE_QASYMM8_KERNELS", - "ENABLE_QASYMM8_SIGNED_KERNELS", -@@ -174,17 +172,6 @@ cc_library( - visibility = ["//visibility:public"], - ) - --#--------------------------------------------------------------------- --# Rule for creating file "arm_compute_version.embed" --genrule( -- name = "create_version_file", -- srcs = [".git/HEAD"], -- outs = ["arm_compute_version.embed"], -- cmd = "$(location //scripts:print_version_file) bazel-build-options `cat $(location :.git/HEAD)` > $@", -- tools = ["//scripts:print_version_file"], -- visibility = ["//visibility:public"], --) -- - #--------------------------------------------------------------------- - # Graph library - -@@ -192,7 +179,7 @@ cc_library( - name = "arm_compute_graph", - srcs = ["//src:arm_compute_graph_srcs"], - copts = [ -- "-march=armv8.2-a+fp16", -+ "-march=armv8-a", - ] + select({ - "//:debug_flag": [ - "-O0", -@@ -330,10 +317,10 @@ cc_library( - "core/NEON/kernels/**/*.hpp", - "**/*.inl", - ]) + [ -- "//:create_version_file", -+ "arm_compute_version.embed" - ], - copts = [ -- "-march=armv8.2-a+fp16", -+ "-march=armv8-a", - ] + select({ - "//:debug_flag": [ - "-O0", -diff --git a/arm_compute_version.embed b/arm_compute_version.embed -new file mode 100644 -index 000000000..3b3c7d838 ---- /dev/null -+++ b/arm_compute_version.embed -@@ -0,0 +1 @@ -+"arm_compute_version=v23.05 Build options: {} Git hash=b'N/A'" diff --git a/third_party/mkl_dnn/mkldnn_acl.BUILD b/third_party/mkl_dnn/mkldnn_acl.BUILD index cfbd515d7815c2..a1085427ec08da 100644 --- a/third_party/mkl_dnn/mkldnn_acl.BUILD +++ b/third_party/mkl_dnn/mkldnn_acl.BUILD @@ -173,6 +173,6 @@ cc_library( ], visibility = ["//visibility:public"], deps = [ - "@compute_library//:arm_compute_core", + "@compute_library//:arm_compute", ], ) From 804d8d4fb7cbcbe5820e470c6880f67ce9406e3b Mon Sep 17 00:00:00 2001 From: Fergus Henderson Date: Mon, 10 Jul 2023 07:09:20 -0700 Subject: [PATCH 045/376] Handle find_builtin_op_v3 / find_custom_op_v3 in InterpreterCreateWithOpResolver. PiperOrigin-RevId: 546862127 --- tensorflow/lite/core/c/c_api.cc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tensorflow/lite/core/c/c_api.cc b/tensorflow/lite/core/c/c_api.cc index 6cef29fc8259f5..7880445c3fccdc 100644 --- a/tensorflow/lite/core/c/c_api.cc +++ b/tensorflow/lite/core/c/c_api.cc @@ -441,6 +441,8 @@ TfLiteInterpreter* InterpreterCreateWithOpResolver( optional_options->op_resolver_callbacks.find_custom_op_v1 != nullptr || optional_options->op_resolver_callbacks.find_builtin_op_v2 != nullptr || optional_options->op_resolver_callbacks.find_custom_op_v2 != nullptr || + optional_options->op_resolver_callbacks.find_builtin_op_v3 != nullptr || + optional_options->op_resolver_callbacks.find_custom_op_v3 != nullptr || optional_options->op_resolver_callbacks.find_builtin_op_external != nullptr || optional_options->op_resolver_callbacks.find_custom_op_external != From f7bb24938d515318e9cd533ae519943c56dfa8bd Mon Sep 17 00:00:00 2001 From: Christian Sigg Date: Mon, 10 Jul 2023 07:12:19 -0700 Subject: [PATCH 046/376] Combine Triton and TritonGPU dialects into a single target. PiperOrigin-RevId: 546862747 --- tensorflow/compiler/xla/service/gpu/BUILD | 2 +- third_party/triton/cl545371535.patch | 29 ---- third_party/triton/cl545644269.patch | 180 ++++++++++++++++++++++ third_party/triton/workspace.bzl | 2 +- 4 files changed, 182 insertions(+), 31 deletions(-) delete mode 100644 third_party/triton/cl545371535.patch create mode 100644 third_party/triton/cl545644269.patch diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 1e79ceed0e64cb..619cccc3c88c4f 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -463,7 +463,7 @@ cc_library( "@llvm-project//mlir:Support", "@llvm-project//mlir:ToLLVMIRTranslation", "@llvm-project//mlir:Transforms", - "@triton//:TritonDialect", + "@triton//:TritonDialects", "@triton//:TritonTransforms", ] + if_cuda_is_configured([ "@triton//:TritonGPUToLLVM", diff --git a/third_party/triton/cl545371535.patch b/third_party/triton/cl545371535.patch deleted file mode 100644 index f010d586bb02da..00000000000000 --- a/third_party/triton/cl545371535.patch +++ /dev/null @@ -1,29 +0,0 @@ -diff --git a/lib/Dialect/Triton/IR/Ops.cpp b/lib/Dialect/Triton/IR/Ops.cpp -index cd8d1b82e..f0f1127d5 100644 ---- a/lib/Dialect/Triton/IR/Ops.cpp -+++ b/lib/Dialect/Triton/IR/Ops.cpp -@@ -6,7 +6,7 @@ - #include "mlir/IR/OperationSupport.h" - #include "triton/Dialect/Triton/IR/Dialect.h" - #include "triton/Dialect/Triton/IR/Types.h" --#include "triton/Dialect/TritonGPU/IR/Attributes.h" -+//#include "triton/Dialect/TritonGPU/IR/Attributes.h" - - namespace mlir { - namespace triton { -@@ -404,6 +404,7 @@ LogicalResult mlir::triton::DotOp::verify() { - auto bTy = getOperand(1).getType().cast(); - if (aTy.getElementType() != bTy.getElementType()) - return emitError("element types of operands A and B must match"); -+#if 0 // TODO(csigg): avoid cyclic BUILD dependency. - auto aEncoding = - aTy.getEncoding().dyn_cast_or_null(); - auto bEncoding = -@@ -415,6 +416,7 @@ LogicalResult mlir::triton::DotOp::verify() { - return emitError("mismatching encoding between A and B operands"); - if (aEncoding.getMMAv2kWidth() != bEncoding.getMMAv2kWidth()) - return emitError("mismatching kWidth between A and B operands"); -+#endif - return mlir::success(); - } - diff --git a/third_party/triton/cl545644269.patch b/third_party/triton/cl545644269.patch new file mode 100644 index 00000000000000..9f453888c70c19 --- /dev/null +++ b/third_party/triton/cl545644269.patch @@ -0,0 +1,180 @@ +diff --git a/BUILD b/BUILD +index a5a813485..c7f8aa5a6 100644 +--- a/BUILD ++++ b/BUILD +@@ -275,8 +275,7 @@ cc_library( + copts = _no_unused_variable, + includes = ["include"], + deps = [ +- ":TritonDialect", +- ":TritonGPUDialect", ++ ":TritonDialects", + ":TritonTools", + ":triton_gpu_attr_inc_gen", + "@llvm-project//llvm:Support", +@@ -291,44 +290,53 @@ cc_library( + ) + + cc_library( +- name = "TritonDialect", +- srcs = glob(["lib/Dialect/Triton/IR/*.cpp"]), +- hdrs = glob(["include/triton/Dialect/Triton/IR/*.h"]), ++ name = "TritonDialects", ++ srcs = glob([ ++ "lib/Dialect/Triton/IR/*.cpp", ++ "lib/Dialect/TritonGPU/IR/*.cpp", ++ ]) + [ ++ "include/triton/Analysis/Utility.h", # Avoid circular dependency. ++ ], ++ hdrs = glob([ ++ "include/triton/Dialect/Triton/IR/*.h", ++ "include/triton/Dialect/TritonGPU/IR/*.h", ++ ]), + copts = _no_unused_variable, + includes = ["include"], + deps = [ +- ":TritonGPUAttributes", + ":triton_dialect_inc_gen", ++ ":triton_gpu_attr_inc_gen", ++ ":triton_gpu_dialect_inc_gen", ++ ":triton_gpu_ops_inc_gen", ++ ":triton_gpu_transforms_inc_gen", + ":triton_interfaces_inc_gen", + ":triton_ops_inc_gen", + "@llvm-project//llvm:Support", ++ "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:ControlFlowDialect", + "@llvm-project//mlir:ControlFlowInterfaces", ++ "@llvm-project//mlir:DestinationStyleOpInterface", + "@llvm-project//mlir:FuncDialect", ++ "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:IR", ++ "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:MathDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", ++ "@llvm-project//mlir:Transforms", + ], + ) + +-cc_library( +- name = "TritonGPUAttributes", +- hdrs = ["include/triton/Dialect/TritonGPU/IR/Attributes.h"], +- includes = ["include"], +- deps = ["triton_gpu_attr_inc_gen"], +-) +- + cc_library( + name = "TritonTransforms", + srcs = glob(["lib/Dialect/Triton/Transforms/*.cpp"]), + hdrs = glob(["include/triton/Dialect/Triton/Transforms/*.h"]), + includes = ["include"], + deps = [ +- ":TritonDialect", ++ ":TritonDialects", + ":triton_combine_inc_gen", + ":triton_transforms_inc_gen", + "@llvm-project//llvm:Support", +@@ -347,36 +355,6 @@ cc_library( + alwayslink = True, # TritonDialect uses getCanonicalizationPatterns(). + ) + +-cc_library( +- name = "TritonGPUDialect", +- srcs = glob(["lib/Dialect/TritonGPU/IR/*.cpp"]), +- hdrs = [ +- "include/triton/Analysis/Utility.h", # Avoid circular dependency. +- "include/triton/Dialect/TritonGPU/IR/Dialect.h", +- "include/triton/Dialect/TritonGPU/IR/Traits.h", +- ], +- copts = _no_unused_variable, +- includes = ["include"], +- deps = [ +- ":TritonDialect", +- ":TritonGPUAttributes", +- ":triton_gpu_attr_inc_gen", +- ":triton_gpu_dialect_inc_gen", +- ":triton_gpu_ops_inc_gen", +- ":triton_gpu_transforms_inc_gen", +- "@llvm-project//llvm:Support", +- "@llvm-project//mlir:Analysis", +- "@llvm-project//mlir:DestinationStyleOpInterface", +- "@llvm-project//mlir:GPUDialect", +- "@llvm-project//mlir:IR", +- "@llvm-project//mlir:LLVMDialect", +- "@llvm-project//mlir:Pass", +- "@llvm-project//mlir:Support", +- "@llvm-project//mlir:TensorDialect", +- "@llvm-project//mlir:Transforms", +- ], +-) +- + cc_library( + name = "TritonGPUTransforms", + srcs = glob([ +@@ -388,9 +366,7 @@ cc_library( + includes = ["include"], + deps = [ + ":TritonAnalysis", +- ":TritonDialect", +- ":TritonGPUAttributes", +- ":TritonGPUDialect", ++ ":TritonDialects", + ":triton_gpu_transforms_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:Analysis", +@@ -428,8 +404,7 @@ cc_library( + ], + deps = [ + ":TritonAnalysis", +- ":TritonDialect", +- ":TritonGPUDialect", ++ ":TritonDialects", + ":triton_conversion_triton_gpu_to_llvm_passes_inc_gen", + ":triton_conversion_triton_to_triton_gpu_passes_inc_gen", + "@llvm-project//llvm:Support", +@@ -466,8 +441,8 @@ cc_library( + hdrs = glob(["include/triton/Conversion/TritonToTritonGPU/*.h"]), + includes = ["include"], + deps = [ +- ":TritonDialect", +- ":TritonGPUDialect", ++ ":TritonAnalysis", ++ ":TritonDialects", + ":TritonGPUTransforms", + ":triton_conversion_triton_gpu_to_llvm_passes_inc_gen", + ":triton_conversion_triton_to_triton_gpu_passes_inc_gen", +@@ -513,9 +488,7 @@ cc_library( + "@llvm-project//mlir:ROCDLToLLVMIRTranslation", + "@llvm-project//mlir:ToLLVMIRTranslation", + "@llvm-project//mlir:Transforms", +- # copybara:uncomment_begin +- # "//third_party/py/triton/google:find_cuda", +- # copybara:uncomment_end ++ # copybara:uncomment "//third_party/py/triton/google:find_cuda", + ], + ) + +@@ -579,8 +552,7 @@ cc_binary( + ], + includes = ["include"], + deps = [ +- ":TritonDialect", +- ":TritonGPUDialect", ++ ":TritonDialects", + ":TritonGPUToLLVM", + ":TritonGPUTransforms", + ":TritonToTritonGPU", +@@ -618,8 +590,7 @@ cc_binary( + ], + includes = ["include"], + deps = [ +- ":TritonDialect", +- ":TritonGPUDialect", ++ ":TritonDialects", + ":TritonGPUToLLVM", + ":TritonGPUTransforms", + ":TritonHSACO", diff --git a/third_party/triton/workspace.bzl b/third_party/triton/workspace.bzl index 0488d04b39c462..1ccf24e7d75fc7 100644 --- a/third_party/triton/workspace.bzl +++ b/third_party/triton/workspace.bzl @@ -16,6 +16,6 @@ def repo(): # For temporary changes which haven't landed upstream yet. patch_file = [ "//third_party/triton:cl536931041.patch", - "//third_party/triton:cl545371535.patch", + "//third_party/triton:cl545644269.patch", ], ) From 6ab40bcee13cac68319986bccfee70ddecbfa71c Mon Sep 17 00:00:00 2001 From: Cesar Magana De Leon Date: Mon, 10 Jul 2023 08:23:06 -0700 Subject: [PATCH 047/376] Testing saved_model_aot Wrapper PiperOrigin-RevId: 546878835 --- tensorflow/core/tfrt/saved_model/python/BUILD | 34 ++++++++------ .../python/saved_model_aot_compile.py | 45 ------------------- .../python/saved_model_aot_compile_test.py | 37 +++++++++++++++ .../python/saved_model_aot_compile_wrapper.cc | 8 +--- .../saved_model/saved_model_aot_compile.cc | 3 +- .../saved_model/saved_model_aot_compile.h | 4 +- 6 files changed, 64 insertions(+), 67 deletions(-) delete mode 100644 tensorflow/core/tfrt/saved_model/python/saved_model_aot_compile.py create mode 100644 tensorflow/core/tfrt/saved_model/python/saved_model_aot_compile_test.py diff --git a/tensorflow/core/tfrt/saved_model/python/BUILD b/tensorflow/core/tfrt/saved_model/python/BUILD index 92d1944a8cf1a7..2b4fb5668cc835 100644 --- a/tensorflow/core/tfrt/saved_model/python/BUILD +++ b/tensorflow/core/tfrt/saved_model/python/BUILD @@ -1,5 +1,5 @@ load("//tensorflow:tensorflow.default.bzl", "tf_python_pybind_extension") -load("//tensorflow:pytype.default.bzl", "pytype_strict_binary") +load("//tensorflow:pytype.default.bzl", "pytype_strict_binary", "pytype_strict_contrib_test") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -29,19 +29,6 @@ pytype_strict_binary( ], ) -py_binary( - name = "saved_model_aot_compile_py", - srcs = ["saved_model_aot_compile.py"], - main = "saved_model_aot_compile.py", - python_version = "PY3", - srcs_version = "PY3", - deps = [ - ":_pywrap_saved_model_aot_compile", - "//tensorflow/core/tfrt/graph_executor/python:_pywrap_graph_execution_options", - "@absl_py//absl:app", - ], -) - tf_python_pybind_extension( name = "_pywrap_saved_model_aot_compile", srcs = ["saved_model_aot_compile_wrapper.cc"], @@ -89,3 +76,22 @@ tf_python_pybind_extension( "@pybind11_abseil//pybind11_abseil:status_casters", ], ) + +pytype_strict_contrib_test( + name = "saved_model_aot_compile_test", + size = "small", + srcs = [ + "saved_model_aot_compile_test.py", + ], + data = [ + "//learning/brain/tfrt/cpp_tests/gpu_inference:testdata", + ], + python_version = "PY3", + deps = [ + ":_pywrap_saved_model_aot_compile", + "//base/python:pywrapbase", + "//tensorflow/python/platform:client_testlib", + "//testing/pybase", + "//third_party/py/lingvo:compat", + ], +) diff --git a/tensorflow/core/tfrt/saved_model/python/saved_model_aot_compile.py b/tensorflow/core/tfrt/saved_model/python/saved_model_aot_compile.py deleted file mode 100644 index da11aa9b22aa68..00000000000000 --- a/tensorflow/core/tfrt/saved_model/python/saved_model_aot_compile.py +++ /dev/null @@ -1,45 +0,0 @@ -# Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at - -# http://www.apache.org/licenses/LICENSE-2.0 - -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -"""Test .py file for pybind11 files for AotOptions and AotCompileSavedModel, currently unable to test due to nullptr in AotOptions.""" - - -from absl import app -from tensorflow.core.tfrt.graph_executor.python import _pywrap_graph_execution_options -from tensorflow.core.tfrt.saved_model.python import _pywrap_saved_model_aot_compile - - -def main(unused_argv): - if not _pywrap_saved_model_aot_compile: - return - try: - # Test for creating an instance of GraphExecutionOptions - test = _pywrap_graph_execution_options.GraphExecutionOptions() - print(test) - - # Executes AoTOptions and AotCompileSavedModel for Wrapping Tests - _pywrap_saved_model_aot_compile.AotOptions() - - # TODO(cesarmagana): Once AotCompileSavedModel is complete - # update this test script to read from CNS - _pywrap_saved_model_aot_compile.AotCompileSavedModel("random") - - # Could also do except status.StatusNotOk if testing for AotCompileSavedModel - except Exception as exception: # pylint: disable=broad-exception-caught - print(exception) - - -if __name__ == "__main__": - app.run(main) diff --git a/tensorflow/core/tfrt/saved_model/python/saved_model_aot_compile_test.py b/tensorflow/core/tfrt/saved_model/python/saved_model_aot_compile_test.py new file mode 100644 index 00000000000000..2e849fb743855c --- /dev/null +++ b/tensorflow/core/tfrt/saved_model/python/saved_model_aot_compile_test.py @@ -0,0 +1,37 @@ +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import os + +import lingvo.compat as tf + +from tensorflow.core.tfrt.saved_model.python import _pywrap_saved_model_aot_compile +from tensorflow.python.platform import test + + +class SavedModelAotCompileTest(test.TestCase): + + def testVerify_saved_model(self): + outputpath = os.getenv("TEST_UNDECLARED_OUTPUTS_DIR") + filepath = "learning/brain/tfrt/cpp_tests/gpu_inference/test_data/translate_converted_placed/" + _pywrap_saved_model_aot_compile.AotCompileSavedModel( + filepath, _pywrap_saved_model_aot_compile.AotOptions(), outputpath + ) + + # Verifies that .pbtxt is created correctly in the output directory + self.assertTrue(tf.io.gfile.exists(outputpath + "/aot_saved_model.pbtxt")) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/core/tfrt/saved_model/python/saved_model_aot_compile_wrapper.cc b/tensorflow/core/tfrt/saved_model/python/saved_model_aot_compile_wrapper.cc index b8b0f6985007d4..7c1e31fab55e7e 100644 --- a/tensorflow/core/tfrt/saved_model/python/saved_model_aot_compile_wrapper.cc +++ b/tensorflow/core/tfrt/saved_model/python/saved_model_aot_compile_wrapper.cc @@ -25,12 +25,8 @@ namespace py = pybind11; PYBIND11_MODULE(_pywrap_saved_model_aot_compile, m) { py::google::ImportStatusModule(); - py::class_(m, "AotOptions", - py::dynamic_attr()) - .def(py::init<>()) - .def_readwrite( - "graph_execution_options", - &tensorflow::tfrt_stub::AotOptions::graph_execution_options); + py::class_(m, "AotOptions") + .def(py::init<>()); m.doc() = "pybind11 AotOptions Python - C++ Wrapper"; m.def("AotCompileSavedModel", &tensorflow::tfrt_stub::AotCompileSavedModel, diff --git a/tensorflow/core/tfrt/saved_model/saved_model_aot_compile.cc b/tensorflow/core/tfrt/saved_model/saved_model_aot_compile.cc index 2711a631822539..aa5fcd51e98195 100644 --- a/tensorflow/core/tfrt/saved_model/saved_model_aot_compile.cc +++ b/tensorflow/core/tfrt/saved_model/saved_model_aot_compile.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/core/tfrt/saved_model/saved_model_aot_compile.h" +#include #include #include "absl/status/status.h" @@ -33,7 +34,7 @@ limitations under the License. namespace tensorflow::tfrt_stub { -AotOptions::AotOptions() : graph_execution_options(GetGlobalRuntime()) {} +AotOptions::AotOptions() : graph_execution_options(nullptr) {} Status AotCompileSavedModel(absl::string_view input_model_dir, const AotOptions& aot_options, diff --git a/tensorflow/core/tfrt/saved_model/saved_model_aot_compile.h b/tensorflow/core/tfrt/saved_model/saved_model_aot_compile.h index a3cffc385b8d7e..5547d3506e0ace 100644 --- a/tensorflow/core/tfrt/saved_model/saved_model_aot_compile.h +++ b/tensorflow/core/tfrt/saved_model/saved_model_aot_compile.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_CORE_TFRT_SAVED_MODEL_SAVED_MODEL_AOT_COMPILE_H_ #define TENSORFLOW_CORE_TFRT_SAVED_MODEL_SAVED_MODEL_AOT_COMPILE_H_ +#include #include #include "tensorflow/compiler/xla/service/compiler.h" @@ -23,8 +24,9 @@ limitations under the License. namespace tensorflow::tfrt_stub { struct AotOptions { - GraphExecutionOptions graph_execution_options; AotOptions(); + + std::unique_ptr graph_execution_options; }; // AOT Compiles saved_model in input_model_dir, writing output From 604c024c2c70352947f839d7e089a7dd91ca4f37 Mon Sep 17 00:00:00 2001 From: Ilia Sergachev Date: Mon, 10 Jul 2023 08:30:57 -0700 Subject: [PATCH 048/376] [XLA:GPU] Fix propagation of HLO module debug options in IrEmitterUnnested. PiperOrigin-RevId: 546880494 --- tensorflow/compiler/xla/service/gpu/BUILD | 2 ++ .../xla/service/gpu/ir_emitter_triton_test.cc | 32 +++++++++++++++++++ .../xla/service/gpu/ir_emitter_unnested.cc | 1 + 3 files changed, 35 insertions(+) diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 619cccc3c88c4f..a7ed398aace948 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -498,6 +498,8 @@ xla_test( "//tensorflow/compiler/xla/stream_executor/cuda:cublas_plugin", "//tensorflow/compiler/xla/tests:verified_hlo_module", "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep + "//tensorflow/tsl/lib/core:status_test_util", + "//tensorflow/tsl/platform:path", "//tensorflow/tsl/platform:status_matchers", "//tensorflow/tsl/platform:statusor", "//tensorflow/tsl/platform:tensor_float_32_hdr_lib", diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_triton_test.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_triton_test.cc index 3b38a412b8b27b..fc4bb7204c1632 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_triton_test.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_triton_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include #include +#include #include "llvm/IR/LLVMContext.h" #include "mlir/IR/MLIRContext.h" // from @llvm-project @@ -31,6 +32,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h" #include "tensorflow/compiler/xla/stream_executor/device_description.h" #include "tensorflow/compiler/xla/tests/verified_hlo_module.h" +#include "tensorflow/tsl/lib/core/status_test_util.h" +#include "tensorflow/tsl/platform/path.h" #include "tensorflow/tsl/platform/status_matchers.h" #include "tensorflow/tsl/platform/statusor.h" #include "tensorflow/tsl/platform/tensor_float_32_utils.h" @@ -92,6 +95,35 @@ class TritonGemmTest : public GpuCodegenTest { } }; +TEST_F(TritonGemmTest, DebugOptionsArePropagated) { + const std::string kHloText = R"( +ENTRY e { + p0 = f16[30,30] parameter(0) + p1 = s8[30,30] parameter(1) + cp1 = f16[30,30] convert(p1) + ROOT _ = f16[30,30] dot(p0, cp1), + lhs_contracting_dims={0}, rhs_contracting_dims={1} +})"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr verified_module, + ParseAndReturnVerifiedModule(kHloText)); + std::string output_directory; + if (!tsl::io::GetTestUndeclaredOutputsDir(&output_directory)) { + output_directory = tsl::testing::TmpDir(); + } + DebugOptions debug_options = verified_module->config().debug_options(); + debug_options.set_xla_dump_to(output_directory); + debug_options.set_xla_gpu_dump_llvmir(true); + verified_module->config().set_debug_options(debug_options); + + EXPECT_TRUE(RunAndCompare(std::move(verified_module), + ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); + + std::vector paths; + TF_EXPECT_OK(tsl::Env::Default()->GetMatchingPaths( + tsl::io::JoinPath(output_directory, "*.triton-passes.log"), &paths)); + EXPECT_EQ(paths.size(), 1); +} + TEST_F(TritonGemmTest, UseTensorCoresForF32OnAmpere) { const std::string kHloText = R"( HloModule t, is_scheduled=true diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 47d706fbbd7fb5..a399816dc67739 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -2725,6 +2725,7 @@ IrEmitterUnnested::GetOrCreateSubComputationFromRegion(mlir::Region* region, TF_ASSIGN_OR_RETURN( module, HloModule::CreateFromProto(xla_computation.proto(), HloModuleConfig(program_shape))); + module->config().set_debug_options(hlo_module_config_.debug_options()); if (is_fusion) { HloComputation* fused_computation = module->entry_computation(); From c6b558b47190329d8ab9834a0102759ce52e8703 Mon Sep 17 00:00:00 2001 From: Bixia Zheng Date: Mon, 10 Jul 2023 09:33:22 -0700 Subject: [PATCH 049/376] [xla] Preserve meta data. Copy the meta data from the CollectivePermute operation to the Send and Recv operations. Modify a test. PiperOrigin-RevId: 546896375 --- .../xla/service/collective_permute_decomposer.cc | 3 +++ .../xla/service/collective_permute_decomposer_test.cc | 11 ++++++++++- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/tensorflow/compiler/xla/service/collective_permute_decomposer.cc b/tensorflow/compiler/xla/service/collective_permute_decomposer.cc index 0359dc6d8bfd68..d17a2af5c1576a 100644 --- a/tensorflow/compiler/xla/service/collective_permute_decomposer.cc +++ b/tensorflow/compiler/xla/service/collective_permute_decomposer.cc @@ -98,6 +98,7 @@ Status DecomposeCollectivePermute( int64_t channel_id = collective_permute->channel_id().value_or(0); HloInstruction* data = collective_permute->mutable_operand(0); const Shape& data_shape = data->shape(); + const OpMetadata& metadata = collective_permute->metadata(); xla::FrontendAttributes attributes; std::string source_target_pairs_string = @@ -121,10 +122,12 @@ Status DecomposeCollectivePermute( HloInstruction* recv = computation->AddInstruction( HloInstruction::CreateRecv(data_shape, after_all, channel_id)); recv->set_frontend_attributes(attributes); + recv->set_metadata(metadata); HloInstruction* send = computation->AddInstruction( HloInstruction::CreateSend(data, after_all, channel_id)); send->set_frontend_attributes(attributes); + send->set_metadata(metadata); // We want the Recv to be scheduled before the Send, enforce this with a // control dependency. TF_RETURN_IF_ERROR(recv->AddControlDependencyTo(send)); diff --git a/tensorflow/compiler/xla/service/collective_permute_decomposer_test.cc b/tensorflow/compiler/xla/service/collective_permute_decomposer_test.cc index 20d01ed2e00f49..0aec6e6ceba85b 100644 --- a/tensorflow/compiler/xla/service/collective_permute_decomposer_test.cc +++ b/tensorflow/compiler/xla/service/collective_permute_decomposer_test.cc @@ -94,7 +94,8 @@ TEST_F(CollectivePermuteDecomposerTest, TransformedDefaultChannelId) { ENTRY test_computation { p = u32[] replica-id() start = (u32[], u32[]) collective-permute-start(p), - source_target_pairs={{0,1}, {1,2}, {2,3}, {3,4}} + source_target_pairs={{0,1}, {1,2}, {2,3}, {3,4}}, + metadata={op_name="op1/op2/add" source_file="foo/bar/mysource.py" source_line=35} ROOT done = u32[] collective-permute-done(start) } )"; @@ -105,6 +106,12 @@ TEST_F(CollectivePermuteDecomposerTest, TransformedDefaultChannelId) { TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get())); EXPECT_TRUE(changed); + auto check_metadata = [](const HloInstruction* inst) { + EXPECT_EQ(inst->metadata().op_name(), "op1/op2/add"); + EXPECT_EQ(inst->metadata().source_file(), "foo/bar/mysource.py"); + EXPECT_EQ(inst->metadata().source_line(), 35); + }; + HloInstruction* after_all = FindInstruction(module.get(), "after-all"); HloInstruction* recv = FindInstruction(module.get(), "recv"); EXPECT_EQ(recv->operand(0), after_all); @@ -113,6 +120,7 @@ TEST_F(CollectivePermuteDecomposerTest, TransformedDefaultChannelId) { recv->ToString(), HasSubstr( "_xla_send_recv_source_target_pairs=\"{{0,1},{1,2},{2,3},{3,4}}\"")); + check_metadata(recv); HloInstruction* recv_done = FindInstruction(module.get(), "recv-done"); EXPECT_EQ(recv_done->operand(0), recv); @@ -124,6 +132,7 @@ TEST_F(CollectivePermuteDecomposerTest, TransformedDefaultChannelId) { send->ToString(), HasSubstr( "_xla_send_recv_source_target_pairs=\"{{0,1},{1,2},{2,3},{3,4}}\"")); + check_metadata(send); HloInstruction* send_done = FindInstruction(module.get(), "send-done"); EXPECT_EQ(send_done->operand(0), send); EXPECT_EQ(send_done->control_predecessors()[0], recv_done); From a369961f8120bdb240a4bb9c70ed98bb98621b45 Mon Sep 17 00:00:00 2001 From: Junwhan Ahn Date: Mon, 10 Jul 2023 09:36:07 -0700 Subject: [PATCH 050/376] Add a proto version of `xla::ExecuteOptions` `xla::CompileOptions` already has a proto version of it, but `xla::ExecuteOptions` does not. This CL adds the same `ToProto()`/`FromProto()` utility function for `xla::ExecuteOptions`. It is worth noting that not all execute option fields are serializable (e.g., opaque `context`). For now, `ExecuteOptions::ToProto` returns an error if non-serializable options are set. In future, we can consider redesigning them to be serializable if needed. PiperOrigin-RevId: 546897094 --- tensorflow/compiler/xla/pjrt/BUILD | 9 ++ .../compiler/xla/pjrt/execute_options.proto | 20 +++ tensorflow/compiler/xla/pjrt/pjrt_client.cc | 1 - tensorflow/compiler/xla/pjrt/pjrt_client.h | 106 ---------------- .../compiler/xla/pjrt/pjrt_executable.cc | 73 +++++++++++ .../compiler/xla/pjrt/pjrt_executable.h | 117 ++++++++++++++++++ .../compiler/xla/pjrt/pjrt_executable_test.cc | 46 ++++++- 7 files changed, 261 insertions(+), 111 deletions(-) create mode 100644 tensorflow/compiler/xla/pjrt/execute_options.proto diff --git a/tensorflow/compiler/xla/pjrt/BUILD b/tensorflow/compiler/xla/pjrt/BUILD index d352faee3674c3..46fbf312d3aa1d 100644 --- a/tensorflow/compiler/xla/pjrt/BUILD +++ b/tensorflow/compiler/xla/pjrt/BUILD @@ -222,6 +222,7 @@ cc_library( hdrs = ["pjrt_executable.h"], visibility = [":friends"], deps = [ + ":execute_options_proto_cc", ":pjrt_common", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", @@ -231,6 +232,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_cost_analysis", "//tensorflow/compiler/xla/service:hlo_proto_cc", "//tensorflow/tsl/platform:statusor", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", ], ) @@ -244,6 +246,7 @@ xla_cc_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/client:executable_build_options", + "//tensorflow/tsl/platform:status_matchers", "@com_google_googletest//:gtest_main", ], ) @@ -795,3 +798,9 @@ tf_proto_library( # deps = [":compile_options_proto"], # ) # copybara:uncomment_end + +tf_proto_library( + name = "execute_options_proto", + srcs = ["execute_options.proto"], + visibility = ["//visibility:public"], +) diff --git a/tensorflow/compiler/xla/pjrt/execute_options.proto b/tensorflow/compiler/xla/pjrt/execute_options.proto new file mode 100644 index 00000000000000..af9558200fb907 --- /dev/null +++ b/tensorflow/compiler/xla/pjrt/execute_options.proto @@ -0,0 +1,20 @@ +syntax = "proto3"; + +package xla; + +enum ExecutionModeProto { + EXECUTION_MODE_UNSPECIFIED = 0; + EXECUTION_MODE_DEFAULT = 1; + EXECUTION_MODE_SYNCHRONOUS = 2; + EXECUTION_MODE_ASYNCHRONOUS = 3; +} + +// Mirrors `xla::ExecuteOptions`. +message ExecuteOptionsProto { + bool arguments_are_tupled = 1; + bool untuple_result = 2; + int32 launch_id = 3; + bool strict_shape_checking = 4; + ExecutionModeProto execution_mode = 6; + repeated int32 non_donatable_input_indices = 7; +} diff --git a/tensorflow/compiler/xla/pjrt/pjrt_client.cc b/tensorflow/compiler/xla/pjrt/pjrt_client.cc index ba6e5856042207..ceff30f1825e43 100644 --- a/tensorflow/compiler/xla/pjrt/pjrt_client.cc +++ b/tensorflow/compiler/xla/pjrt/pjrt_client.cc @@ -60,7 +60,6 @@ PjRtFuture PjRtBuffer::CopyRawToHostFuture( return PjRtFuture(std::move(promise)); } -MultiSliceConfig::~MultiSliceConfig() = default; std::string CompiledMemoryStats::DebugString() const { return absl::Substitute( diff --git a/tensorflow/compiler/xla/pjrt/pjrt_client.h b/tensorflow/compiler/xla/pjrt/pjrt_client.h index a59502f33167fd..ee2715e9743753 100644 --- a/tensorflow/compiler/xla/pjrt/pjrt_client.h +++ b/tensorflow/compiler/xla/pjrt/pjrt_client.h @@ -1199,112 +1199,6 @@ class PjRtBuffer { virtual bool IsOnCpu() const = 0; }; -class ExecuteContext { - public: - virtual ~ExecuteContext() = default; -}; - -struct PjRtTransferMetadata { - // May be invalid if - // ExecuteOptions::use_major_to_minor_data_layout_for_callbacks is true for - // this execution. - Shape device_shape; -}; - -struct SendCallback { - int64_t channel_id; - // The callback for retrieving the send value. It will be invoked once for - // each invocation of the corresponding Send op in the HLO program (So it can - // be invoked multiple times if it is in a loop). Currently there is no - // guarantee that the callback here will be invoked in the same order as their - // corresponding HLO Send ops. The callback can also return errors to indicate - // the execution should fail. - // - // IMPORTANT: the implementation might NOT signal the error to the execution, - // and the execution will run to completion with UNDEFINED DATA returned by - // the callback. If there is any potential control flow that depends on the - // value of the returned data, an error return is unsafe. - // - // TODO(chky): Currently the callback invocation order may not be consistent - // with the HLO send op invocation order, due to limitations in some PjRt - // implementation. Consider making it strictly the same order as HLO program. - std::function - callback; -}; - -struct RecvCallback { - int64_t channel_id; - // The callback for feeding the recv value. It will be invoked once for each - // invocation of the corresponding Recv op in the HLO program (So it can be - // invoked multiple times if it is in a loop). Currently there is no - // guarantee that the callback here will be invoked in the same order as their - // corresponding HLO Recv ops. - std::function stream)> - callback; -}; - -struct ExecuteOptions { - // If true, the client must pass a single PjRtBuffer which contains all of - // the arguments as a single XLA tuple, otherwise each argument must be - // passed in its own PjRtBuffer. May only be true if the executable was - // compiled with parameter_is_tupled_arguments==true. - bool arguments_are_tupled = false; - // If true, the computation must return a tuple, which will be destructured - // into its elements. - bool untuple_result = false; - // If non-zero, identifies this execution as part of a potentially - // multi-device launch. This can be used to detect scheduling errors, e.g. if - // multi-host programs are launched in different orders on different hosts, - // the launch IDs may be used by the runtime to detect the mismatch. - int32_t launch_id = 0; - // If non-null, an opaque context passed to an execution that may be used to - // supply additional arguments to a derived class of PjRtExecutable. - const ExecuteContext* context = nullptr; - // If true, check that the PjRtBuffer argument shapes match the compiled - // shapes. Otherwise, any shape with the right size on device may be passed. - bool strict_shape_checking = true; - - // Set multi_slice_config when the computation spans multiple slices. The - // config should match what was used during compilation to generate this - // executable. - const MultiSliceConfig* multi_slice_config = nullptr; - - // The send/recv callbacks for PjRt execution. The first level span is for - // multi-device parallel execution, the second level vector contains the - // callbacks for all send/recv ops in the executable. These callbacks can be - // stateful and the user code is responsible for managing the states here. - // These callbacks must outlive the execution. - absl::Span> send_callbacks; - absl::Span> recv_callbacks; - - // If true, send callbacks are passed PjRtChunks in major-to-minor layout, and - // recv functions should pass major-to-minor chunks to - // CopyToDeviceStream::AddChunk. - // - // If false, send callbacks are passed PjRtChunks in the on-device layout - // specified in the PjRtTransferMetadata, and recv functions should similarly - // pass device-layout chunks to CopyToDeviceStream::AddChunk. - bool use_major_to_minor_data_layout_for_callbacks = false; - - // The `execution_mode` decides whether the execution will be invoked in the - // caller thread or launched to a separate thread. By default, the - // implementation may choose either strategy or use a heuristic to decide. - // Currently it is only applied to CPU implementations - enum class ExecutionMode { kDefault = 0, kSynchronous, kAsynchronous }; - ExecutionMode execution_mode = ExecutionMode::kDefault; - - // A set of indices denoting the input buffers that should not be donated. - // An input buffer may be non-donable, for example, if it is referenced more - // than once. Since such runtime information is not available at compile time, - // the compiler might mark the input as `may-alias`, which could lead PjRt to - // donate the input buffer when it should not. By defining this set of - // indices, a higher-level PjRt caller can instruct PjRtClient not to donate - // specific input buffers. - absl::flat_hash_set non_donatable_input_indices; -}; - // Represents a compiled computation that can be executed given handles to // device-allocated literals. If any input/output alias has been specified in // the computation, the parameter containing the input buffer will be donated diff --git a/tensorflow/compiler/xla/pjrt/pjrt_executable.cc b/tensorflow/compiler/xla/pjrt/pjrt_executable.cc index b4ac9732815130..f087e103003bc1 100644 --- a/tensorflow/compiler/xla/pjrt/pjrt_executable.cc +++ b/tensorflow/compiler/xla/pjrt/pjrt_executable.cc @@ -24,6 +24,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/client/executable_build_options.h" +#include "tensorflow/compiler/xla/pjrt/execute_options.pb.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/tsl/platform/statusor.h" @@ -108,6 +109,78 @@ StatusOr CompileOptions::FromProto( return output; } +MultiSliceConfig::~MultiSliceConfig() = default; + +absl::StatusOr ExecuteOptions::ToProto() const { + ExecuteOptionsProto proto; + + proto.set_arguments_are_tupled(arguments_are_tupled); + proto.set_untuple_result(untuple_result); + proto.set_launch_id(launch_id); + if (context != nullptr) { + return absl::UnimplementedError( + "ExecuteOptions with non-nullptr context is not serializable"); + } + proto.set_strict_shape_checking(strict_shape_checking); + + if (multi_slice_config != nullptr) { + return absl::UnimplementedError( + "ExecuteOptions with multi-slice config is not serializable"); + } + + if (!send_callbacks.empty() || !recv_callbacks.empty()) { + return absl::UnimplementedError( + "ExecuteOptions with send/recv calbacks is not serializable"); + } + + switch (execution_mode) { + case ExecutionMode::kDefault: + proto.set_execution_mode(EXECUTION_MODE_DEFAULT); + break; + case ExecutionMode::kSynchronous: + proto.set_execution_mode(EXECUTION_MODE_SYNCHRONOUS); + break; + case ExecutionMode::kAsynchronous: + proto.set_execution_mode(EXECUTION_MODE_ASYNCHRONOUS); + break; + } + + proto.mutable_non_donatable_input_indices()->Add( + non_donatable_input_indices.begin(), non_donatable_input_indices.end()); + + return proto; +} + +absl::StatusOr ExecuteOptions::FromProto( + const ExecuteOptionsProto& proto) { + ExecuteOptions options; + + options.arguments_are_tupled = proto.arguments_are_tupled(); + options.untuple_result = proto.untuple_result(); + options.launch_id = proto.launch_id(); + + switch (proto.execution_mode()) { + case EXECUTION_MODE_DEFAULT: + options.execution_mode = ExecutionMode::kDefault; + break; + case EXECUTION_MODE_SYNCHRONOUS: + options.execution_mode = ExecutionMode::kSynchronous; + break; + case EXECUTION_MODE_ASYNCHRONOUS: + options.execution_mode = ExecutionMode::kAsynchronous; + break; + default: + return absl::UnimplementedError( + absl::StrCat("Unknown execution mode: ", proto.execution_mode())); + } + + options.non_donatable_input_indices.insert( + proto.non_donatable_input_indices().begin(), + proto.non_donatable_input_indices().end()); + + return options; +} + void GetOpSharding(std::vector& out, const OpSharding& sharding) { if (sharding.type() == OpSharding::TUPLE) { for (const OpSharding& s : sharding.tuple_shardings()) { diff --git a/tensorflow/compiler/xla/pjrt/pjrt_executable.h b/tensorflow/compiler/xla/pjrt/pjrt_executable.h index 3dfbb6c038442e..96610182284b14 100644 --- a/tensorflow/compiler/xla/pjrt/pjrt_executable.h +++ b/tensorflow/compiler/xla/pjrt/pjrt_executable.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_PJRT_PJRT_EXECUTABLE_H_ #include +#include #include #include #include @@ -24,9 +25,11 @@ limitations under the License. #include #include +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/client/executable_build_options.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_module.h" +#include "tensorflow/compiler/xla/pjrt/execute_options.pb.h" #include "tensorflow/compiler/xla/pjrt/pjrt_common.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" @@ -120,6 +123,120 @@ struct LoadOptions { const MultiSliceConfig* multi_slice_config = nullptr; }; +class ExecuteContext { + public: + virtual ~ExecuteContext() = default; +}; + +struct PjRtTransferMetadata { + // May be invalid if + // ExecuteOptions::use_major_to_minor_data_layout_for_callbacks is true for + // this execution. + Shape device_shape; +}; + +class PjRtChunk; +class PjRtTransferMetadata; +class CopyToDeviceStream; + +struct SendCallback { + int64_t channel_id; + // The callback for retrieving the send value. It will be invoked once for + // each invocation of the corresponding Send op in the HLO program (So it can + // be invoked multiple times if it is in a loop). Currently there is no + // guarantee that the callback here will be invoked in the same order as their + // corresponding HLO Send ops. The callback can also return errors to indicate + // the execution should fail. + // + // IMPORTANT: the implementation might NOT signal the error to the execution, + // and the execution will run to completion with UNDEFINED DATA returned by + // the callback. If there is any potential control flow that depends on the + // value of the returned data, an error return is unsafe. + // + // TODO(chky): Currently the callback invocation order may not be consistent + // with the HLO send op invocation order, due to limitations in some PjRt + // implementation. Consider making it strictly the same order as HLO program. + std::function + callback; +}; + +struct RecvCallback { + int64_t channel_id; + // The callback for feeding the recv value. It will be invoked once for each + // invocation of the corresponding Recv op in the HLO program (So it can be + // invoked multiple times if it is in a loop). Currently there is no + // guarantee that the callback here will be invoked in the same order as their + // corresponding HLO Recv ops. + std::function stream)> + callback; +}; + +struct ExecuteOptions { + // If true, the client must pass a single PjRtBuffer which contains all of + // the arguments as a single XLA tuple, otherwise each argument must be + // passed in its own PjRtBuffer. May only be true if the executable was + // compiled with parameter_is_tupled_arguments==true. + bool arguments_are_tupled = false; + // If true, the computation must return a tuple, which will be destructured + // into its elements. + bool untuple_result = false; + // If non-zero, identifies this execution as part of a potentially + // multi-device launch. This can be used to detect scheduling errors, e.g. if + // multi-host programs are launched in different orders on different hosts, + // the launch IDs may be used by the runtime to detect the mismatch. + int32_t launch_id = 0; + // If non-null, an opaque context passed to an execution that may be used to + // supply additional arguments to a derived class of PjRtExecutable. + const ExecuteContext* context = nullptr; + // If true, check that the PjRtBuffer argument shapes match the compiled + // shapes. Otherwise, any shape with the right size on device may be passed. + bool strict_shape_checking = true; + + // Set multi_slice_config when the computation spans multiple slices. The + // config should match what was used during compilation to generate this + // executable. + const MultiSliceConfig* multi_slice_config = nullptr; + + // The send/recv callbacks for PjRt execution. The first level span is for + // multi-device parallel execution, the second level vector contains the + // callbacks for all send/recv ops in the executable. These callbacks can be + // stateful and the user code is responsible for managing the states here. + // These callbacks must outlive the execution. + absl::Span> send_callbacks; + absl::Span> recv_callbacks; + + // If true, send callbacks are passed PjRtChunks in major-to-minor layout, and + // recv functions should pass major-to-minor chunks to + // CopyToDeviceStream::AddChunk. + // + // If false, send callbacks are passed PjRtChunks in the on-device layout + // specified in the PjRtTransferMetadata, and recv functions should similarly + // pass device-layout chunks to CopyToDeviceStream::AddChunk. + bool use_major_to_minor_data_layout_for_callbacks = false; + + // The `execution_mode` decides whether the execution will be invoked in the + // caller thread or launched to a separate thread. By default, the + // implementation may choose either strategy or use a heuristic to decide. + // Currently it is only applied to CPU implementations + enum class ExecutionMode { kDefault = 0, kSynchronous, kAsynchronous }; + ExecutionMode execution_mode = ExecutionMode::kDefault; + + // A set of indices denoting the input buffers that should not be donated. + // An input buffer may be non-donable, for example, if it is referenced more + // than once. Since such runtime information is not available at compile time, + // the compiler might mark the input as `may-alias`, which could lead PjRt to + // donate the input buffer when it should not. By defining this set of + // indices, a higher-level PjRt caller can instruct PjRtClient not to donate + // specific input buffers. + absl::flat_hash_set non_donatable_input_indices; + + absl::StatusOr ToProto() const; + static absl::StatusOr FromProto( + const ExecuteOptionsProto& proto); +}; + // Static device memory usage for a compiled program. // The on-device memory needed to run an executable is at least // generated_code_size_in_bytes diff --git a/tensorflow/compiler/xla/pjrt/pjrt_executable_test.cc b/tensorflow/compiler/xla/pjrt/pjrt_executable_test.cc index 747755a05bd529..a480172117d5f7 100644 --- a/tensorflow/compiler/xla/pjrt/pjrt_executable_test.cc +++ b/tensorflow/compiler/xla/pjrt/pjrt_executable_test.cc @@ -14,15 +14,21 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/pjrt/pjrt_executable.h" +#include + +#include #include #include "tensorflow/compiler/xla/client/executable_build_options.h" #include "tensorflow/compiler/xla/pjrt/compile_options.pb.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/tsl/platform/status_matchers.h" namespace xla { namespace { +using ::tsl::testing::StatusIs; + TEST(CompileOptionsTest, Serialization) { CompileOptions src; src.compile_portable_executable = true; @@ -41,15 +47,47 @@ TEST(CompileOptionsTest, Serialization) { EXPECT_EQ(proto.SerializeAsString(), output_proto.SerializeAsString()); } -TEST(FromProtoTest, MultiSliceConfigNotSupported) { +TEST(CompileOptionsTest, MultiSliceConfigNotSupported) { CompileOptionsProto proto; *proto.mutable_serialized_multi_slice_config() = "multi_size_config"; auto option = CompileOptions::FromProto(proto); - EXPECT_EQ(option.status().code(), tensorflow::error::UNIMPLEMENTED); - EXPECT_EQ(option.status().message(), - "multi_slice_config not supported in CompileOptions::FromProto."); + EXPECT_THAT( + option.status(), + StatusIs( + absl::StatusCode::kUnimplemented, + "multi_slice_config not supported in CompileOptions::FromProto.")); +} + +TEST(ExecuteOptionsTest, Serialization) { + ExecuteOptions src; + src.arguments_are_tupled = true; + src.untuple_result = false; + src.launch_id = 1234; + src.strict_shape_checking = true; + src.execution_mode = ExecuteOptions::ExecutionMode::kAsynchronous; + src.non_donatable_input_indices = {2, 3}; + + TF_ASSERT_OK_AND_ASSIGN(ExecuteOptionsProto proto, src.ToProto()); + TF_ASSERT_OK_AND_ASSIGN(ExecuteOptions output, + ExecuteOptions::FromProto(proto)); + TF_ASSERT_OK_AND_ASSIGN(ExecuteOptionsProto output_proto, src.ToProto()); + + EXPECT_EQ(proto.SerializeAsString(), output_proto.SerializeAsString()); +} + +TEST(ExecuteOptionsTest, SendRecvNotSupported) { + ExecuteOptions options; + std::vector> send_callbacks(1); + options.send_callbacks = send_callbacks; + std::vector> recv_callbacks(1); + options.recv_callbacks = recv_callbacks; + + EXPECT_THAT( + options.ToProto(), + StatusIs(absl::StatusCode::kUnimplemented, + "ExecuteOptions with send/recv calbacks is not serializable")); } } // namespace From 9d4e5e2139cf77ca7dfac25665c664b76adc5816 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 10 Jul 2023 09:40:21 -0700 Subject: [PATCH 051/376] Implement `TpuOpExecutable::fingerprint()` functionality via existing C APIs PiperOrigin-RevId: 546898278 --- .../compiler/xla/stream_executor/tpu/tpu_op_executable.cc | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/tensorflow/compiler/xla/stream_executor/tpu/tpu_op_executable.cc b/tensorflow/compiler/xla/stream_executor/tpu/tpu_op_executable.cc index 429dc6a1ce3756..35b75ac9e95d1a 100644 --- a/tensorflow/compiler/xla/stream_executor/tpu/tpu_op_executable.cc +++ b/tensorflow/compiler/xla/stream_executor/tpu/tpu_op_executable.cc @@ -15,8 +15,6 @@ limitations under the License. #include "tensorflow/compiler/xla/stream_executor/tpu/tpu_op_executable.h" -#include -#include #include #include "tensorflow/compiler/xla/status.h" @@ -24,7 +22,6 @@ limitations under the License. #include "tensorflow/compiler/xla/stream_executor/tpu/proto_helper.h" #include "tensorflow/compiler/xla/stream_executor/tpu/status_helper.h" #include "tensorflow/compiler/xla/stream_executor/tpu/tpu_api.h" -#include "tensorflow/compiler/xla/stream_executor/tpu/tpu_ops_c_api.h" #include "tensorflow/compiler/xla/stream_executor/tpu/tpu_platform.h" #include "tensorflow/compiler/xla/stream_executor/tpu/tpu_platform_interface.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -111,8 +108,8 @@ xla::Status TpuOpExecutable::LoadProgramAndEnqueueToStream( } absl::string_view TpuOpExecutable::fingerprint() const { - TpuProgramFingerprint fingerprint = TpuProgram_GetFingerprint(core_program_); - return absl::string_view(fingerprint.bytes, fingerprint.size); + // TODO(skye): the fingerprint can be plumbed through via core_program_ + LOG(FATAL) << "TpuOpExecutable::fingerprint() unimplemented"; } } // namespace tensorflow From 38b006d10700ab58dc9da0f7870d6981313f4bc2 Mon Sep 17 00:00:00 2001 From: Ilia Sergachev Date: Mon, 10 Jul 2023 09:48:09 -0700 Subject: [PATCH 052/376] [XLA:GPU] Use the new runtime in autotuning. The old runtime is not regularly tested anymore and contains at least one known cuBLAS-related bug. PiperOrigin-RevId: 546900287 --- tensorflow/compiler/xla/service/gpu/autotuner_compile_util.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/tensorflow/compiler/xla/service/gpu/autotuner_compile_util.cc b/tensorflow/compiler/xla/service/gpu/autotuner_compile_util.cc index 7555b6fcf69e3f..e5ee60c7f73860 100644 --- a/tensorflow/compiler/xla/service/gpu/autotuner_compile_util.cc +++ b/tensorflow/compiler/xla/service/gpu/autotuner_compile_util.cc @@ -214,7 +214,6 @@ StatusOr> AutotunerCompileUtil::RunBackend( options.set_xla_gpu_dump_llvmir(false); // Avoid using another thread pool. options.set_xla_gpu_force_compilation_parallelism(1); - options.set_xla_gpu_enable_xla_runtime_executable(false); module->config().set_debug_options(options); StatusOr> out = compiler_->RunBackend(std::move(module), &stream_executor_, &allocator_); From 33f241d20d5b0a86d80e16098fc2716c2e9662ef Mon Sep 17 00:00:00 2001 From: Terry Heo Date: Mon, 10 Jul 2023 09:59:37 -0700 Subject: [PATCH 053/376] Disable distributed_save_ft_test in macos The test is flaky in MacOs. PiperOrigin-RevId: 546903479 --- tensorflow/python/data/experimental/kernel_tests/service/BUILD | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tensorflow/python/data/experimental/kernel_tests/service/BUILD b/tensorflow/python/data/experimental/kernel_tests/service/BUILD index 8628abfb0972d5..8d36d2e3637093 100644 --- a/tensorflow/python/data/experimental/kernel_tests/service/BUILD +++ b/tensorflow/python/data/experimental/kernel_tests/service/BUILD @@ -245,6 +245,9 @@ tf_py_strict_test( size = "medium", srcs = ["distributed_save_ft_test.py"], shard_count = 17, + tags = [ + "no_mac", # TODO(b/290355883): Fix the flakyness in macos + ], deps = [ ":test_base", "//tensorflow/python/data/experimental/ops:distributed_save_op", From b23e0dbda8c250e4ac0dc13114f0b9d8da06e357 Mon Sep 17 00:00:00 2001 From: Austin Anderson Date: Mon, 10 Jul 2023 10:07:16 -0700 Subject: [PATCH 054/376] Fix missed if statement --- ci/official/any.sh | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ci/official/any.sh b/ci/official/any.sh index 0548e4e9e1bccd..ea00c6b9c32717 100755 --- a/ci/official/any.sh +++ b/ci/official/any.sh @@ -23,7 +23,9 @@ fi if [[ "${PIP_WHEEL}" -eq "1" ]]; then # Update the version numbers to build a "nightly" package - [[ "$TFCI_NIGHTLY_UPDATE_VERSION_ENABLE" == 1 ]] && tfrun python3 tensorflow/tools/ci_build/update_version.py --nightly + if [[ "$TFCI_NIGHTLY_UPDATE_VERSION_ENABLE" == 1 ]]; then + tfrun python3 tensorflow/tools/ci_build/update_version.py --nightly + fi tfrun bazel "${TFCI_BAZEL_BAZELRC_ARGS[@]}" build "${TFCI_BAZEL_COMMON_ARGS[@]}" tensorflow/tools/pip_package:build_pip_package tfrun ./bazel-bin/tensorflow/tools/pip_package/build_pip_package build "${TFCI_BUILD_PIP_PACKAGE_ARGS[@]}" From 4ed5214abf93ff0324185aad5cd49c1db3cc838d Mon Sep 17 00:00:00 2001 From: Juan Martinez Castellanos Date: Mon, 10 Jul 2023 10:26:46 -0700 Subject: [PATCH 055/376] Make all Python targets under tensorflow/examples/custom_ops_doc/multiplex_1/ have strict dependencies. PiperOrigin-RevId: 546911576 --- tensorflow/examples/custom_ops_doc/multiplex_1/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/examples/custom_ops_doc/multiplex_1/README.md b/tensorflow/examples/custom_ops_doc/multiplex_1/README.md index e9800a5f4afd0f..15cac59ad9c28d 100644 --- a/tensorflow/examples/custom_ops_doc/multiplex_1/README.md +++ b/tensorflow/examples/custom_ops_doc/multiplex_1/README.md @@ -338,7 +338,7 @@ py_strict_library( ], ) -tf_py_test( +tf_py_strict_test( name = "multiplex_1_test", size = "small", srcs = ["multiplex_1_test.py"], @@ -399,5 +399,5 @@ Op components | Build rule | Build target Kernels (C++) | `tf_custom_op_library` | `multiplex_1_kernel` | `multiplex_1_kernel.cc`, `multiplex_1_op.cc` Wrapper (automatically generated) | N/A | `gen_multiplex_1_op` | N/A Wrapper (with public API and docstring) | `py_strict_library` | `multiplex_1_op` | `multiplex_1_op.py` -Tests | `tf_py_test` | `multiplex_1_test` | `multiplex_1_test.py` +Tests | `tf_py_strict_test` | `multiplex_1_test` | `multiplex_1_test.py` From d44691333d3f6ebbe92dbe328356ed17dae9e553 Mon Sep 17 00:00:00 2001 From: Juan Martinez Castellanos Date: Mon, 10 Jul 2023 10:26:50 -0700 Subject: [PATCH 056/376] Make all Python targets under tensorflow/compiler/xla/service/ have strict dependencies. PiperOrigin-RevId: 546911604 --- tensorflow/compiler/xla/service/BUILD | 7 ++++--- tensorflow/compiler/xla/strict.default.bzl | 13 +++++++++++++ 2 files changed, 17 insertions(+), 3 deletions(-) create mode 100644 tensorflow/compiler/xla/strict.default.bzl diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 695bb3a0e0327c..41ea31d95a685b 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -1,6 +1,7 @@ # Description: # XLA service implementation. +load("//tensorflow:strict.default.bzl", "py_strict_library", "py_strict_test") load("//tensorflow/tsl/platform:rules_cc.bzl", "cc_library") load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test") load( @@ -6424,19 +6425,19 @@ xla_py_proto_library( deps = [":hlo_proto"], ) -py_library( +py_strict_library( name = "generate_test_hlo_checks", srcs = ["generate_test_hlo_checks.py"], srcs_version = "PY3", ) -py_test( +py_strict_test( name = "generate_test_hlo_checks_test", srcs = ["generate_test_hlo_checks_test.py"], python_version = "PY3", - # TODO(b/200806426): Test fails in OSS. tags = [ "no_oss", + "nopip", ], deps = [ ":generate_test_hlo_checks", diff --git a/tensorflow/compiler/xla/strict.default.bzl b/tensorflow/compiler/xla/strict.default.bzl new file mode 100644 index 00000000000000..2042d4a98d05fb --- /dev/null +++ b/tensorflow/compiler/xla/strict.default.bzl @@ -0,0 +1,13 @@ +"""Default (OSS) build versions of Python strict rules.""" + +# Placeholder to use until bazel supports py_strict_binary. +def py_strict_binary(name, **kwargs): + native.py_binary(name = name, **kwargs) + +# Placeholder to use until bazel supports py_strict_library. +def py_strict_library(name, **kwargs): + native.py_library(name = name, **kwargs) + +# Placeholder to use until bazel supports py_strict_test. +def py_strict_test(name, **kwargs): + native.py_test(name = name, **kwargs) From bc308afbc8485f528b42e24c4657f4d7690bc5c7 Mon Sep 17 00:00:00 2001 From: Fiona Lang Date: Mon, 10 Jul 2023 10:27:27 -0700 Subject: [PATCH 057/376] Remove legacy references to tensor.disable_tensor_equality and tensor.enable_tensor_equality. PiperOrigin-RevId: 546911771 --- tensorflow/python/compat/BUILD | 1 + tensorflow/python/compat/v2_compat.py | 5 +- tensorflow/python/eager/BUILD | 1 + tensorflow/python/eager/core_test.py | 49 ++++++++++--------- tensorflow/python/framework/ops.py | 2 - tensorflow/python/ops/numpy_ops/BUILD | 1 + .../python/ops/numpy_ops/np_math_ops_test.py | 3 +- tensorflow/python/ops/parallel_for/BUILD | 1 + .../python/ops/parallel_for/math_test.py | 7 +-- tensorflow/python/util/BUILD | 2 +- tensorflow/python/util/deprecation_test.py | 6 +-- 11 files changed, 42 insertions(+), 36 deletions(-) diff --git a/tensorflow/python/compat/BUILD b/tensorflow/python/compat/BUILD index 5d6100d4f01210..404927a33501fe 100644 --- a/tensorflow/python/compat/BUILD +++ b/tensorflow/python/compat/BUILD @@ -21,6 +21,7 @@ py_strict_library( "//tensorflow/python/data/ops:readers", "//tensorflow/python/eager:monitoring", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/ops:control_flow_v2_toggles", "//tensorflow/python/ops:variable_scope", diff --git a/tensorflow/python/compat/v2_compat.py b/tensorflow/python/compat/v2_compat.py index 179f64008961dc..481c3c69e34855 100644 --- a/tensorflow/python/compat/v2_compat.py +++ b/tensorflow/python/compat/v2_compat.py @@ -23,6 +23,7 @@ from tensorflow.python.data.ops import readers from tensorflow.python.eager import monitoring from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import control_flow_v2_toggles from tensorflow.python.ops import variable_scope @@ -60,7 +61,7 @@ def enable_v2_behavior(): ops.enable_eager_execution() tensor_shape.enable_v2_tensorshape() # Also switched by tf2 variable_scope.enable_resource_variables() - ops.enable_tensor_equality() + tensor.enable_tensor_equality() # Enables TensorArrayV2 and control flow V2. control_flow_v2_toggles.enable_control_flow_v2() # Make sure internal uses of tf.data symbols map to V2 versions. @@ -105,7 +106,7 @@ def disable_v2_behavior(): ops.disable_eager_execution() tensor_shape.disable_v2_tensorshape() # Also switched by tf2 variable_scope.disable_resource_variables() - ops.disable_tensor_equality() + tensor.disable_tensor_equality() # Disables TensorArrayV2 and control flow V2. control_flow_v2_toggles.disable_control_flow_v2() # Make sure internal uses of tf.data symbols map to V1 versions. diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD index dc80c625633100..9ad4d8c0d58c9a 100644 --- a/tensorflow/python/eager/BUILD +++ b/tensorflow/python/eager/BUILD @@ -574,6 +574,7 @@ cuda_py_strict_test( "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:errors", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:test_lib", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:nn_ops", diff --git a/tensorflow/python/eager/core_test.py b/tensorflow/python/eager/core_test.py index 64a2a6ad06d6ae..6e919d6deab965 100644 --- a/tensorflow/python/eager/core_test.py +++ b/tensorflow/python/eager/core_test.py @@ -35,6 +35,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_resource_variable_ops @@ -91,7 +92,7 @@ def _test_hashable(self, a, b, hashable): set([a, b]) def testEquality(self): - default = ops.Tensor._USE_EQUALITY + default = tensor_lib.Tensor._USE_EQUALITY try: def _v1_check(a, b): @@ -113,20 +114,20 @@ def _v2_check(a, b): constant_a = constant_op.constant(1.0) constant_b = constant_op.constant(1.0) - ops.disable_tensor_equality() + tensor_lib.disable_tensor_equality() self._test_hashable(constant_a, constant_b, False) _v1_check(constant_a, constant_b) - ops.enable_tensor_equality() + tensor_lib.enable_tensor_equality() _v2_check(constant_a, constant_b) self._test_hashable(constant_a, constant_b, False) variable_a = variables.Variable(1.0) variable_b = variables.Variable(1.0) - ops.disable_tensor_equality() + tensor_lib.disable_tensor_equality() _v1_check(variable_a, variable_b) self._test_hashable(variable_a, variable_b, True) - ops.enable_tensor_equality() + tensor_lib.enable_tensor_equality() _v2_check(variable_a, variable_b) self._test_hashable(variable_a, variable_b, False) @@ -137,12 +138,12 @@ def _v2_check(a, b): self._test_hashable(numpy_a, numpy_b, False) finally: if default: - ops.enable_tensor_equality() + tensor_lib.enable_tensor_equality() else: - ops.disable_tensor_equality() + tensor_lib.disable_tensor_equality() def testEqualityNan(self): - default = ops.Tensor._USE_EQUALITY + default = tensor_lib.Tensor._USE_EQUALITY try: def _v1_check(a, b): @@ -164,20 +165,20 @@ def _v2_check(a, b): constant_a = constant_op.constant(float('nan')) constant_b = constant_op.constant(float('nan')) - ops.disable_tensor_equality() + tensor_lib.disable_tensor_equality() self._test_hashable(constant_a, constant_b, False) _v1_check(constant_a, constant_b) - ops.enable_tensor_equality() + tensor_lib.enable_tensor_equality() _v2_check(constant_a, constant_b) self._test_hashable(constant_a, constant_b, False) variable_a = variables.Variable(float('nan')) variable_b = variables.Variable(float('nan')) - ops.disable_tensor_equality() + tensor_lib.disable_tensor_equality() _v1_check(variable_a, variable_b) self._test_hashable(variable_a, variable_b, True) - ops.enable_tensor_equality() + tensor_lib.enable_tensor_equality() _v2_check(variable_a, variable_b) self._test_hashable(variable_a, variable_b, False) @@ -187,12 +188,12 @@ def _v2_check(a, b): self._test_hashable(numpy_a, numpy_b, False) finally: if default: - ops.enable_tensor_equality() + tensor_lib.enable_tensor_equality() else: - ops.disable_tensor_equality() + tensor_lib.disable_tensor_equality() def testEqualityCompare(self): - default = ops.Tensor._USE_EQUALITY + default = tensor_lib.Tensor._USE_EQUALITY try: tf_a = constant_op.constant([1, 2]) @@ -202,7 +203,7 @@ def testEqualityCompare(self): np_b = np.array([1, 2]) np_c = np.array([1, 1]) - ops.disable_tensor_equality() + tensor_lib.disable_tensor_equality() # We don't do element-wise comparison self.assertNotEqual(tf_a, tf_b) self.assertNotEqual(tf_a, tf_c) @@ -216,7 +217,7 @@ def testEqualityCompare(self): self.assertIn(tf_a, [tf_b, tf_a]) self.assertNotIn(tf_a, [tf_b, tf_c]) - ops.enable_tensor_equality() + tensor_lib.enable_tensor_equality() # We do element-wise comparison but can't convert results array to bool with self.assertRaises(ValueError): bool(tf_a == tf_b) @@ -266,12 +267,12 @@ def testEqualityCompare(self): self.assertAllEqual(np.array(1) == np.array(2), False) finally: if default: - ops.enable_tensor_equality() + tensor_lib.enable_tensor_equality() else: - ops.disable_tensor_equality() + tensor_lib.disable_tensor_equality() def testEqualityBroadcast(self): - default = ops.Tensor._USE_EQUALITY + default = tensor_lib.Tensor._USE_EQUALITY try: tf_a = constant_op.constant([1, 1]) @@ -285,13 +286,13 @@ def testEqualityBroadcast(self): np_d = np.array([[1, 2], [1, 2]]) np_e = np.array([1, 1, 1]) - ops.disable_tensor_equality() + tensor_lib.disable_tensor_equality() # We don't do element-wise comparison self.assertNotEqual(tf_a, tf_b) self.assertNotEqual(tf_a, tf_c) self.assertNotEqual(tf_a, tf_d) - ops.enable_tensor_equality() + tensor_lib.enable_tensor_equality() # We do element-wise comparison but can't convert results array to bool with self.assertRaises(ValueError): bool(tf_a == tf_b) @@ -322,9 +323,9 @@ def testEqualityBroadcast(self): self.assertNotAllEqual(np_a, np_e) finally: if default: - ops.enable_tensor_equality() + tensor_lib.enable_tensor_equality() else: - ops.disable_tensor_equality() + tensor_lib.disable_tensor_equality() @test_util.disable_tfrt('Get execution mode not supported in TFRT.') def testContext(self): diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index de09610cce1ded..5f23e0654496f3 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -234,8 +234,6 @@ def value_text(tensor, is_repr=False): return text -enable_tensor_equality = tensor_lib.enable_tensor_equality -disable_tensor_equality = tensor_lib.disable_tensor_equality Tensor = tensor_lib.Tensor diff --git a/tensorflow/python/ops/numpy_ops/BUILD b/tensorflow/python/ops/numpy_ops/BUILD index fc6e6d7b420141..73b52041d31d9a 100644 --- a/tensorflow/python/ops/numpy_ops/BUILD +++ b/tensorflow/python/ops/numpy_ops/BUILD @@ -231,6 +231,7 @@ cuda_py_strict_test( ":np_math_ops", "//tensorflow/python/framework:errors", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/platform:client_testlib", "//third_party/py/numpy", "@absl_py//absl/testing:parameterized", diff --git a/tensorflow/python/ops/numpy_ops/np_math_ops_test.py b/tensorflow/python/ops/numpy_ops/np_math_ops_test.py index 534a1dc9335f39..2a6b6368e8fb14 100644 --- a/tensorflow/python/ops/numpy_ops/np_math_ops_test.py +++ b/tensorflow/python/ops/numpy_ops/np_math_ops_test.py @@ -20,6 +20,7 @@ from tensorflow.python.framework import errors from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.ops.numpy_ops import np_array_ops from tensorflow.python.ops.numpy_ops import np_arrays from tensorflow.python.ops.numpy_ops import np_math_ops @@ -377,7 +378,7 @@ def testIsInf(self): self.assertFalse(np_math_ops.isneginf(x2)) if __name__ == '__main__': - ops.enable_tensor_equality() + tensor.enable_tensor_equality() ops.enable_eager_execution() ops.set_dtype_conversion_mode('legacy') np_math_ops.enable_numpy_methods_on_tensor() diff --git a/tensorflow/python/ops/parallel_for/BUILD b/tensorflow/python/ops/parallel_for/BUILD index 2e3a3ad2f00fe7..5deb99e8a1af21 100644 --- a/tensorflow/python/ops/parallel_for/BUILD +++ b/tensorflow/python/ops/parallel_for/BUILD @@ -256,6 +256,7 @@ cuda_py_strict_test( "//tensorflow/python/framework:constant_op", "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:test_lib", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:clip_ops", diff --git a/tensorflow/python/ops/parallel_for/math_test.py b/tensorflow/python/ops/parallel_for/math_test.py index 932e07bde749aa..240c94fbdd4077 100644 --- a/tensorflow/python/ops/parallel_for/math_test.py +++ b/tensorflow/python/ops/parallel_for/math_test.py @@ -20,6 +20,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops as framework_ops +from tensorflow.python.framework import tensor from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import clip_ops @@ -149,8 +150,8 @@ def loop_fn(i): def test_binary_cwise_ops(self): # Enable tensor equality to test `equal` and `not_equal` ops below. - default_equality = framework_ops.Tensor._USE_EQUALITY - framework_ops.enable_tensor_equality() + default_equality = tensor.Tensor._USE_EQUALITY + tensor.enable_tensor_equality() try: logical_ops = [ math_ops.logical_and, math_ops.logical_or, math_ops.logical_xor @@ -225,7 +226,7 @@ def loop_fn(i): self._test_loop_fn(loop_fn, 3) finally: if not default_equality: - framework_ops.disable_tensor_equality() + tensor.disable_tensor_equality() def test_approximate_equal(self): x = random_ops.random_uniform([3, 5]) diff --git a/tensorflow/python/util/BUILD b/tensorflow/python/util/BUILD index fa421e03215c13..61b5377ecc8f41 100644 --- a/tensorflow/python/util/BUILD +++ b/tensorflow/python/util/BUILD @@ -253,7 +253,7 @@ tf_py_strict_test( ":tf_inspect", "//tensorflow/python/eager:context", "//tensorflow/python/framework:constant_op", - "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:test_lib", "//tensorflow/python/ops:variables", "//tensorflow/python/platform:client_testlib", diff --git a/tensorflow/python/util/deprecation_test.py b/tensorflow/python/util/deprecation_test.py index 024ef220260417..898af79480875e 100644 --- a/tensorflow/python/util/deprecation_test.py +++ b/tensorflow/python/util/deprecation_test.py @@ -23,7 +23,7 @@ from tensorflow.python.eager import context from tensorflow.python.framework import constant_op -from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.framework import test_util from tensorflow.python.ops import variables from tensorflow.python.platform import test @@ -1062,14 +1062,14 @@ def test_deprecated_arg_values_when_value_is_none(self, mock_warning): def _fn(arg0): # pylint: disable=unused-argument pass - ops.enable_tensor_equality() + tensor.enable_tensor_equality() initial_count = mock_warning.call_count # Check that we avoid error from explicit `var == None` check. _fn(arg0=variables.Variable(0)) self.assertEqual(initial_count, mock_warning.call_count) _fn(arg0=None) self.assertEqual(initial_count + 1, mock_warning.call_count) - ops.disable_tensor_equality() + tensor.disable_tensor_equality() class DeprecationArgumentsTest(test.TestCase): From 1cd73838f75d83b773bc8bb94960ac0c45ec6f16 Mon Sep 17 00:00:00 2001 From: Juan Martinez Castellanos Date: Mon, 10 Jul 2023 10:30:35 -0700 Subject: [PATCH 058/376] Make all Python targets under tensorflow/compiler/jit/* have strict dependencies. PiperOrigin-RevId: 546912688 --- tensorflow/compiler/jit/BUILD | 4 ++-- tensorflow/compiler/jit/ops/BUILD | 10 +++++++++- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index 0d06315e3b4c57..ab84540ec8c683 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -1,7 +1,7 @@ load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load("//tensorflow:tensorflow.bzl", "if_libtpu", "if_with_tpu_support", "tf_cc_test", "tf_copts", "tf_cuda_cc_test", "tf_cuda_only_cc_test") load("//tensorflow/compiler/xla/stream_executor:build_defs.bzl", "if_cuda_or_rocm") -load("//tensorflow:tensorflow.default.bzl", "cc_header_only_library", "filegroup", "tf_custom_op_py_library", "tf_jit_compilation_passes_extra_deps") +load("//tensorflow:tensorflow.default.bzl", "cc_header_only_library", "filegroup", "tf_custom_op_py_strict_library", "tf_jit_compilation_passes_extra_deps") load("//tensorflow/core/platform:build_config.bzl", "tf_additional_all_protos", "tf_proto_library") load( "//tensorflow/core/platform:build_config_root.bzl", @@ -1351,7 +1351,7 @@ tf_cc_test( ], ) -tf_custom_op_py_library( +tf_custom_op_py_strict_library( name = "xla_ops_py", kernels = ["//tensorflow/compiler/jit/ops:xla_ops"], visibility = [ diff --git a/tensorflow/compiler/jit/ops/BUILD b/tensorflow/compiler/jit/ops/BUILD index 1059a263d57f43..a2c4bbd466848c 100644 --- a/tensorflow/compiler/jit/ops/BUILD +++ b/tensorflow/compiler/jit/ops/BUILD @@ -1,3 +1,4 @@ +load("//tensorflow:strict.default.bzl", "py_strict_library") load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py") @@ -17,10 +18,17 @@ cc_library( tf_gen_op_wrapper_py( name = "xla_ops_wrapper_py", out = "xla_ops.py", + extra_py_deps = [ + "//tensorflow/python:pywrap_tfe", + "//tensorflow/python/util:dispatch", + "//tensorflow/python/util:deprecation", + "//tensorflow/python/util:tf_export", + ], + py_lib_rule = py_strict_library, deps = ["//tensorflow/compiler/jit/ops:xla_ops"], ) -py_library( +py_strict_library( name = "xla_ops_grad", srcs = ["xla_ops_grad.py"], srcs_version = "PY3", From 2ca49454875b328e1f74b3c9051103f75ad619d7 Mon Sep 17 00:00:00 2001 From: Justin Lebar Date: Mon, 10 Jul 2023 10:32:55 -0700 Subject: [PATCH 059/376] [XLA:GPU] Require sm80 for cudnn_fused_conv_rewriter_test. This test skips some testcases when running on pre-Ampere GPUs. To ensure that it's not flaky, we need to mark it as requiring Ampere. We also change the test so it doesn't skip tests it doesn't need to. In particular, if you're just *compiling* but not *running* the operation, then you shouldn't care what GPU is in the machine. PiperOrigin-RevId: 546913352 --- tensorflow/compiler/xla/service/gpu/BUILD | 3 ++- .../xla/service/gpu/cudnn_fused_conv_rewriter_test.cc | 8 ++------ 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index a7ed398aace948..56295b88908337 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -3171,7 +3171,8 @@ xla_cc_test( "no_oss", "noasan", "nomsan", - "requires-gpu-sm70", + # This test runs some fusions that are only supported on Ampere+. + "requires-gpu-sm80", ], deps = [ ":backend_configs_cc", diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter_test.cc b/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter_test.cc index 1ce89568c4948e..3ef05e5fd399ea 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter_test.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter_test.cc @@ -931,11 +931,6 @@ TEST_F(CudnnFusedConvRewriterHloTest, DontFuseReluIfMultipleUses) { } TEST_F(CudnnFusedConvRewriterHloTest, FuseElu) { - if (!GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE)) { - GTEST_SKIP() << "Conv-Bias-Elu fusion is supported and recommended with " - "the Nvidia Ampere+ GPUs."; - } const std::string module_str = R"( HloModule Test @@ -958,7 +953,8 @@ TEST_F(CudnnFusedConvRewriterHloTest, FuseElu) { GpuConvRewriter rewriter; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); - CudnnFusedConvRewriter fuser{GetCudaComputeCapability()}; + // elu fusion is only active on Ampere+. + CudnnFusedConvRewriter fuser{se::CudaComputeCapability(8, 0)}; TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); SCOPED_TRACE(m->ToString()); From 791fb52376b9e177a1283c9e407ae61aed7ccf52 Mon Sep 17 00:00:00 2001 From: Kuangyuan Chen Date: Mon, 10 Jul 2023 10:35:43 -0700 Subject: [PATCH 060/376] Add tf.PwStreamResults op to tensorflow for general streaming support PiperOrigin-RevId: 546914204 --- .../compiler/mlir/tensorflow/ir/tfrt_ops.cc | 13 ++++++++ .../compiler/mlir/tensorflow/ir/tfrt_ops.td | 33 +++++++++++++++++++ .../mlir/tensorflow/tests/tfrt_ops.mlir | 15 +++++++++ .../transforms/legalization_op_config_test.cc | 2 +- 4 files changed, 62 insertions(+), 1 deletion(-) create mode 100644 tensorflow/compiler/mlir/tensorflow/tests/tfrt_ops.mlir diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tfrt_ops.cc b/tensorflow/compiler/mlir/tensorflow/ir/tfrt_ops.cc index 01c6c6c68b38ad..8cce823ae5233c 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tfrt_ops.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tfrt_ops.cc @@ -71,6 +71,19 @@ LogicalResult _TfrtGetResourceOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// PwStreamResults +//===----------------------------------------------------------------------===// + +mlir::LogicalResult PwStreamResultsOp::verify() { + if (getArgs().size() != getNames().size()) { + return emitOpError() + << "has a mismatch between the number of arguments and their names (" + << getArgs().size() << " vs. " << getNames().size() << ")"; + } + return mlir::success(); +} + } // namespace TF } // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tfrt_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tfrt_ops.td index bd6f35db525ffa..a0e2935255e95a 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tfrt_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tfrt_ops.td @@ -63,4 +63,37 @@ def TF__TfrtGetResourceOp : TF_Op<"_TfrtGetResource", let hasVerifier = 1; } +// TODO(chky): Consider adding this op to tensorflow core ops. +def TF_PwStreamResultsOp : TF_Op<"PwStreamResults"> { + let summary = "Streams results back to the controller"; + + let description = [{ + This op is a TensorFlow op that represents "streamed outputs", where + intermediate results can be returned immediately without waiting for the + entire signature computation to complete. + + This op takes `args` with their `names` (their cardinality must match) and + sends the given argument tensors back to the serving controller. This + triggers a controller-side stream callback (see `ScopedStreamCallback`). + + In addition to the listed attributes, this op has two "hidden" attributes + that do not exist in SavedModel but are dynamically populated by the serving + runtime: + + * `_controller_address`: Address of the remote instance to which tensors + will be sent via e.g. RPC. + * `_callback_id`: Identifier for the callback to be called from the + controller. See `ScopedStreamCallback`. + }]; + + let arguments = (ins + Variadic : $args, + StrArrayAttr : $names + ); + + TF_DerivedOperandTypeListAttr T = TF_DerivedOperandTypeListAttr<0>; + + let hasVerifier = 1; +} + #endif // TFRT_OPS diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tfrt_ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tfrt_ops.mlir new file mode 100644 index 00000000000000..3fb11e56172276 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/tfrt_ops.mlir @@ -0,0 +1,15 @@ +// RUN: tf-opt %s -split-input-file -verify-diagnostics | FileCheck %s + +// Tests for TensorFlow TFRT ops with custom verifiers. + +//===--------------------------------------------------------------------===// +// Test TF operations (tf.*) +//===--------------------------------------------------------------------===// + +// CHECK-LABEL: func @testPwStreamResults +func.func @testPwStreamResults(%arg0: tensor, %arg1: tensor) { + "tf.PwStreamResults"(%arg0, %arg1) {names = ["foo", "bar"]} : (tensor, tensor) -> () + return +} + +// ----- diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config_test.cc b/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config_test.cc index 8130518f2aef73..79220db6087a14 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config_test.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config_test.cc @@ -131,7 +131,7 @@ TEST_F(LegalizationOpConfigTest, CountLoweringsSet) { // a new op, we should expect these to change too. EXPECT_EQ(mlir_lowering_count, 71); EXPECT_EQ(tf2xla_fallback_count, 295); - EXPECT_EQ(non_categorized_count, 418); + EXPECT_EQ(non_categorized_count, 419); } // Just a counter test to see which ops have duplicate lowerings. This isn't a From 9091bf73b3ed05d2d121870e160ec8ef810998d6 Mon Sep 17 00:00:00 2001 From: Juan Martinez Castellanos Date: Mon, 10 Jul 2023 10:39:29 -0700 Subject: [PATCH 061/376] Make all targets under tensorflow/lite/toco/python/ have strict dependencies. PiperOrigin-RevId: 546915251 --- tensorflow/lite/toco/python/BUILD | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/tensorflow/lite/toco/python/BUILD b/tensorflow/lite/toco/python/BUILD index 774f849c702b68..3a172ea0613691 100644 --- a/tensorflow/lite/toco/python/BUILD +++ b/tensorflow/lite/toco/python/BUILD @@ -1,5 +1,4 @@ -load("//tensorflow:tensorflow.default.bzl", "tf_py_strict_test") -load("//tensorflow:strict.default.bzl", "py_strict_binary", "py_strict_library") +load("//tensorflow:tensorflow.bzl", "py_binary", "tf_py_test") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -78,7 +77,7 @@ cc_library( ) # Compatibility stub. Remove when internal customers moved. -py_strict_library( +py_library( name = "tensorflow_wrap_toco", srcs = ["tensorflow_wrap_toco.py"], srcs_version = "PY3", @@ -92,7 +91,7 @@ py_strict_library( ], ) -py_strict_binary( +py_binary( name = "toco_from_protos", srcs = ["toco_from_protos.py"], python_version = "PY3", @@ -106,7 +105,7 @@ py_strict_binary( ], ) -tf_py_strict_test( +tf_py_test( name = "toco_from_protos_test", srcs = ["toco_from_protos_test.py"], python_version = "PY3", @@ -114,8 +113,7 @@ tf_py_strict_test( "no_oss", ], deps = [ - "//tensorflow:tensorflow_py_no_contrib", - "//tensorflow/core:protos_all_py", + "//tensorflow:tensorflow_py", "//tensorflow/lite/toco:model_flags_proto_py", "//tensorflow/lite/toco:toco_flags_proto_py", "//tensorflow/python/platform:resource_loader", From 31125b6ccbb5e6ce17b06a76ef40bf700b85ee8b Mon Sep 17 00:00:00 2001 From: Russell Power Date: Mon, 10 Jul 2023 10:46:46 -0700 Subject: [PATCH 062/376] Cache computed control dependencies during control-dep optimization. PiperOrigin-RevId: 546917420 --- .../transforms/update_control_dependencies.cc | 126 +++++++++++++++--- 1 file changed, 104 insertions(+), 22 deletions(-) diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/update_control_dependencies.cc b/tensorflow/compiler/mlir/tensorflow/transforms/update_control_dependencies.cc index d4a05aae890ecd..484523dfd7c912 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/update_control_dependencies.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/update_control_dependencies.cc @@ -18,7 +18,6 @@ limitations under the License. #include #include -#include "absl/container/btree_set.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "llvm/ADT/SmallVector.h" @@ -55,6 +54,36 @@ using OpToParallelIdsMap = using OpToOpsMap = absl::flat_hash_map>; +// Many operations have the same dependency and parallel id set. We cache the +// processed result of these operations to speed execution. +struct OpCacheEntry { + Operation* template_op; + llvm::SmallVector preds_in_reverse_program_order; +}; + +struct OpCacheKey { + const llvm::SmallVector deps; + const GroupIdToBranchIdMap& group_id_to_branch_id_map; + + template + friend H AbslHashValue(H h, const OpCacheKey& c) { + for (Operation* dep : c.deps) { + h = H::combine(std::move(h), dep); + } + for (auto [group_id, branch_id] : c.group_id_to_branch_id_map) { + h = H::combine(std::move(h), group_id, branch_id); + } + return h; + } + + bool operator==(const OpCacheKey& other) const { + return deps == other.deps && + group_id_to_branch_id_map == other.group_id_to_branch_id_map; + } +}; + +using OpCache = absl::flat_hash_map; + #define GEN_PASS_DEF_EXECUTORUPDATECONTROLDEPENDENCIESPASS #include "tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.h.inc" @@ -169,27 +198,63 @@ LogicalResult FillOpToParallelIdsMap( // Computes and sets direct control inputs for `op`. Also fills // `active_transitive_preds` and `inactive_transitive_preds` for `op`. -void -UpdateControlDependenciesForOp( +// +// `active_transitive_preds` are those dominated by `op`: taking a dependency +// on `op` will also ensure all `active_transitive_preds[op]` are waited +// for. +// +// `inactive_transitive_preds` are transitive dependencies of op in the original +// graph but are not dominated by `op`. (They run in a different parallel +// execution group). They must be separately considered when processing +// successor operations. +void UpdateControlDependenciesForOp( Operation* op, const TF::SideEffectAnalysis::Info& analysis_for_func, const OpToParallelIdsMap& op_to_parallel_ids_map, + OpCache& op_cache, OpToOpsMap& active_transitive_preds, OpToOpsMap& inactive_transitive_preds, int& num_control_inputs_removed, int& num_control_inputs_added, int& num_invalid_dependencies) { + auto& op_inactive = inactive_transitive_preds[op]; + auto& op_active = active_transitive_preds[op]; + + llvm::SmallVector control_deps = + analysis_for_func.DirectControlPredecessors(op); + OpCacheKey key = { + control_deps, + GetGroupIdToBranchIdMap(op, op_to_parallel_ids_map) + }; + + // We matched with another op in the cache. We will have the same active and + // inactive dependency sets and control inputs, except we swap out our current + // op for the template op in the active set. + if (op_cache.contains(key)) { + auto& entry = op_cache[key]; + op_active = active_transitive_preds[entry.template_op]; + op_active.insert(op); + op_active.erase(entry.template_op); + + op_inactive = inactive_transitive_preds[entry.template_op]; + ClearControlInputs(op, num_control_inputs_removed); + SetControlInputs(op, entry.preds_in_reverse_program_order, + num_control_inputs_added); + return; + } + + op_active.insert(op); + + // First iterate over all direct control dependencies and collect the set of + // potential active dependencies. absl::flat_hash_set pred_set; - active_transitive_preds[op].insert(op); - for (Operation* pred : analysis_for_func.DirectControlPredecessors(op)) { - // Propagate inactive transitive dependencies from `pred` to `op`. - inactive_transitive_preds[op].insert( - inactive_transitive_preds[pred].begin(), - inactive_transitive_preds[pred].end()); + for (Operation* pred : control_deps) { // Inactive transitive predecessors of `pred` are potential direct // predecessors of `op` (they are not tracked by `pred`). for (Operation* transitive_pred : inactive_transitive_preds[pred]) { pred_set.insert(transitive_pred); + op_inactive.insert(transitive_pred); } + if (IsValidDependency(pred, op, op_to_parallel_ids_map)) { // We know that any active transitive predecessors will still be covered // by (pred, op), so we don't have to add them to `potential_preds`. @@ -197,40 +262,55 @@ UpdateControlDependenciesForOp( } else { // Active transitive predecessors will not be covered by (pred, op) // anymore, so add them all as candidates. - for (Operation* transitive_pred : active_transitive_preds[pred]) { - pred_set.insert(transitive_pred); - } + pred_set.insert( + active_transitive_preds[pred].begin(), + active_transitive_preds[pred].end()); ++num_invalid_dependencies; } } - std::vector potential_preds(pred_set.begin(), pred_set.end()); - std::sort(potential_preds.begin(), potential_preds.end(), IsAfterInBlock()); + // Now collect a list of valid dependencies and sort them in program order. + std::vector potential_preds; + potential_preds.reserve(pred_set.size()); - llvm::SmallVector preds_in_reverse_program_order; - for (Operation* potential_pred : potential_preds) { - bool is_valid = - IsValidDependency(potential_pred, op, op_to_parallel_ids_map); - if (!is_valid) { + for (Operation* potential_pred : pred_set) { + if (IsValidDependency(potential_pred, op, op_to_parallel_ids_map)) { + potential_preds.push_back(potential_pred); + } else { // We don't keep the (pred, op) dependency, so all active transitive // dependencies become inactive. - inactive_transitive_preds[op].insert( + op_inactive.insert( active_transitive_preds[potential_pred].begin(), active_transitive_preds[potential_pred].end()); - } else if (!active_transitive_preds[op].contains(potential_pred)) { + } + } + std::sort(potential_preds.begin(), potential_preds.end(), IsAfterInBlock()); + + // Finally, accumulate dependencies until we have coverage over all active + // dependencies. + llvm::SmallVector preds_in_reverse_program_order; + for (Operation* potential_pred : potential_preds) { + if (!op_active.contains(potential_pred)) { // `potential_pred` is not an active transitive predecessor of `op` yet, // so we must add it as a direct predecessor. preds_in_reverse_program_order.push_back(potential_pred); // We keep the (pred, op) dependency, so all active transitive // dependencies stay active. - active_transitive_preds[op].insert( + op_active.insert( active_transitive_preds[potential_pred].begin(), active_transitive_preds[potential_pred].end()); } } + + for (Operation* pred : op_active) { + op_inactive.erase(pred); + } + ClearControlInputs(op, num_control_inputs_removed); SetControlInputs(op, preds_in_reverse_program_order, num_control_inputs_added); + + op_cache[key] = {op, preds_in_reverse_program_order}; } // This function updates all control dependencies in `func`, represented as @@ -259,6 +339,7 @@ LogicalResult UpdateAllControlDependencies( // Maps island ops to parallel IDs of the wrapped ops. OpToParallelIdsMap op_to_parallel_ids_map; + OpCache op_cache; OpToOpsMap active_transitive_preds, inactive_transitive_preds; // We call `VerifyExportSuitable` in the beginning of the pass, so every @@ -275,6 +356,7 @@ LogicalResult UpdateAllControlDependencies( op, analysis_for_func, op_to_parallel_ids_map, + op_cache, active_transitive_preds, inactive_transitive_preds, num_control_inputs_removed, From f3f1f2a6e7eddb17a9dc2dbbe80ac899a28e1e09 Mon Sep 17 00:00:00 2001 From: Juan Martinez Castellanos Date: Mon, 10 Jul 2023 10:51:33 -0700 Subject: [PATCH 063/376] Make all Python targets under tensorflow/cc/saved_model/* have strict dependencies. PiperOrigin-RevId: 546919496 --- tensorflow/cc/saved_model/BUILD | 11 ++++++++--- .../cc/saved_model/testdata/generate_saved_models.py | 2 +- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/tensorflow/cc/saved_model/BUILD b/tensorflow/cc/saved_model/BUILD index de7b00e2eaeb6e..e1bc27d3edc7cb 100644 --- a/tensorflow/cc/saved_model/BUILD +++ b/tensorflow/cc/saved_model/BUILD @@ -2,6 +2,7 @@ #Description: # TensorFlow SavedModel. +load("//tensorflow:strict.default.bzl", "py_strict_binary") load("//tensorflow:tensorflow.default.bzl", "filegroup") load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load( @@ -249,7 +250,7 @@ tf_cc_test( ) # A subset of the TF2 saved models can be generated with this tool. -py_binary( +py_strict_binary( name = "testdata/generate_saved_models", srcs = ["testdata/generate_saved_models.py"], data = [ @@ -259,12 +260,14 @@ py_binary( python_version = "PY3", srcs_version = "PY3", deps = [ + "//tensorflow/python/client:session", "//tensorflow/python/compat:v2_compat", "//tensorflow/python/eager:def_function", "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", "//tensorflow/python/framework:tensor_spec", "//tensorflow/python/module", + "//tensorflow/python/ops:io_ops", "//tensorflow/python/ops:lookup_ops", "//tensorflow/python/ops:variables", "//tensorflow/python/platform:client_testlib", @@ -277,17 +280,18 @@ py_binary( # copybara:uncomment_begin(google-only) # -# py_binary( +# py_strict_binary( # name = "testdata/generate_chunked_models", # srcs = ["testdata/generate_chunked_models.py"], # python_version = "PY3", # srcs_version = "PY3", # deps = [ +# "//third_party/py/numpy", # "//tensorflow/python/compat:v2_compat", # "//tensorflow/python/eager:def_function", # "//tensorflow/python/framework:constant_op", +# "//tensorflow/python/lib/io:lib", # "//tensorflow/python/module", -# "//tensorflow/python/platform:client_testlib", # "//tensorflow/python/saved_model:loader", # "//tensorflow/python/saved_model:save", # "//tensorflow/python/saved_model:save_options", @@ -295,6 +299,7 @@ py_binary( # "//tensorflow/tools/proto_splitter:constants", # "//tensorflow/tools/proto_splitter/python:saved_model", # "@absl_py//absl:app", +# "@absl_py//absl/flags", # ], # ) # diff --git a/tensorflow/cc/saved_model/testdata/generate_saved_models.py b/tensorflow/cc/saved_model/testdata/generate_saved_models.py index 5644feaaeea5da..5b2e458bbb64c6 100644 --- a/tensorflow/cc/saved_model/testdata/generate_saved_models.py +++ b/tensorflow/cc/saved_model/testdata/generate_saved_models.py @@ -17,7 +17,7 @@ import os from absl import app -from keras.optimizers.optimizers_v2 import adam +from keras.optimizers.legacy import adam from tensorflow.python.client import session as session_lib from tensorflow.python.compat import v2_compat From d14f8deaa802bd859bc7bae5dd27ef41e445f1a4 Mon Sep 17 00:00:00 2001 From: Haibo Huang Date: Mon, 10 Jul 2023 10:59:52 -0700 Subject: [PATCH 064/376] Returns error if the requested path is a bucket and doesn't exist In the current implementation, if fname is a bucket that doesn't exist, we will fall-through to StatForObject() below. But StatForObject() doesn't allow empty object. PiperOrigin-RevId: 546922445 --- tensorflow/tsl/platform/cloud/BUILD | 4 ++-- tensorflow/tsl/platform/cloud/gcs_file_system.cc | 4 ++++ tensorflow/tsl/platform/cloud/gcs_file_system_test.cc | 6 ++---- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/tensorflow/tsl/platform/cloud/BUILD b/tensorflow/tsl/platform/cloud/BUILD index a8bf3e30fee028..6266747a67248a 100644 --- a/tensorflow/tsl/platform/cloud/BUILD +++ b/tensorflow/tsl/platform/cloud/BUILD @@ -118,7 +118,6 @@ cc_library( ":http_request", ":ram_file_block_cache", ":time_util", - "//tensorflow/tsl/lib/gtl:map_util", "//tensorflow/tsl/platform:env", "//tensorflow/tsl/platform:errors", "//tensorflow/tsl/platform:file_statistics", @@ -136,6 +135,7 @@ cc_library( "//tensorflow/tsl/profiler/lib:traceme", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", "@jsoncpp_git//:jsoncpp", ], alwayslink = 1, @@ -162,7 +162,6 @@ cc_library( ":http_request", ":ram_file_block_cache", ":time_util", - "//tensorflow/tsl/lib/gtl:map_util", "//tensorflow/tsl/platform:env", "//tensorflow/tsl/platform:errors", "//tensorflow/tsl/platform:file_statistics", @@ -180,6 +179,7 @@ cc_library( "//tensorflow/tsl/profiler/lib:traceme", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", "@jsoncpp_git//:jsoncpp", ], alwayslink = 1, diff --git a/tensorflow/tsl/platform/cloud/gcs_file_system.cc b/tensorflow/tsl/platform/cloud/gcs_file_system.cc index f451279053e74c..9baeb4d8266aa4 100644 --- a/tensorflow/tsl/platform/cloud/gcs_file_system.cc +++ b/tensorflow/tsl/platform/cloud/gcs_file_system.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include "absl/status/status.h" +#include "absl/strings/str_cat.h" #ifndef _WIN32 #include @@ -1459,6 +1460,9 @@ Status GcsFileSystem::FileExists(const string& fname, TransactionToken* token) { TF_RETURN_IF_ERROR(BucketExists(bucket, &result)); if (result) { return OkStatus(); + } else { + return absl::NotFoundError( + absl::StrCat("The specified bucket ", fname, " was not found.")); } } diff --git a/tensorflow/tsl/platform/cloud/gcs_file_system_test.cc b/tensorflow/tsl/platform/cloud/gcs_file_system_test.cc index 61a588eea96b2a..d9bea9b43d54e8 100644 --- a/tensorflow/tsl/platform/cloud/gcs_file_system_test.cc +++ b/tensorflow/tsl/platform/cloud/gcs_file_system_test.cc @@ -1649,10 +1649,8 @@ TEST(GcsFileSystemTest, FileExists_NotAsBucket) { 0 /* matching paths cache max entries */, kTestRetryConfig, kTestTimeoutConfig, *kAllowedLocationsDefault, nullptr /* gcs additional header */, false /* compose append */); - EXPECT_TRUE( - errors::IsInvalidArgument(fs.FileExists("gs://bucket2/", nullptr))); - EXPECT_TRUE( - errors::IsInvalidArgument(fs.FileExists("gs://bucket2", nullptr))); + EXPECT_TRUE(absl::IsNotFound(fs.FileExists("gs://bucket2/", nullptr))); + EXPECT_TRUE(absl::IsNotFound(fs.FileExists("gs://bucket2", nullptr))); } TEST(GcsFileSystemTest, FileExists_StatCache) { From 6b478d0e91c089a024f04225f437a88c9dd1be9d Mon Sep 17 00:00:00 2001 From: Kuangyuan Chen Date: Mon, 10 Jul 2023 11:06:52 -0700 Subject: [PATCH 065/376] Adding stream callback support in TFRT PiperOrigin-RevId: 546924901 --- tensorflow/core/tfrt/runtime/BUILD | 60 +++++ tensorflow/core/tfrt/runtime/channel.h | 79 +++++++ tensorflow/core/tfrt/runtime/channel_test.cc | 137 ++++++++++++ tensorflow/core/tfrt/runtime/stream.cc | 211 ++++++++++++++++++ tensorflow/core/tfrt/runtime/stream.h | 221 +++++++++++++++++++ tensorflow/core/tfrt/runtime/stream_test.cc | 133 +++++++++++ tensorflow/core/tfrt/saved_model/BUILD | 5 +- 7 files changed, 842 insertions(+), 4 deletions(-) create mode 100644 tensorflow/core/tfrt/runtime/channel.h create mode 100644 tensorflow/core/tfrt/runtime/channel_test.cc create mode 100644 tensorflow/core/tfrt/runtime/stream.cc create mode 100644 tensorflow/core/tfrt/runtime/stream.h create mode 100644 tensorflow/core/tfrt/runtime/stream_test.cc diff --git a/tensorflow/core/tfrt/runtime/BUILD b/tensorflow/core/tfrt/runtime/BUILD index 16978e7ce9306c..c4e50acd172bb5 100644 --- a/tensorflow/core/tfrt/runtime/BUILD +++ b/tensorflow/core/tfrt/runtime/BUILD @@ -19,6 +19,7 @@ package_group( # copybara:uncomment "//learning/brain/experimental/tfrt/...", # copybara:uncomment "//learning/brain/tfrt/...", # copybara:uncomment "//learning/infra/mira/...", + # copybara:uncomment "//learning/pathways/serving/...", # copybara:uncomment "//learning/serving/...", # copybara:uncomment "//quality/webanswers/servo2/...", ], @@ -101,6 +102,54 @@ cc_library( ], ) +cc_library( + name = "stream", + srcs = ["stream.cc"], + hdrs = ["stream.h"], + deps = [ + ":channel", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/core/framework:tensor", + "//tensorflow/core/framework:tensor_proto_cc", + "//tensorflow/tsl/platform:env", + "//tensorflow/tsl/platform:random", + "//tensorflow/tsl/profiler/lib:traceme", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", + "@com_google_absl//absl/utility", + "@llvm-project//mlir:IR", + ], +) + +cc_library( + name = "channel", + hdrs = ["channel.h"], + deps = [ + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/synchronization", + ], +) + +tf_cc_test( + name = "stream_test", + srcs = ["stream_test.cc"], + deps = [ + ":stream", + "//tensorflow/core/framework:tensor_testutil", + "//tensorflow/core/tfrt/saved_model:saved_model_testutil", + "//tensorflow/tsl/platform:env", + "@com_google_googletest//:gtest_main", + ], +) + tf_cc_test( name = "tf_threadpool_concurrent_work_queue_test", srcs = ["tf_threadpool_concurrent_work_queue_test.cc"], @@ -120,3 +169,14 @@ tf_cc_test( "@tf_runtime//:support", ], ) + +tf_cc_test( + name = "channel_test", + srcs = ["channel_test.cc"], + deps = [ + ":channel", + "//tensorflow/tsl/platform:env", + "@com_google_absl//absl/synchronization", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/tensorflow/core/tfrt/runtime/channel.h b/tensorflow/core/tfrt/runtime/channel.h new file mode 100644 index 00000000000000..5a01e78677064f --- /dev/null +++ b/tensorflow/core/tfrt/runtime/channel.h @@ -0,0 +1,79 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TFRT_RUNTIME_CHANNEL_H_ +#define TENSORFLOW_CORE_TFRT_RUNTIME_CHANNEL_H_ + +#include + +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/synchronization/mutex.h" + +namespace tensorflow { +namespace tfrt_stub { + +// An unbounded queue for communicating between threads. This class is +// thread-safe. +template +class UnboundedChannel { + public: + absl::Status Write(T value) { + absl::MutexLock lock(&mu_); + + if (closed_) { + return absl::InternalError( + "Failed to write to the UnboundedChannel that is closed."); + } + + channel_.push(std::move(value)); + + return absl::OkStatus(); + } + + bool Read(T& value) { + absl::MutexLock lock(&mu_); + + mu_.Await(absl::Condition( + +[](UnboundedChannel* channel) ABSL_SHARED_LOCKS_REQUIRED(mu_) { + return !channel->channel_.empty() || channel->closed_; + }, + this)); + + if (!channel_.empty()) { + value = std::move(channel_.front()); + channel_.pop(); + return true; + } + + // If channel_ is empty, then it must be closed at this point. + DCHECK(closed_); + return false; + } + + void Close() { + absl::MutexLock lock(&mu_); + closed_ = true; + } + + private: + absl::Mutex mu_; + std::queue channel_ ABSL_GUARDED_BY(mu_); + bool closed_ ABSL_GUARDED_BY(mu_) = false; +}; + +} // namespace tfrt_stub +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TFRT_RUNTIME_CHANNEL_H_ diff --git a/tensorflow/core/tfrt/runtime/channel_test.cc b/tensorflow/core/tfrt/runtime/channel_test.cc new file mode 100644 index 00000000000000..ed30b88720d479 --- /dev/null +++ b/tensorflow/core/tfrt/runtime/channel_test.cc @@ -0,0 +1,137 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/tfrt/runtime/channel.h" + +#include +#include + +#include +#include +#include "absl/synchronization/blocking_counter.h" +#include "tensorflow/tsl/platform/env.h" + +namespace tensorflow { +namespace tfrt_stub { +namespace { + +using ::testing::ElementsAreArray; +using ::testing::UnorderedElementsAreArray; +using ::testing::status::StatusIs; + +TEST(ChannelTest, Basic) { + UnboundedChannel channel; + + std::vector expected(100); + std::iota(expected.begin(), expected.end(), 0); + + tsl::Env::Default()->SchedClosure([&]() { + for (int v : expected) { + CHECK_OK(channel.Write(v)); + } + channel.Close(); + }); + + std::vector outputs; + int v = -1; + while (channel.Read(v)) { + outputs.push_back(v); + } + + EXPECT_THAT(outputs, ElementsAreArray(expected)); + + EXPECT_THAT(channel.Write(100), StatusIs(absl::StatusCode::kInternal)); +} + +TEST(ChannelTest, MultipleWriters) { + UnboundedChannel channel; + + std::vector expected(100); + std::iota(expected.begin(), expected.end(), 0); + + tsl::Env::Default()->SchedClosure([&]() { + absl::BlockingCounter bcount(expected.size()); + for (int v : expected) { + tsl::Env::Default()->SchedClosure([&, v]() { + CHECK_OK(channel.Write(v)); + bcount.DecrementCount(); + }); + } + bcount.Wait(); + channel.Close(); + }); + + std::vector outputs; + int v = 0; + while (channel.Read(v)) { + outputs.push_back(v); + } + + EXPECT_THAT(outputs, UnorderedElementsAreArray(expected)); +} + +TEST(ChannelTest, MultipleReaders) { + UnboundedChannel channel; + + std::vector expected(100); + std::iota(expected.begin(), expected.end(), 0); + + absl::Mutex mu; + std::vector outputs; + + int num_readers = 200; + absl::BlockingCounter bcount(num_readers); + for (int i = 0; i < num_readers; ++i) { + tsl::Env::Default()->SchedClosure([&]() { + int v = 0; + while (channel.Read(v)) { + absl::MutexLock lock(&mu); + outputs.push_back(v); + } + bcount.DecrementCount(); + }); + } + + for (int v : expected) { + CHECK_OK(channel.Write(v)); + } + channel.Close(); + + bcount.Wait(); + EXPECT_THAT(outputs, UnorderedElementsAreArray(expected)); +} + +TEST(ChannelTest, FullyBuffered) { + UnboundedChannel channel; + + std::vector expected(100); + std::iota(expected.begin(), expected.end(), 0); + + for (int v : expected) { + CHECK_OK(channel.Write(v)); + } + channel.Close(); + + std::vector outputs; + int v = -1; + while (channel.Read(v)) { + outputs.push_back(v); + } + + EXPECT_THAT(outputs, ElementsAreArray(expected)); +} + +} // namespace +} // namespace tfrt_stub +} // namespace tensorflow diff --git a/tensorflow/core/tfrt/runtime/stream.cc b/tensorflow/core/tfrt/runtime/stream.cc new file mode 100644 index 00000000000000..01ed8ad9e6824f --- /dev/null +++ b/tensorflow/core/tfrt/runtime/stream.cc @@ -0,0 +1,211 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/tfrt/runtime/stream.h" + +#include +#include +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/synchronization/mutex.h" +#include "absl/time/clock.h" +#include "absl/time/time.h" +#include "absl/utility/utility.h" +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/tsl/platform/random.h" +#include "tensorflow/tsl/profiler/lib/traceme.h" + +namespace tensorflow { +namespace tfrt_stub { + +absl::StatusOr> CreateStreamCallbackId( + absl::string_view model_name, mlir::ModuleOp module) { + mlir::Builder builder(module.getContext()); + + // Inject information about the callback to `tf.PwStreamResults` ops. The + // attribute names must match `PwStreamResult` op's implementation. + + std::vector ops; + module->walk([&](mlir::TF::PwStreamResultsOp op) { ops.push_back(op); }); + + if (ops.empty()) { + return std::nullopt; + } + + auto& stream_interface = GetGlobalStreamCallbackRegistry().stream_interface(); + + auto controller_address = stream_interface.controller_address(); + auto controller_address_attr = builder.getStringAttr(controller_address); + + auto model_name_attr = builder.getStringAttr(model_name); + + // We use int64_t instead of uint64_t returned by `New64()` because + // TensorFlow doesn't support uint64 attributes. + const StreamCallbackId callback_id( + static_cast(tsl::random::New64())); + auto callback_id_attr = builder.getI64IntegerAttr(callback_id.id); + + for (auto op : ops) { + op->setAttr("_controller_address", controller_address_attr); + op->setAttr("_model_name", model_name_attr); + op->setAttr("_callback_id", callback_id_attr); + } + + return callback_id; +} + +absl::StatusOr StreamCallbackRegistry::Register( + absl::string_view model_name, StreamCallbackId callback_id, StepId step_id, + absl::AnyInvocable< + void(absl::flat_hash_map)> + callback) { + absl::MutexLock l(&mu_); + + const auto [it, inserted] = + stream_callbacks_.insert({std::make_pair(callback_id, step_id), nullptr}); + if (!inserted) { + return absl::AlreadyExistsError(absl::StrCat( + "Stream callback ", callback_id, " @ ", step_id, " already exists")); + } + + it->second = std::make_unique(); + it->second->thread = absl::WrapUnique(tsl::Env::Default()->StartThread( + tensorflow::ThreadOptions(), + /*name=*/absl::StrCat("stream_handler_", callback_id, "_", step_id), + [model_name = std::string(model_name), callback_id, step_id, + callback = std::move(callback), state = it->second.get(), + this]() mutable { + StreamedResult result; + while (state->channel.Read(result)) { + absl::Duration dequeue_latency = absl::Now() - result.enqueued_time; + interface_->RecordDequeueLatency(model_name, dequeue_latency); + + tsl::profiler::TraceMe trace_me("StreamCallbackInvocation"); + trace_me.AppendMetadata([&]() { + return tsl::profiler::TraceMeEncode({ + {"callback_id", callback_id.id}, + {"step_id", step_id.id}, + }); + }); + + absl::Time start_time = absl::Now(); + callback(std::move(result.tensors)); + interface_->RecordCallbackLatency(model_name, + absl::Now() - start_time); + } + })); + + return ScopedStreamCallback(this, callback_id, step_id); +} + +absl::Status StreamCallbackRegistry::Write(StreamCallbackId callback_id, + StepId step_id, + StreamedResult result) { + absl::MutexLock lock(&mu_); + auto iter = stream_callbacks_.find({callback_id, step_id}); + if (iter == stream_callbacks_.end()) { + return absl::NotFoundError(absl::StrCat( + "Stream callback ", callback_id, " @ ", step_id, + " does not exist; this usually indicates that a streaming signature " + "was called by a non-streaming request")); + } + + auto* state = iter->second.get(); + DCHECK(state); + return state->channel.Write(std::move(result)); +} + +std::unique_ptr +StreamCallbackRegistry::Unregister(StreamCallbackId callback_id, + StepId step_id) { + absl::MutexLock l(&mu_); + const auto it = stream_callbacks_.find({callback_id, step_id}); + if (it == stream_callbacks_.end()) { + return nullptr; + } + auto state = std::move(it->second); + stream_callbacks_.erase(it); + return state; +} + +ScopedStreamCallback::ScopedStreamCallback(ScopedStreamCallback&& other) + : registry_(other.registry_), + callback_id_(other.callback_id_), + step_id_(other.step_id_) { + other.callback_id_ = std::nullopt; + other.step_id_ = StepId::GetInvalidStepId(); +} + +ScopedStreamCallback& ScopedStreamCallback::operator=( + ScopedStreamCallback&& other) { + Unregister(); + + registry_ = other.registry_; + callback_id_ = other.callback_id_; + step_id_ = other.step_id_; + other.callback_id_ = std::nullopt; + other.step_id_ = StepId::GetInvalidStepId(); + + return *this; +} + +void ScopedStreamCallback::Unregister() { + if (!callback_id_.has_value()) { + return; + } + + tsl::profiler::TraceMe trace_me("ScopedStreamCallback::Unregister"); + trace_me.AppendMetadata([&]() { + return tsl::profiler::TraceMeEncode({ + {"callback_id", callback_id_->id}, + {"step_id", step_id_.id}, + }); + }); + + DCHECK(registry_); + auto state = registry_->Unregister(*callback_id_, step_id_); + DCHECK(state); + + // At this point, it is safe to close the channel. + state->channel.Close(); + + // Wait until the stream handler finishes. + state->thread.reset(); + + callback_id_.reset(); +} + +StreamInterfaceFactory& GetGlobalStreamInterfaceFactory() { + static auto* stream_interface_factory = new StreamInterfaceFactory; + return *stream_interface_factory; +} + +StreamCallbackRegistry& GetGlobalStreamCallbackRegistry() { + static auto* stream_callback_registry = new StreamCallbackRegistry( + GetGlobalStreamInterfaceFactory().CreateStreamInterface().value()); + return *stream_callback_registry; +} + +} // namespace tfrt_stub +} // namespace tensorflow diff --git a/tensorflow/core/tfrt/runtime/stream.h b/tensorflow/core/tfrt/runtime/stream.h new file mode 100644 index 00000000000000..7fea8ffe88c01b --- /dev/null +++ b/tensorflow/core/tfrt/runtime/stream.h @@ -0,0 +1,221 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See +the License for the specific language governing permissions and limitations +under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TFRT_RUNTIME_STREAM_H_ +#define TENSORFLOW_CORE_TFRT_RUNTIME_STREAM_H_ + +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/functional/any_invocable.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/tfrt/runtime/channel.h" +#include "tensorflow/tsl/platform/env.h" + +namespace tensorflow { +namespace tfrt_stub { + +template +struct SafeId { + SafeId() : id(0) {} + explicit constexpr SafeId(int64_t id) : id(id) {} + + using Base = SafeId; + + int64_t id; + + friend bool operator==(const Derived& x, const Derived& y) { + return x.id == y.id; + } + + template + friend void AbslStringify(Sink& sink, const Derived& x) { + absl::Format(&sink, "%d", x.id); + } + + template + friend H AbslHashValue(H h, const Derived& x) { + return H::combine(std::move(h), x.id); + } +}; + +struct StreamedResult { + absl::flat_hash_map tensors; + absl::Time enqueued_time; +}; + +struct StreamCallbackId : SafeId { + using Base::Base; +}; + +struct StepId : SafeId { + using Base::Base; + + bool valid() const { return id != 0; } + static constexpr StepId GetInvalidStepId() { return StepId(0); } +}; + +class StreamInterface { + public: + explicit StreamInterface(std::string controller_address) + : controller_address_(std::move(controller_address)) {} + virtual ~StreamInterface() = default; + + absl::string_view controller_address() const { return controller_address_; } + + virtual void RecordDequeueLatency(absl::string_view model_name, + absl::Duration latency) {} + + virtual void RecordCallbackLatency(absl::string_view model_name, + absl::Duration latency) {} + + private: + std::string controller_address_; +}; + +class ScopedStreamCallback; + +class StreamInterfaceFactory { + public: + void Register(absl::AnyInvocable< + absl::StatusOr>() const> + interface_factory) { + absl::MutexLock lock(&mu_); + interface_factory_ = std::move(interface_factory); + } + + absl::StatusOr> CreateStreamInterface() + const { + absl::MutexLock lock(&mu_); + return interface_factory_(); + } + + private: + mutable absl::Mutex mu_; + absl::AnyInvocable>() const> + interface_factory_ ABSL_GUARDED_BY(mu_) = []() { + return absl::InternalError( + "The factory for StreamInterface is not registered."); + }; +}; + +// Returns the global factory for the stream interface. The factory for the +// stream interface must be registered first before calling +// GetGlobalStreamCallbackRegistry(). +StreamInterfaceFactory& GetGlobalStreamInterfaceFactory(); + +// Mapping from tuples of (callback_id, step_id) to callback states. The mapping +// is stored in a global variable so that it can be shared between +// `ScopedStreamCallback` and `InvokeStreamCallbackOp`. +// +// This class is thread-safe. +class StreamCallbackRegistry { + public: + explicit StreamCallbackRegistry(std::unique_ptr interface) + : interface_(std::move(interface)) { + DCHECK(interface_); + } + + // Registers a callback under the given id. A stream callback is uniquely + // identified by a tuple of a callback id (unique to each executable) and a + // step id (unique to each invocation of a given executable). Returns an RAII + // object that removes the callback from the registry on its deallocation, or + // an error if the id already exists in the registry. + // + // If a program runs `tf.PwStreamResults` with a matching callback/step id, + // `callback` will be called with the arguments of `tf.PwStreamResults`. + // + // All invocations to `callback` are handled serially by a single thread, so + // `callback` doesn't need to be thread-safe even if multiple + // `tf.PwStreamResults` ops may run concurrently. + absl::StatusOr Register( + absl::string_view model_name, StreamCallbackId callback_id, + StepId step_id, + absl::AnyInvocable< + void(absl::flat_hash_map)> + callback); + + absl::Status Write(StreamCallbackId callback_id, StepId step_id, + StreamedResult result); + + StreamInterface& stream_interface() const { return *interface_; } + + private: + friend class ScopedStreamCallback; + + struct CallbackState { + std::unique_ptr thread; + UnboundedChannel channel; + }; + + std::unique_ptr Unregister(StreamCallbackId callback_id, + StepId step_id); + + std::unique_ptr interface_; + + mutable absl::Mutex mu_; + absl::flat_hash_map, + std::unique_ptr> + stream_callbacks_ ABSL_GUARDED_BY(mu_); +}; + +// Returns the global registry for the stream callbacks. The stream interface +// must have been registered through GetGlobalStreamInterfaceFactory() before +// calling this function. +StreamCallbackRegistry& GetGlobalStreamCallbackRegistry(); + +// Creates a new stream callback id and rewrites the given module with +// information required to trigger this callback remotely. Returns the callback +// id, or `std::nullopt` if the module has no stream outputs. +absl::StatusOr> CreateStreamCallbackId( + absl::string_view model_name, mlir::ModuleOp module); + +// Implements an RAII object that registers a callback to be called on receiving +// streamed tensors. +class ScopedStreamCallback { + public: + ScopedStreamCallback() = default; + + // Moveable but not copyable. + ScopedStreamCallback(ScopedStreamCallback&& other); + ScopedStreamCallback& operator=(ScopedStreamCallback&& other); + + ~ScopedStreamCallback() { Unregister(); } + + private: + friend class StreamCallbackRegistry; + + explicit ScopedStreamCallback(StreamCallbackRegistry* registry, + StreamCallbackId callback_id, StepId step_id) + : registry_(registry), callback_id_(callback_id), step_id_(step_id) {} + + void Unregister(); + + StreamCallbackRegistry* registry_ = nullptr; + std::optional callback_id_; + StepId step_id_ = StepId::GetInvalidStepId(); +}; + +} // namespace tfrt_stub +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TFRT_RUNTIME_STREAM_H_ diff --git a/tensorflow/core/tfrt/runtime/stream_test.cc b/tensorflow/core/tfrt/runtime/stream_test.cc new file mode 100644 index 00000000000000..50fc07e79c1571 --- /dev/null +++ b/tensorflow/core/tfrt/runtime/stream_test.cc @@ -0,0 +1,133 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/tfrt/runtime/stream.h" + +#include +#include +#include + +#include +#include +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/tfrt/saved_model/saved_model_testutil.h" +#include "tensorflow/tsl/platform/env.h" + +namespace tensorflow { +namespace tfrt_stub { +namespace { + +using ::tensorflow::test::AsTensor; +using ::testing::AnyOf; +using ::testing::ElementsAreArray; +using ::testing::Pair; +using ::testing::UnorderedElementsAre; + +class TestStreamInterface : public StreamInterface { + public: + TestStreamInterface() : StreamInterface("test_address") {} +}; + +const bool kUnused = []() { + GetGlobalStreamInterfaceFactory().Register( + []() { return std::make_unique(); }); + return true; +}(); + +TEST(StreamTest, Simple) { + StreamCallbackId callback_id(1234); + StepId step_id(5678); + + std::vector> outputs; + + { + ASSERT_OK_AND_ASSIGN( + auto scoped_stream_callback, + GetGlobalStreamCallbackRegistry().Register( + "test_model", callback_id, step_id, + [&](absl::flat_hash_map arg) { + outputs.push_back(std::move(arg)); + })); + + std::vector> expected = + {{{"a", AsTensor({100})}, {"b", AsTensor({200})}}, + {{"c", AsTensor({300})}}}; + auto thread = absl::WrapUnique(tsl::Env::Default()->StartThread( + tsl::ThreadOptions(), "fake_stream_client", [&]() { + for (const auto& map : expected) { + CHECK_OK(GetGlobalStreamCallbackRegistry().Write( + callback_id, step_id, {map, absl::Now()})); + } + })); + } + + EXPECT_EQ(outputs.size(), 2); + EXPECT_THAT(GetTfTensorData(outputs[0]["a"]), + ElementsAreArray({100})); + EXPECT_THAT(GetTfTensorData(outputs[0]["b"]), + ElementsAreArray({200})); + EXPECT_THAT(GetTfTensorData(outputs[1]["c"]), + ElementsAreArray({300})); +} + +TEST(StreamTest, MultipleWriters) { + StreamCallbackId callback_id(1234); + StepId step_id(5678); + + std::vector>> outputs; + + { + ASSERT_OK_AND_ASSIGN( + auto scoped_stream_callback, + GetGlobalStreamCallbackRegistry().Register( + "test_model", callback_id, step_id, + [&](absl::flat_hash_map arg) { + absl::flat_hash_map> out; + for (const auto& p : arg) { + out[p.first] = GetTfTensorData(p.second); + } + outputs.push_back(std::move(out)); + })); + + std::vector> expected = + {{{"a", AsTensor({100})}, {"b", AsTensor({200})}}, + {{"c", AsTensor({300})}}}; + + for (const auto& p : expected) { + tsl::Env::Default()->SchedClosure([&, p]() { + // The stream callback may be dropped early, and in that case we ignore + // the error. + GetGlobalStreamCallbackRegistry() + .Write(callback_id, step_id, {p, absl::Now()}) + .IgnoreError(); + }); + } + + absl::SleepFor(absl::Microseconds(100)); + } + + LOG(INFO) << "StreamCallback receives " << outputs.size() << " outputs."; + + for (const auto& output : outputs) { + EXPECT_THAT( + output, + AnyOf(UnorderedElementsAre(Pair("a", ElementsAreArray({100})), + Pair("b", ElementsAreArray({200}))), + UnorderedElementsAre(Pair("c", ElementsAreArray({300}))))); + } +} + +} // namespace +} // namespace tfrt_stub +} // namespace tensorflow diff --git a/tensorflow/core/tfrt/saved_model/BUILD b/tensorflow/core/tfrt/saved_model/BUILD index 98eb3e205c53d5..699b84ce7c5812 100644 --- a/tensorflow/core/tfrt/saved_model/BUILD +++ b/tensorflow/core/tfrt/saved_model/BUILD @@ -16,10 +16,7 @@ package_group( # copybara:uncomment "//learning/serving/...", "//tensorflow/core/runtime_fallback/...", "//tensorflow/core/tfrt/mlrt/application/tensorflow/tests/...", - "//tensorflow/core/tfrt/saved_model/tests/...", - "//tensorflow/core/tfrt/graph_executor/...", - "//tensorflow/core/tfrt/tfrt_session/...", - "//tensorflow/core/tfrt/utils/debug/...", + "//tensorflow/core/tfrt/...", "//tensorflow_serving/...", "//tensorflow/core/tfrt/saved_model/python/...", # copybara:uncomment "//platforms/xla/tests/saved_models/...", From d2f3214bd6909e2e6b6b48e4ca696ba1bd88893b Mon Sep 17 00:00:00 2001 From: Ilia Sergachev Date: Mon, 10 Jul 2023 11:16:20 -0700 Subject: [PATCH 066/376] [XLA:GPU] Roll-forward cl/543680393: Fuse more inputs into Triton GEMMs. - Let the GEMM rewriter do more complex traversals of inputs and fuse elementwise operations and broadcasts of scalar constants. - Limit the number of parameters per fusion. - Reorganize GPU compiler pipeline: bf16 float normalization is now required both before and after Triton GEMM fusion. - Remove an autotuner config that for unknown reasons fails on Volta with new fusions. One problem with the original CL was fixed in cl/544612599. Other ones are fixed in this one and are covered by the new tests GemmRewriterTritonTest.DoNotFuseIncompatibleDimOrders and DoNotFuseTooManyParameters. PiperOrigin-RevId: 546928177 --- tensorflow/compiler/xla/service/gpu/BUILD | 7 + .../xla/service/gpu/gemm_rewriter_triton.cc | 438 ++++++++++++------ .../xla/service/gpu/gemm_rewriter_triton.h | 49 +- .../service/gpu/gemm_rewriter_triton_test.cc | 147 +++++- .../compiler/xla/service/gpu/gpu_compiler.cc | 37 +- .../xla/service/gpu/ir_emitter_triton.cc | 2 +- .../xla/service/gpu/ir_emitter_triton_test.cc | 152 ++++++ .../xla/service/gpu/triton_autotuner.cc | 11 +- 8 files changed, 672 insertions(+), 171 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 56295b88908337..316330826a364b 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -438,6 +438,7 @@ cc_library( "//tensorflow/tsl/platform:errors", "//tensorflow/tsl/platform:logging", "//tensorflow/tsl/platform:path", + "//tensorflow/tsl/platform:statusor", "//tensorflow/tsl/platform:tensor_float_32_hdr_lib", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", @@ -493,6 +494,8 @@ xla_test( "//tensorflow/compiler/xla:autotuning_proto_cc", "//tensorflow/compiler/xla:error_spec", "//tensorflow/compiler/xla/hlo/ir:hlo", + "//tensorflow/compiler/xla/service:pattern_matcher", + "//tensorflow/compiler/xla/service:pattern_matcher_gmock", "//tensorflow/compiler/xla/service/gpu/tests:gpu_codegen_test", "//tensorflow/compiler/xla/stream_executor:device_description", "//tensorflow/compiler/xla/stream_executor/cuda:cublas_plugin", @@ -1154,18 +1157,22 @@ cc_library( "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status", + "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/hlo/ir:hlo", "//tensorflow/compiler/xla/hlo/utils:hlo_query", "//tensorflow/compiler/xla/service:hlo_creation_utils", "//tensorflow/compiler/xla/service:hlo_pass", + "//tensorflow/compiler/xla/service:instruction_fusion", + "//tensorflow/compiler/xla/stream_executor:device_description", "//tensorflow/tsl/platform:errors", "//tensorflow/tsl/platform:status", "//tensorflow/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", ], diff --git a/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton.cc b/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton.cc index 6b28352ccd61ab..1f862d0bb5851b 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton.cc @@ -22,12 +22,15 @@ limitations under the License. #include #include #include +#include #include #include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/log/check.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/autotuning.pb.h" @@ -37,6 +40,7 @@ limitations under the License. #include "tensorflow/compiler/xla/hlo/ir/hlo_instruction.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_instructions.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_opcode.h" +#include "tensorflow/compiler/xla/hlo/ir/hlo_schedule.h" #include "tensorflow/compiler/xla/hlo/utils/hlo_query.h" #include "tensorflow/compiler/xla/layout.h" #include "tensorflow/compiler/xla/literal_util.h" @@ -46,9 +50,12 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/gpu/matmul_utils.h" #include "tensorflow/compiler/xla/service/hlo_creation_utils.h" +#include "tensorflow/compiler/xla/service/instruction_fusion.h" #include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/stream_executor/device_description.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/tsl/platform/errors.h" @@ -57,6 +64,25 @@ limitations under the License. namespace xla { namespace gpu { + +bool TensorIterationSpec::operator==(const TensorIterationSpec& other) const { + for (int dim = 0; dim < TensorIterationSpec::kMaxDimsPerTensor; ++dim) { + if (dim_iteration_specs_[dim].size() != other[dim].size()) { + return false; + } + for (int fragment = 0; fragment < dim_iteration_specs_[dim].size(); + ++fragment) { + if (dim_iteration_specs_[dim][fragment].stride != + other[dim][fragment].stride || + dim_iteration_specs_[dim][fragment].count != + other[dim][fragment].count) { + return false; + } + } + } + return true; +} + namespace { // Batch dimensions of an operand of a dot instruction. @@ -95,10 +121,10 @@ int64_t NonContractingDimensionIndex(const HloInstruction& dot, } // Data types that are tested to work in the triton GEMM emitter. -bool IsSupportedDataType(PrimitiveType t, GpuVersion gpu_version) { +bool IsSupportedDataType(PrimitiveType type, GpuVersion gpu_version) { auto cuda_compute_capability = std::get(gpu_version); - switch (t) { + switch (type) { case PRED: case S8: case S16: @@ -114,21 +140,19 @@ bool IsSupportedDataType(PrimitiveType t, GpuVersion gpu_version) { } } -Status RequireTritonFusibleConvert(const HloInstruction* input, - GpuVersion gpu_version) { - if (!IsSupportedDataType(input->operand(0)->shape().element_type(), - gpu_version)) { - return Unimplemented("unsupported data type"); +// Let input and output data volumes of a fusion grow by small amounts. +constexpr int64_t kIoToleranceBytes = 1024; + +// Difference of input and output data volumes of an instruction. +int64_t InputMinusOutputBytes(const HloInstruction& hlo) { + CHECK(!hlo.shape().IsTuple()); + int64_t output_size = ShapeUtil::ByteSizeOf(hlo.shape()); + int64_t input_size = 0; + for (const HloInstruction* operand : hlo.operands()) { + CHECK(!operand->shape().IsTuple()); + input_size += ShapeUtil::ByteSizeOf(operand->shape()); } - // TODO(b/266862494): Can pick up almost any - // convert, but if it's reducing the data volume it should rather be fused - // to the output of the producer kernel. However not all operations support - // output fusion - then it should be fused here anyway! - if (ShapeUtil::ByteSizeOf(input->operand(0)->shape()) > - ShapeUtil::ByteSizeOf(input->shape())) { - return FailedPrecondition("narrowing conversion"); - } - return OkStatus(); + return input_size - output_size; } // Handles numbers of dimensions of a target HLO instruction @@ -142,6 +166,13 @@ class DimensionOrder { int64_t target_dim_number; int subdim_number; int64_t size; + bool operator==(const DimDescription& other) const { + return target_dim_number == other.target_dim_number && + subdim_number == other.subdim_number && size == other.size; + } + std::string ToString() const { + return absl::StrCat(target_dim_number, ":", subdim_number, ":", size); + } }; // Sequence describing all dimensions of HLO's output shape // in layout minor-to-major (physical) order. @@ -171,34 +202,35 @@ class DimensionOrder { // Transforms the DimensionOrder so that from a description of the output // of `hlo` it becomes a description of the input of `hlo`. - Status HandleInstruction(const HloInstruction* hlo) { + FusionDecision HandleInstruction(const HloInstruction* hlo) { VLOG(7) << hlo->ToString(); - if (hlo->opcode() == HloOpcode::kParameter) { - return OkStatus(); + if (hlo->opcode() == HloOpcode::kParameter || + hlo->opcode() == HloOpcode::kConstant) { + return FusionDecision{}; } else if (hlo->opcode() == HloOpcode::kTranspose || hlo->opcode() == HloOpcode::kCopy) { return HandleCopyOrTranspose(hlo); } else if (hlo->operand_count() > 0 && IsTritonSupportedElementwise( hlo->opcode(), hlo->operand(0)->shape().element_type())) { - return OkStatus(); + return FusionDecision{}; } else if (hlo->opcode() == HloOpcode::kBitcast) { return HandleBitcast(hlo); } else if (hlo->opcode() == HloOpcode::kReshape) { if (!ShapeUtil::ReshapeIsBitcast(hlo->operand(0)->shape(), hlo->shape())) { - return Unimplemented("Non-bitcast reshape."); + return "Non-bitcast reshape."; } return HandleBitcast(hlo); } else if (hlo_query::IsScalarConstant(hlo) || hlo_query::IsBroadcastOfScalarConstant(*hlo)) { // Dimension order collapses on a scalar, for simplicity leave it equal // to the output one for now. - return OkStatus(); + return FusionDecision{}; } else { - return Unimplemented("Instruction: %s", hlo->ToString()); + return "Unimplemented instruction."; } - return OkStatus(); + return FusionDecision{}; } // Get the raw data of the dimension order. @@ -210,20 +242,32 @@ class DimensionOrder { return splittable_dimension_index_; } + // Tells that two dimension orders describe the same tensor physical layout. + bool IsPhysicallyEquivalent(const DimensionOrder& other) const; + + std::string ToString() const { + return absl::StrJoin(dim_order_, "-", + [](std::string* out, const DimDescription& d) { + absl::StrAppend(out, d.ToString()); + }); + } + private: // See HandleInstruction() for the general description of Handle*(). - Status HandleBitcast(const HloInstruction* hlo); - Status HandleCopyOrTranspose(const HloInstruction* hlo); + FusionDecision HandleBitcast(const HloInstruction* hlo); + FusionDecision HandleCopyOrTranspose(const HloInstruction* hlo); DimOrderVector dim_order_; - int64_t splittable_dimension_index_; + const int64_t splittable_dimension_index_; }; -DotFusionAnalysis::TensorIterationSpec DimensionOrderToTensorIterationSpec( +using DimIterationSpec = TensorIterationSpec::DimIterationSpec; + +TensorIterationSpec DimensionOrderToTensorIterationSpec( const DimensionOrder& order) { const DimensionOrder::DimOrderVector& dim_order_vector = order.GetDimOrderVector(); - DotFusionAnalysis::TensorIterationSpec tensor_spec; + TensorIterationSpec tensor_spec; int64_t accumulated_stride = 1; for (int dim_order_index = 0; dim_order_index < dim_order_vector.size(); ++dim_order_index) { @@ -236,8 +280,7 @@ DotFusionAnalysis::TensorIterationSpec DimensionOrderToTensorIterationSpec( continue; } - DotFusionAnalysis::DimIterationSpec& dim_spec = - tensor_spec[dim.target_dim_number]; + DimIterationSpec& dim_spec = tensor_spec[dim.target_dim_number]; if (dim_order_index > 0 && dim_order_vector[dim_order_index - 1].target_dim_number == dim.target_dim_number) { @@ -257,7 +300,7 @@ DotFusionAnalysis::TensorIterationSpec DimensionOrderToTensorIterationSpec( accumulated_stride *= dim.size; } // Create all absent dimensions as degenerate ones to simplify later queries. - for (DotFusionAnalysis::DimIterationSpec& dim_spec : tensor_spec) { + for (DimIterationSpec& dim_spec : tensor_spec) { if (dim_spec.empty()) { dim_spec.push_back({/*stride=*/0, /*count=*/1, /*subfragments=*/{1}}); } @@ -265,6 +308,11 @@ DotFusionAnalysis::TensorIterationSpec DimensionOrderToTensorIterationSpec( return tensor_spec; } +bool DimensionOrder::IsPhysicallyEquivalent(const DimensionOrder& other) const { + return DimensionOrderToTensorIterationSpec(*this) == + DimensionOrderToTensorIterationSpec(other); +} + DimensionOrder DimensionOrder::FromDotOperand(const HloInstruction& dot, const int operand_number, const int64_t split_k) { @@ -287,7 +335,7 @@ DimensionOrder DimensionOrder::FromDotOutput(const HloInstruction& dot) { return DimensionOrder(&dot); } -Status DimensionOrder::HandleBitcast(const HloInstruction* hlo) { +FusionDecision DimensionOrder::HandleBitcast(const HloInstruction* hlo) { const Shape& operand_shape = hlo->operand(0)->shape(); DimOrderVector operand_dim_order; operand_dim_order.reserve(dim_order_.size()); @@ -301,7 +349,7 @@ Status DimensionOrder::HandleBitcast(const HloInstruction* hlo) { ++out_dim) { if (operand_remaining_size >= out_dim->size) { if (operand_remaining_size % out_dim->size) { - return Unimplemented("Unsupported bitcast: %s", hlo->ToString()); + return "Unsupported bitcast"; } // Output dimension fragment completely fits into the operand one: // just copy it as is. @@ -319,7 +367,7 @@ Status DimensionOrder::HandleBitcast(const HloInstruction* hlo) { // If there is a remaining fragment of a previous operand dimension // assign it first. if (out_remaining_size % operand_remaining_size) { - return Unimplemented("Unsupported bitcast: %s", hlo->ToString()); + return "Unsupported bitcast"; } operand_dim_order.push_back( {out_dim->target_dim_number, subdim_index, operand_remaining_size}); @@ -337,7 +385,7 @@ Status DimensionOrder::HandleBitcast(const HloInstruction* hlo) { // assign the remainder of the output and carry over the remainder // of the operand. if (operand_dim_size % out_remaining_size) { - return Unimplemented("Unsupported bitcast: %s", hlo->ToString()); + return "Unsupported bitcast"; } operand_remaining_size = operand_dim_size / out_remaining_size; new_fragment_size = out_remaining_size; @@ -358,7 +406,7 @@ Status DimensionOrder::HandleBitcast(const HloInstruction* hlo) { int subdim_index = operand_dim_order.back().subdim_number + 1; while (operand_dim_iter != operand_shape.layout().minor_to_major().cend()) { if (operand_shape.dimensions(*operand_dim_iter) != 1) { - return Unimplemented("Unsupported bitcast: %s", hlo->ToString()); + return "Unsupported bitcast"; } operand_dim_order.push_back( {operand_dim_order.back().target_dim_number, subdim_index, 1}); @@ -367,10 +415,11 @@ Status DimensionOrder::HandleBitcast(const HloInstruction* hlo) { } dim_order_ = operand_dim_order; - return OkStatus(); + return FusionDecision{}; } -Status DimensionOrder::HandleCopyOrTranspose(const HloInstruction* hlo) { +FusionDecision DimensionOrder::HandleCopyOrTranspose( + const HloInstruction* hlo) { // Every HLO dimension can correspond to a group of subdimensions in // dim_order_. For the easier handling of permutations: group dim_order_ by // dimension, apply permutations, then finally remove the grouping. @@ -419,25 +468,25 @@ Status DimensionOrder::HandleCopyOrTranspose(const HloInstruction* hlo) { dim_order_.push_back(subdim); } } - return OkStatus(); + return FusionDecision{}; } // Tells if the dimension order is supported by the triton GEMM emitter. // Only the dimension indicated by SplittableDimensionIndex() can be split // physically once by other dimensions. Other ones can be only split logically. // All subdimensions within a dimension have to be ordered. -Status RequireTritonGemmSupportedDimOrder(const DimensionOrder& order) { - std::array subdim_counters = { +FusionDecision RequireTritonGemmSupportedDimOrder(const DimensionOrder& order) { + std::array subdim_counters = { -1, -1, -1, -1}; - std::array split_counters = { + std::array split_counters = { -1, -1, -1, -1}; const DimensionOrder::DimOrderVector& dim_order_vector = order.GetDimOrderVector(); + VLOG(8) << order.ToString(); for (int i = 0; i < dim_order_vector.size(); i++) { const auto [dim_number, subdim_number, size] = dim_order_vector[i]; - VLOG(8) << dim_number << "\t" << subdim_number << "\t" << size; if (subdim_counters[dim_number] != subdim_number - 1) { - return Unimplemented("Transpose within a dimension."); + return "Transpose within a dimension."; } ++subdim_counters[dim_number]; if (size == 1) { @@ -447,31 +496,179 @@ Status RequireTritonGemmSupportedDimOrder(const DimensionOrder& order) { ++split_counters[dim_number]; if (dim_number == order.SplittableDimensionIndex()) { if (split_counters[dim_number] > 1) { - return Unimplemented("2nd split of a splittable dimension."); + return "2nd split of a splittable dimension."; } } else if (split_counters[dim_number] > 0) { - return Unimplemented("Split of a non-splittable dimension."); + return "Split of a non-splittable dimension."; } } } - return OkStatus(); + return FusionDecision{}; } -// Transforms dim_order describing the output of `hlo` into a +// Tells if an instruction has no input into which it could be fused. +// More cases should be added here. +bool CanNotBeFusedIntoAProducer(const HloInstruction& hlo) { + return hlo_query::AllOperandsAreParametersOrConstants(hlo); +} + +// Tells that fusing an instruction is efficient. +bool IsInputWorthFusing(const HloInstruction& hlo) { + return hlo_query::AllOperandsAreParametersOrConstants(hlo) || + InputMinusOutputBytes(hlo) < kIoToleranceBytes; +} + +// Checks if the instruction is possible and profitable to fuse. +// If so tries to transform dim_order describing output of `hlo` into a // description of its input if it is supported by the triton GEMM emitter. -Status CanFuse(const HloInstruction* hlo, DimensionOrder& dim_order, - const GpuVersion gpu_version) { - if (hlo->opcode() == HloOpcode::kConvert) { - return RequireTritonFusibleConvert(hlo, gpu_version); - } else if (hlo->IsElementwise() && hlo->opcode() != HloOpcode::kCopy) { - // Temporarily forbid fusing elementwise operations - // other than copy and convert. - return Unimplemented("Unsupported elementwise operation"); +FusionDecision CanFuse(const HloInstruction& hlo, DimensionOrder& dim_order, + const GpuVersion gpu_version) { + if (hlo.opcode() == HloOpcode::kTuple || + hlo.opcode() == HloOpcode::kGetTupleElement) { + return "Unsupported instruction."; + } + for (const HloInstruction* operand : hlo.operands()) { + if (!IsSupportedDataType(operand->shape().element_type(), gpu_version)) { + return "Unsupported input data type."; + } + } + if (!IsSupportedDataType(hlo.shape().element_type(), gpu_version)) { + return "Unsupported output data type."; + } + if (hlo.IsConstant()) { + return "Not fusing a constant."; + } + if (hlo.opcode() == HloOpcode::kBroadcast) { + return "Not fusing a broadcast."; + } + if (!CanNotBeFusedIntoAProducer(hlo) && !IsInputWorthFusing(hlo)) { + return "Not obviously profitable to fuse as input."; + } + if (FusionDecision decision = dim_order.HandleInstruction(&hlo); !decision) { + return decision; } - TF_RETURN_IF_ERROR(dim_order.HandleInstruction(hlo)); return RequireTritonGemmSupportedDimOrder(dim_order); } +// Clone an instruction into the fusion. +void Fuse(HloInstruction& hlo, + absl::flat_hash_map& + old_to_new_mapping, + std::vector& call_operands, + HloComputation::Builder& builder) { + if (old_to_new_mapping.contains(&hlo)) { + return; + } + VLOG(3) << "Fusing " << hlo.ToString(); + auto get_or_add_parameter = [&](HloInstruction& instr) { + if (auto it = old_to_new_mapping.find(&instr); + it != old_to_new_mapping.end()) { + return it->second; + } + call_operands.push_back(&instr); + return old_to_new_mapping + .insert({&instr, + builder.AddInstruction(HloInstruction::CreateParameter( + call_operands.size() - 1, instr.shape(), + absl::StrCat("parameter_", call_operands.size() - 1)))}) + .first->second; + }; + if (hlo.opcode() == HloOpcode::kParameter || + hlo.opcode() == HloOpcode::kGetTupleElement) { + get_or_add_parameter(hlo); + } else { + std::vector hlo_new_operands; + for (HloInstruction* operand : hlo.operands()) { + hlo_new_operands.push_back(get_or_add_parameter(*operand)); + } + old_to_new_mapping[&hlo] = builder.AddInstruction( + hlo.CloneWithNewOperands(hlo.shape(), hlo_new_operands)); + } +} + +// Tells how many new parameters does a fusion gain by fusing the operation as +// an input. +int64_t NumAddedParameters(const HloInstruction& hlo) { + // Non-scalar constant is equivalent to a parameter: one input, one output. + if (hlo.opcode() == HloOpcode::kConstant && + !ShapeUtil::IsScalar(hlo.shape())) { + return 0; + } + // All other instructions add all own inputs and remove own single output. + return hlo.operand_count() - 1; +} + +// Fuse an instruction with all its fusible inputs. +// If an input is not fusible stop there and make a parameter of the new +// fusion, otherwise put it onto stack and check its own inputs first. +void FuseWithInputsRecursively( + HloInstruction* root, DimensionOrder root_dim_order, + // Dimension orders describing inputs of corresponding instructions. + absl::flat_hash_map& dim_orders, + const GpuVersion gpu_version, + absl::flat_hash_map& + old_to_new_mapping, + std::vector& call_operands, + HloComputation::Builder& builder) { + absl::flat_hash_set visited; + std::stack to_fuse; + // Instructions at the edge 'to_fuse' that can either get fused too or + // become parameters of the fusion. Used to track the number of parameters + // of the fusion. + absl::flat_hash_set inputs; + // Currently only one physically unique dim order per scope is supported. + // Let it change while the scope has one input; afterwards require all + // of them to be physically compatible. + const HloInstruction* reference_dim_order_hlo = nullptr; + if (CanFuse(*root, root_dim_order, gpu_version)) { + to_fuse.push(root); + inputs.insert(root->operands().begin(), root->operands().end()); + // root_dim_order went through output -> input transformation here. + CHECK(dim_orders.insert({root, root_dim_order}).second) << root->ToString(); + } + visited.insert(root); + while (!to_fuse.empty()) { + bool top_is_ready_to_fuse = true; + HloInstruction* hlo = to_fuse.top(); + if (reference_dim_order_hlo == nullptr && hlo->operand_count() > 1) { + reference_dim_order_hlo = hlo; + } + for (HloInstruction* operand : hlo->mutable_operands()) { + if (visited.insert(operand).second) { + // Stop adding new parameters. + if (inputs.size() >= DotFusionAnalysis::kMaxParameterPerScope && + NumAddedParameters(*operand) > 0) { + continue; + } + // Operand's output is described by its consumer's input. + DimensionOrder operand_dim_order(dim_orders.at(hlo)); + // CanFuse() makes output -> input transformation of + // operand_dim_order if succeeds. + if (CanFuse(*operand, operand_dim_order, gpu_version)) { + if (reference_dim_order_hlo != nullptr && + !operand_dim_order.IsPhysicallyEquivalent( + dim_orders.at(reference_dim_order_hlo))) { + continue; + } + to_fuse.push(operand); + if (operand->opcode() != HloOpcode::kParameter) { + inputs.erase(operand); + } + inputs.insert(operand->operands().begin(), operand->operands().end()); + // Save the dimension order description of operand's input. + CHECK(dim_orders.insert({operand, operand_dim_order}).second) + << operand->ToString(); + top_is_ready_to_fuse = false; + } + } + } + if (top_is_ready_to_fuse) { + Fuse(*hlo, old_to_new_mapping, call_operands, builder); + to_fuse.pop(); + } + } +} + // Extracts into fused computations parts of HLO graph including dot() // operations that can target the triton GEMM emitter. class GemmRewriterTritonVisitor : public DfsHloRewriteVisitor { @@ -483,8 +680,9 @@ class GemmRewriterTritonVisitor : public DfsHloRewriteVisitor { // and replaces the original dot() with a call to the computation. Status HandleDot(HloInstruction* dot) override { VLOG(5) << dot->ToString(); - - if (!CanTritonHandleGEMM(*dot, gpu_version_)) { + FusionDecision can_handle = CanTritonHandleGEMM(*dot, gpu_version_); + if (!can_handle) { + VLOG(3) << can_handle.Explain(); return OkStatus(); } @@ -503,72 +701,28 @@ class GemmRewriterTritonVisitor : public DfsHloRewriteVisitor { std::string suggested_name = absl::StrCat("triton_gemm_", dot->name()); HloComputation::Builder builder( absl::StrCat(suggested_name, "_computation")); + std::vector call_operands; // Original instruction -> fused one. absl::flat_hash_map old_to_new_mapping; - absl::flat_hash_set visited; - std::vector call_operands; - // Traverse and fuse dot() inputs bottom-up starting from direct operands. - // If an input is not fusible stop there and make it a parameter of the new - // fusion, otherwise put it onto stack and check its own inputs first. - std::stack to_fuse; - // Dimension orders describing inputs of corresponding instructions. - absl::flat_hash_map dim_orders; - to_fuse.push(dot); - while (!to_fuse.empty()) { - bool top_is_ready_to_fuse = true; - HloInstruction* hlo = to_fuse.top(); - for (HloInstruction* operand : hlo->mutable_operands()) { - if (visited.insert(operand).second) { - DimensionOrder operand_dim_order = [&] { - // Direct dot inputs are described by default dimension orders. - if (operand == dot->operand(0)) { - return DimensionOrder::FromDotOperand(*dot, 0); - } else if (operand == dot->operand(1)) { - return DimensionOrder::FromDotOperand(*dot, 1); - } - // Otherwise operand's output is described by its consumer's input. - return DimensionOrder(dim_orders.at(hlo)); - }(); - // CanFuse() makes output -> input transformation of - // operand_dim_order if succeeds. - if (CanFuse(operand, operand_dim_order, gpu_version_).ok()) { - VLOG(3) << "Fusing " << operand->ToString(); - to_fuse.push(operand); - // Save the dimension order description of operand's input. - dim_orders.insert({operand, operand_dim_order}); - top_is_ready_to_fuse = false; - } - } - } - if (top_is_ready_to_fuse) { - if (hlo->opcode() == HloOpcode::kParameter || - hlo->opcode() == HloOpcode::kGetTupleElement) { - old_to_new_mapping[hlo] = - builder.AddInstruction(HloInstruction::CreateParameter( - call_operands.size(), hlo->shape(), - absl::StrCat("parameter_", call_operands.size()))); - call_operands.push_back(hlo); - } else { - std::vector hlo_new_operands; - for (HloInstruction* operand : hlo->operands()) { - const auto iter = old_to_new_mapping.find(operand); - if (iter != old_to_new_mapping.end()) { - hlo_new_operands.push_back(iter->second); - } else { - hlo_new_operands.push_back( - builder.AddInstruction(HloInstruction::CreateParameter( - call_operands.size(), operand->shape(), - absl::StrCat("parameter_", call_operands.size())))); - call_operands.push_back(operand); - } - } - old_to_new_mapping[hlo] = builder.AddInstruction( - hlo->CloneWithNewOperands(hlo->shape(), hlo_new_operands)); - } - to_fuse.pop(); - } - } + + auto fuse_inputs = [&](int operand_number) { + absl::flat_hash_map dim_orders; + int operand_count_before = call_operands.size(); + // Direct dot inputs have well defined dimension orders. + FuseWithInputsRecursively( + dot->mutable_operand(operand_number), + DimensionOrder::FromDotOperand(*dot, operand_number), dim_orders, + gpu_version_, old_to_new_mapping, call_operands, builder); + return call_operands.size() - operand_count_before; + }; + // Separate traversal from LHS and RHS inputs of the dot: they use + // differently shaped tiles but may go through same HLO graph nodes. + TF_RET_CHECK(fuse_inputs(0) <= DotFusionAnalysis::kMaxParameterPerScope); + TF_RET_CHECK(fuse_inputs(1) <= DotFusionAnalysis::kMaxParameterPerScope); + + Fuse(*dot, old_to_new_mapping, call_operands, builder); + HloComputation* computation = dot->GetModule()->AddComputationAndUnifyNamesAndIds(builder.Build(), /*is_entry=*/false); @@ -592,7 +746,7 @@ class GemmRewriterTritonVisitor : public DfsHloRewriteVisitor { } else { TF_RETURN_IF_ERROR(ReplaceInstruction(dot, dot_fusion)); } - VLOG(5) << computation->ToString(); + XLA_VLOG_LINES(5, computation->ToString()); return OkStatus(); } @@ -643,7 +797,7 @@ StatusOr MakeSplitKOperand( for (const HloInstruction* param : analysis.ScopeParameters(scope)) { // If an operand of dot does not read any parameters its K dimension // does not need analysis for fragmentation. - const DotFusionAnalysis::DimIterationSpec* spec = + const DimIterationSpec* spec = analysis.IterSpec(scope, param, contracting_dim_idx); // Split contracting dimension is not implemented yet. CHECK_EQ(spec->size(), 1); @@ -885,8 +1039,8 @@ DotFusionAnalysis::DotFusionAnalysis(const HloComputation* dot_computation, absl::flat_hash_map dim_orders; DimensionOrder dot_operand_dim_order = DimensionOrder::FromDotOperand(*dot, operand_number, split_k); - TF_CHECK_OK(dot_operand_dim_order.HandleInstruction(dot_operand)); - TF_CHECK_OK(RequireTritonGemmSupportedDimOrder(dot_operand_dim_order)) + CHECK(dot_operand_dim_order.HandleInstruction(dot_operand)); + CHECK(RequireTritonGemmSupportedDimOrder(dot_operand_dim_order)) << dot_computation->ToString(); dim_orders.insert({dot_operand, dot_operand_dim_order}); visited.insert(dot_operand); @@ -907,14 +1061,18 @@ DotFusionAnalysis::DotFusionAnalysis(const HloComputation* dot_computation, {hlo_operand, DimensionOrder(dim_orders.at(hlo))}); CHECK(inserted); DimensionOrder& hlo_operand_dim_order = it->second; - TF_CHECK_OK(hlo_operand_dim_order.HandleInstruction(hlo_operand)); - TF_CHECK_OK(RequireTritonGemmSupportedDimOrder(hlo_operand_dim_order)) + CHECK(hlo_operand_dim_order.HandleInstruction(hlo_operand)); + CHECK(RequireTritonGemmSupportedDimOrder(hlo_operand_dim_order)) << " " << dot_computation->ToString(); to_process.push(hlo_operand); } } + // For now all parameters of one scope have to use the same tiling. for (const HloInstruction* parameter : parameters_[scope]) { + CHECK(dim_orders.at(parameter).IsPhysicallyEquivalent( + dim_orders.at(*parameters_[scope].cbegin()))) + << dot_computation->ToString(); iter_specs_[scope][parameter] = DimensionOrderToTensorIterationSpec(dim_orders.at(parameter)); } @@ -926,22 +1084,22 @@ DotFusionAnalysis::DotFusionAnalysis(const HloComputation* dot_computation, .second); } -const DotFusionAnalysis::DimIterationSpec* DotFusionAnalysis::IterSpec( +const DimIterationSpec* DotFusionAnalysis::IterSpec( const DotFusionAnalysis::Scope scope, const HloInstruction* hlo, const int dimension) const { auto ret = iter_specs_.at(scope).find(hlo); if (ret != iter_specs_.at(scope).end()) { - return &ret->second.at(dimension); + return &ret->second[dimension]; } return nullptr; } -bool CanTritonHandleGEMM(const HloInstruction& dot, - const GpuVersion gpu_version) { +FusionDecision CanTritonHandleGEMM(const HloInstruction& dot, + const GpuVersion gpu_version) { if (dot.opcode() != HloOpcode::kDot || absl::c_any_of(dot.precision_config().operand_precision(), [](int x) { return x != PrecisionConfig::DEFAULT; })) { - return false; + return "Non-default precision."; } auto supported_output_type = [&](const PrimitiveType t) { @@ -961,21 +1119,21 @@ bool CanTritonHandleGEMM(const HloInstruction& dot, // TODO(b/266862493): Support more output types. if (!supported_output_type(dot.shape().element_type())) { - return false; + return "Unsupported output data type."; } if (!IsSupportedDataType(dot.operand(0)->shape().element_type(), gpu_version) || !IsSupportedDataType(dot.operand(1)->shape().element_type(), gpu_version)) { - return false; + return "Unsupported input data type."; } const DotDimensionNumbers& dim_numbers = dot.dot_dimension_numbers(); // TODO(b/269580541): support multiple batch dimensions. if (dim_numbers.lhs_batch_dimensions().size() > 1) { - return false; + return "Multiple batch dimensions."; } // Cases where lhs or rhs have no non-contracting dims are not handled. @@ -985,10 +1143,10 @@ bool CanTritonHandleGEMM(const HloInstruction& dot, dim_numbers.rhs_batch_dimensions().size() + dim_numbers.rhs_contracting_dimensions().size() == dot.operand(1)->shape().rank()) { - return false; + return "No non-contracting dimensions."; } - return true; + return FusionDecision{}; } bool ShouldTritonHandleGEMM(const HloInstruction& dot, @@ -1008,7 +1166,7 @@ bool ShouldTritonHandleGEMM(const HloInstruction& dot, while (!queue.empty()) { const HloInstruction* current = queue.front(); queue.pop(); - if (!CanFuse(current, dim_order, gpu_version).ok()) { + if (!CanFuse(*current, dim_order, gpu_version)) { continue; } // Stop as soon as a profitable operation is fused. diff --git a/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton.h b/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton.h index 715c79d9114659..0afc939b43ede2 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton.h +++ b/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton.h @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/compiler/xla/hlo/ir/hlo_opcode.h" #include "tensorflow/compiler/xla/service/gpu/gpu_types.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/compiler/xla/service/instruction_fusion.h" namespace xla { namespace gpu { @@ -52,13 +53,13 @@ Status MakeDotSplitKBatch(HloInstruction* dot_fusion, const AutotuneResult::TritonGemmKey& tiling); // Filters GEMMs which can be handled using Triton. -bool CanTritonHandleGEMM(const HloInstruction&, GpuVersion gpu_version); +FusionDecision CanTritonHandleGEMM(const HloInstruction&, + GpuVersion gpu_version); // Filters GEMMs which are better to handle using Triton. bool ShouldTritonHandleGEMM(const HloInstruction&, GpuVersion gpu_version); -// Analysis of iteration of HLO shapes within a fusion around dot(). -class DotFusionAnalysis { +class TensorIterationSpec { public: // Description of basic iteration: `count` elements separated by `stride`. struct IterationSpecFragment { @@ -68,16 +69,42 @@ class DotFusionAnalysis { // of several HLO dimensions. Product of subfragments equals `count`. std::vector subfragments; }; - // Description of complex iteration over a sequence of several strides. // Describes a logically contiguous dimension of a tensor physically // separated into multiple fragments by other dimensions. using DimIterationSpec = std::vector; // At most: contracting, non-contracting, split-K, another batch. - static const int kMaxDimsPerTensor = 4; - using TensorIterationSpec = std::array; + static constexpr int kMaxDimsPerTensor = 4; + using StorageType = std::array; + + const DimIterationSpec& operator[](int dimension) const { + return dim_iteration_specs_[dimension]; + } + + DimIterationSpec& operator[](int dimension) { + return dim_iteration_specs_[dimension]; + } + + // Compares physical layouts of tensors ignoring subfragments of dimensions. + bool operator==(const TensorIterationSpec& other) const; + + StorageType::iterator begin() { return dim_iteration_specs_.begin(); } + StorageType::iterator end() { return dim_iteration_specs_.end(); } + StorageType::const_iterator cbegin() const { + return dim_iteration_specs_.cbegin(); + } + StorageType::const_iterator cend() const { + return dim_iteration_specs_.cend(); + } + + private: + StorageType dim_iteration_specs_; +}; +// Analysis of iteration of HLO shapes within a fusion around dot(). +class DotFusionAnalysis { + public: // Execute analysis of dot fusion computation. // split_k indicates whether this operation was converted to the split-K // form and tells the analysis how to interpret the batch dimensions. @@ -88,9 +115,15 @@ class DotFusionAnalysis { // defined by left operand, right operand and output. enum class Scope { LHS = 0, RHS = 1, OUTPUT = 2 }; + // Every parameter requires a separate piece of shared memory for asynchronous + // loads. Multiple parameters are approximately equivalent to multiple + // pipeline stages. + static constexpr int kMaxParameterPerScope = 4; + // Scope -> HLO -> dot dimension number -> iteration spec at the HLO's output. - const DimIterationSpec* IterSpec(Scope scope, const HloInstruction*, - int dimension) const; + const TensorIterationSpec::DimIterationSpec* IterSpec(Scope scope, + const HloInstruction*, + int dimension) const; // Parameter HLO instructions used in a scope of `dot`. const absl::flat_hash_set& ScopeParameters( const Scope scope) const { diff --git a/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton_test.cc b/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton_test.cc index d02faa5b3abdc9..b154efabe1ef0a 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton_test.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton_test.cc @@ -94,7 +94,7 @@ ENTRY e { GmockMatch(m::Fusion(m::Parameter(), m::Parameter()))); } -TEST_F(GemmRewriterTritonTest, DoNotFuseConstant) { +TEST_F(GemmRewriterTritonTest, DoNotFuseConstants) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(R"( HloModule m @@ -102,14 +102,14 @@ HloModule m ENTRY e { p0 = s8[60,5] parameter(0) c0 = f16[60,5] convert(p0) - cst1 = f16[600] constant({...}) - r1 = f16[5,120] reshape(cst1) + cst1 = f16[] constant(1234) + r1 = f16[5,120] broadcast(cst1) ROOT d = f16[60,120] dot(c0, r1), lhs_contracting_dims={1}, rhs_contracting_dims={0} })")); EXPECT_TRUE(GemmRewriterTriton(gpu_version_).Run(module.get()).value()); EXPECT_THAT(module->entry_computation()->root_instruction(), - GmockMatch(m::Fusion(m::Constant(), m::Parameter()))); + GmockMatch(m::Fusion(m::Parameter(), m::Broadcast()))); } using TritonDotAnalysisTest = HloTestBase; @@ -793,6 +793,145 @@ ENTRY e { EXPECT_TRUE(GemmRewriterTriton(cc).Run(module.get()).value()); } +TEST_F(GemmRewriterTritonTest, DoNotFuseIncompatibleDimOrders) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +HloModule m + +ENTRY e { + p0 = f16[5,3] parameter(0) + p1 = f16[5,7] parameter(1) + p2 = f16[7,5] parameter(2) + t = f16[5,7] transpose(p2), dimensions={1,0} + a = f16[5,7] add(t, p1) + ROOT d = f16[3,7] dot(p0, a), + lhs_contracting_dims={0}, rhs_contracting_dims={0} +})")); + + EXPECT_TRUE(GemmRewriterTriton(gpu_version_).Run(module.get()).value()); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch(m::Fusion(m::Parameter(), m::Parameter(), m::Transpose()))); +} + +TEST_F(GemmRewriterTritonTest, DoNotFuseTooManyParameters) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +ENTRY e { + tmp_0 = f32[] constant(1) + tmp_1 = f32[3,49]{1,0} broadcast(tmp_0), dimensions={} + tmp_2 = f32[3,49]{1,0} parameter(6) + tmp_3 = f32[] constant(0) + tmp_4 = f32[3,49]{1,0} broadcast(tmp_3), dimensions={} + tmp_5 = pred[3,49]{1,0} compare(tmp_2, tmp_4), direction=GT + tmp_6 = f32[3,49]{1,0} convert(tmp_5) + tmp_7 = f32[3,49]{1,0} subtract(tmp_1, tmp_6) + tmp_8 = s32[] parameter(13) + tmp_9 = f32[] convert(tmp_8) + tmp_10 = f32[] maximum(tmp_9, tmp_0) + tmp_11 = f32[] divide(tmp_3, tmp_10) + tmp_12 = f32[3,49]{1,0} broadcast(tmp_11), dimensions={} + tmp_13 = pred[3,49]{1,0} parameter(7) + tmp_14 = pred[3,49]{1,0} parameter(10) + tmp_15 = pred[3,49]{1,0} and(tmp_13, tmp_14) + tmp_16 = f32[3,49]{1,0} convert(tmp_15) + tmp_17 = f32[3,49]{1,0} multiply(tmp_12, tmp_16) + tmp_18 = f32[3,49]{1,0} negate(tmp_17) + tmp_19 = f32[3,49]{1,0} multiply(tmp_7, tmp_18) + tmp_20 = f32[3,49]{1,0} parameter(19) + tmp_21 = f32[3,49]{1,0} subtract(tmp_1, tmp_20) + tmp_22 = f32[3,49]{1,0} divide(tmp_19, tmp_21) + tmp_23 = f32[3,49]{1,0} negate(tmp_22) + tmp_24 = f32[3,49]{1,0} negate(tmp_6) + tmp_25 = f32[3,49]{1,0} multiply(tmp_24, tmp_17) + tmp_26 = f32[3,49]{1,0} divide(tmp_25, tmp_20) + tmp_27 = f32[3,49]{1,0} add(tmp_23, tmp_26) + tmp_28 = f32[3,49]{1,0} parameter(18) + tmp_29 = f32[3,49]{1,0} multiply(tmp_27, tmp_28) + tmp_30 = f32[3,49]{1,0} parameter(17) + tmp_31 = f32[3,49]{1,0} multiply(tmp_29, tmp_30) + tmp_32 = f32[3,49]{1,0} parameter(16) + tmp_33 = f32[3,49]{1,0} multiply(tmp_31, tmp_32) + tmp_34 = f32[3,49]{1,0} parameter(15) + tmp_35 = f32[3,49]{1,0} add(tmp_33, tmp_34) + tmp_36 = f32[3,49]{1,0} parameter(14) + tmp_37 = f32[3,49]{1,0} add(tmp_35, tmp_36) + tmp_38 = f32[1,1]{1,0} constant({ {0} }) + tmp_39 = f32[1,1]{1,0} broadcast(tmp_38), dimensions={0,1} + tmp_40 = f32[] reshape(tmp_39) + tmp_41 = f32[3,32]{1,0} broadcast(tmp_40), dimensions={} + tmp_42 = u32[48]{0} parameter(11) + tmp_43 = u32[48]{0} parameter(5) + tmp_44 = u32[96]{0} concatenate(tmp_42, tmp_43), dimensions={0} + tmp_45 = u32[3,32]{1,0} reshape(tmp_44) + tmp_46 = u32[96]{0} reshape(tmp_45) + tmp_47 = u32[] constant(1) + tmp_48 = u32[3,32]{1,0} broadcast(tmp_47), dimensions={} + tmp_49 = u32[96]{0} reshape(tmp_48) + tmp_50 = u32[96]{0} shift-right-logical(tmp_46, tmp_49) + tmp_51 = u32[3,32]{1,0} reshape(tmp_50) + tmp_52 = u32[3,32]{1,0} or(tmp_51, tmp_48) + tmp_53 = f32[3,32]{1,0} bitcast-convert(tmp_52) + tmp_54 = f32[3,32]{1,0} broadcast(tmp_0), dimensions={} + tmp_55 = f32[3,32]{1,0} subtract(tmp_53, tmp_54) + tmp_56 = f32[1,1]{1,0} constant({ {1} }) + tmp_57 = f32[1,1]{1,0} broadcast(tmp_56), dimensions={0,1} + tmp_58 = f32[] reshape(tmp_57) + tmp_59 = f32[3,32]{1,0} broadcast(tmp_58), dimensions={} + tmp_60 = f32[3,32]{1,0} multiply(tmp_55, tmp_59) + tmp_61 = f32[3,32]{1,0} add(tmp_60, tmp_41) + tmp_62 = f32[3,32]{1,0} maximum(tmp_41, tmp_61) + tmp_63 = f32[3,32]{1,0} broadcast(tmp_3), dimensions={} + tmp_64 = pred[3,32]{1,0} compare(tmp_62, tmp_63), direction=LT + tmp_65 = f32[3,32]{1,0} convert(tmp_64) + tmp_66 = f32[3,49]{1,0} parameter(9) + tmp_67 = f32[49]{0} parameter(4) + tmp_68 = f32[3,49]{1,0} broadcast(tmp_67), dimensions={1} + tmp_69 = f32[3,49]{1,0} add(tmp_66, tmp_68) + tmp_70 = f32[1,49]{1,0} parameter(12) + tmp_71 = f32[1,49]{1,0} broadcast(tmp_0), dimensions={} + tmp_72 = f32[1,49]{1,0} divide(tmp_70, tmp_71) + tmp_73 = f32[1,49]{1,0} broadcast(tmp_72), dimensions={0,1} + tmp_74 = f32[49]{0} reshape(tmp_73) + tmp_75 = f32[3,49]{1,0} broadcast(tmp_74), dimensions={1} + tmp_76 = f32[3,49]{1,0} subtract(tmp_69, tmp_75) + tmp_77 = f32[1,49]{1,0} parameter(3) + tmp_78 = f32[1,49]{1,0} parameter(8) + tmp_79 = f32[1,49]{1,0} divide(tmp_78, tmp_71) + tmp_80 = f32[1,49]{1,0} multiply(tmp_72, tmp_72) + tmp_81 = f32[1,49]{1,0} subtract(tmp_79, tmp_80) + tmp_82 = f32[1,49]{1,0} add(tmp_81, tmp_71) + tmp_83 = f32[1,49]{1,0} rsqrt(tmp_82) + tmp_84 = f32[1,49]{1,0} multiply(tmp_77, tmp_83) + tmp_85 = f32[1,49]{1,0} broadcast(tmp_84), dimensions={0,1} + tmp_86 = f32[49]{0} reshape(tmp_85) + tmp_87 = f32[3,49]{1,0} broadcast(tmp_86), dimensions={1} + tmp_88 = f32[3,49]{1,0} multiply(tmp_76, tmp_87) + tmp_89 = f32[1,49]{1,0} parameter(2) + tmp_90 = f32[1,49]{1,0} broadcast(tmp_89), dimensions={0,1} + tmp_91 = f32[49]{0} reshape(tmp_90) + tmp_92 = f32[3,49]{1,0} broadcast(tmp_91), dimensions={1} + tmp_93 = f32[3,49]{1,0} add(tmp_88, tmp_92) + tmp_94 = f32[49,32]{1,0} parameter(1) + tmp_95 = f32[3,32]{1,0} dot(tmp_93, tmp_94), lhs_contracting_dims={1}, rhs_contracting_dims={0} + tmp_96 = f32[32]{0} parameter(0) + tmp_97 = f32[3,32]{1,0} broadcast(tmp_96), dimensions={1} + tmp_98 = f32[3,32]{1,0} add(tmp_95, tmp_97) + tmp_99 = f32[3,32]{1,0} multiply(tmp_65, tmp_98) + tmp_100 = f32[3,32]{1,0} divide(tmp_99, tmp_63) + tmp_101 = f32[3,32]{1,0} maximum(tmp_100, tmp_63) + ROOT tmp_102 = f32[49,32]{1,0} dot(tmp_37, tmp_101), lhs_contracting_dims={0}, rhs_contracting_dims={0} +})")); + + EXPECT_TRUE(GemmRewriterTriton(gpu_version_).Run(module.get()).value()); + EXPECT_EQ(module->entry_computation()->root_instruction()->opcode(), + HloOpcode::kFusion); + EXPECT_EQ(module->entry_computation()->root_instruction()->fusion_kind(), + HloInstruction::FusionKind::kCustom); + EXPECT_LE(module->entry_computation()->root_instruction()->operand_count(), + DotFusionAnalysis::kMaxParameterPerScope * 2); +} + } // namespace } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index f490f9b127e21a..b3944952ac68da 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -973,6 +973,29 @@ Status GpuCompiler::OptimizeHloPostLayoutAssignment( }); } + GpuFloatSupport bf16_support(BF16); + GpuFloatSupport f8e5m2_support(F8E5M2); + GpuFloatSupport f8e4m3fn_support(F8E4M3FN); + FloatSupport f8e4m3b11fnuz_support(F8E4M3B11FNUZ); + FloatSupport f8e5m2fnuz_support(F8E5M2FNUZ); + FloatSupport f8e4m3fnuz_support(F8E4M3FNUZ); + + auto add_float_normalization = [&](HloPassPipeline& pipeline) { + auto& sub_pipeline = + pipeline.AddPass("float_normalization"); + sub_pipeline.AddPass(&bf16_support); + sub_pipeline.AddPass(&f8e5m2_support); + sub_pipeline.AddPass(&f8e4m3fn_support); + sub_pipeline.AddPass(&f8e4m3b11fnuz_support); + sub_pipeline.AddPass(&f8e5m2fnuz_support); + sub_pipeline.AddPass(&f8e4m3fnuz_support); + // Remove `f32 -> bf16 -> f32` casts inserted by bf16 normalization. + if (debug_options.xla_gpu_simplify_all_fp_conversions()) { + sub_pipeline.AddPass(); + } + }; + add_float_normalization(pipeline); + // By default use an externally provided thread pool. tsl::thread::ThreadPool* thread_pool = options.thread_pool; std::optional overriding_thread_pool; @@ -994,18 +1017,8 @@ Status GpuCompiler::OptimizeHloPostLayoutAssignment( &pipeline, hlo_module, stream_exec, debug_options, options, gpu_target_config, autotune_results, thread_pool)); - GpuFloatSupport bf16_support(BF16); - pipeline.AddPass(&bf16_support); - GpuFloatSupport f8e5m2_support(F8E5M2); - pipeline.AddPass(&f8e5m2_support); - GpuFloatSupport f8e4m3fn_support(F8E4M3FN); - pipeline.AddPass(&f8e4m3fn_support); - FloatSupport f8e4m3b11fnuz_support(F8E4M3B11FNUZ); - pipeline.AddPass(&f8e4m3b11fnuz_support); - FloatSupport f8e5m2fnuz_support(F8E5M2FNUZ); - pipeline.AddPass(&f8e5m2fnuz_support); - FloatSupport f8e4m3fnuz_support(F8E4M3FNUZ); - pipeline.AddPass(&f8e4m3fnuz_support); + // The Triton autotuner can insert new reductions. + add_float_normalization(pipeline); // Remove `f32 -> bf16 -> f32` casts inserted by bf16 normalization. if (debug_options.xla_gpu_simplify_all_fp_conversions()) { diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_triton.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_triton.cc index 709f3e40b52c3f..7c9cd87953a848 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_triton.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_triton.cc @@ -792,7 +792,7 @@ StatusOr MatMulImpl( if (!analysis.ScopeParameters(DotFusionAnalysis::Scope::LHS).empty()) { const HloInstruction* lhs_param0 = *analysis.ScopeParameters(DotFusionAnalysis::Scope::LHS).begin(); - const DotFusionAnalysis::DimIterationSpec* lhs_nc_iter_spec = + const TensorIterationSpec::DimIterationSpec* lhs_nc_iter_spec = analysis.IterSpec(DotFusionAnalysis::Scope::LHS, lhs_param0, lhs_noncontracting_dim_idx); lhs_nc_split = lhs_nc_iter_spec->size() > 1; diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_triton_test.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_triton_test.cc index fc4bb7204c1632..86d7209de81114 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_triton_test.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_triton_test.cc @@ -25,11 +25,14 @@ limitations under the License. #include "tensorflow/compiler/xla/autotuning.pb.h" #include "tensorflow/compiler/xla/error_spec.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_computation.h" +#include "tensorflow/compiler/xla/hlo/ir/hlo_instruction.h" #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" #include "tensorflow/compiler/xla/service/gpu/gpu_device_info_for_tests.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/gpu/launch_dimensions.h" #include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h" +#include "tensorflow/compiler/xla/service/pattern_matcher.h" +#include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h" #include "tensorflow/compiler/xla/stream_executor/device_description.h" #include "tensorflow/compiler/xla/tests/verified_hlo_module.h" #include "tensorflow/tsl/lib/core/status_test_util.h" @@ -42,6 +45,8 @@ namespace xla { namespace gpu { namespace { +namespace m = ::xla::match; + class TritonGemmNoTF32Test : public GpuCodegenTest { public: void SetUp() override { @@ -715,6 +720,153 @@ ENTRY e { EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{/*aabs=*/1e-6, /*arel=*/1e-6})); } +TEST_F(TritonGemmTest, BinaryOperationWithSmallInputsIsFused) { + const std::string kHloText = R"( +HloModule m + +ENTRY e { + p0 = s8[7,3] parameter(0) + p1 = f32[3,16] parameter(1) + p2 = f32[3,16] parameter(2) + e = f32[3,16] exponential(p1) + a = f32[3,16] add(e, p2) + c = f32[7,3] convert(p0) + ROOT d = f32[7,16] dot(c, a), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +})"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + GetOptimizedModule(kHloText)); + + EXPECT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch(m::Fusion(m::Parameter(), m::Parameter(), m::Parameter()) + .WithFusionKind(HloInstruction::FusionKind::kCustom))); + + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-1, /*arel=*/1e-3})); +} + +TEST_F(TritonGemmTest, BinaryOperationWithLargeInputsIsNotFused) { + const std::string kHloText = R"( +HloModule m + +ENTRY e { + p0 = f16[333,1000] parameter(0) + p1 = f32[1000,333] parameter(1) + p1n = f32[1000,333] negate(p1) + p2 = f32[1000,333] parameter(2) + p2n = f32[1000,333] negate(p2) + s = f32[1000,333] subtract(p1n, p2n) + c = f32[333,1000] convert(p0) + ROOT d = f32[1000,1000] dot(s, c), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +})"; + + MatchOptimizedHlo(kHloText, R"( +; CHECK: fused_computation +; CHECK: negate +; CHECK: negate +; CHECK: ROOT +; CHECK-SAME: subtract +; CHECK: ENTRY +; CHECK: kLoop +; CHECK: kCustom +)"); + + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-1, /*arel=*/1e-3})); +} + +TEST_F(TritonGemmTest, BinaryOperationOnLargeParametersIsFused) { + const std::string kHloText = R"( +HloModule m + +ENTRY e { + p0 = f16[1000,111] parameter(0) + p1 = f32[111,10000] parameter(1) + p2 = f32[111,10000] parameter(2) + s = f32[111,10000] subtract(p1, p2) + c = f32[1000,111] convert(p0) + ROOT d = f32[10000,1000] dot(s, c), + lhs_contracting_dims={0}, rhs_contracting_dims={1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + GetOptimizedModule(kHloText)); + + EXPECT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch(m::Fusion(m::Parameter(), m::Parameter(), m::Parameter()) + .WithFusionKind(HloInstruction::FusionKind::kCustom))); + + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-1, /*arel=*/1e-3})); +} + +TEST_F(TritonGemmTest, LinkingLibdeviceTwiceWorks) { + const std::string kHloText = R"( +HloModule m + +ENTRY e { + p0 = s8[7,3] parameter(0) + c0 = f32[7,3] convert(p0) + e0 = f32[7,3] exponential(c0) + p1 = f32[3,16] parameter(1) + e1 = f32[3,16] exponential(p1) + d0 = f32[7,16] dot(c0, e1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} + d1 = f32[7,16] dot(e0, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} + ROOT a = f32[7,16] add(d0, d1) +})"; + + MatchOptimizedHlo(kHloText, R"( +; CHECK: ENTRY +; CHECK-NEXT: parameter +; CHECK-NEXT: parameter +; CHECK-NEXT: kCustom +; CHECK-NEXT: kCustom +; CHECK-NEXT: ROOT +; CHECK-SAME: add +)"); + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + GetOptimizedModule(kHloText)); + + EXPECT_THAT(module->entry_computation()->root_instruction(), + GmockMatch(m::Add( + m::Fusion(m::Parameter(), m::Parameter()) + .WithFusionKind(HloInstruction::FusionKind::kCustom), + m::Fusion(m::Parameter(), m::Parameter()) + .WithFusionKind(HloInstruction::FusionKind::kCustom)))); + + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-2, /*arel=*/1e-2})); +} + +TEST_F(TritonGemmTest, BroadcastOfConstantIsNotFused) { + const std::string kHloText = R"( +HloModule m + +ENTRY e { + p0 = f16[70,30] parameter(0) + p0c = f32[70,30] convert(p0) + constant_3663 = f32[] constant(4321) + bc0 = f32[30,5] broadcast(constant_3663) + p1 = f32[30,5] parameter(1) + a = f32[30,5] add(p1, bc0) + ROOT d = f32[70,5] dot(p0c, a), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +})"; + + MatchOptimizedHlo(kHloText, R"( +; CHECK: ENTRY +; CHECK: constant +; CHECK: broadcast +; CHECK: fusion +; CHECK-SAME: kind=kCustom +)"); + + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/2e-3, /*arel=*/2e-3})); +} + TEST_F(TritonGemmTest, Naming) { const char* hlo_text = R"( HloModule t diff --git a/tensorflow/compiler/xla/service/gpu/triton_autotuner.cc b/tensorflow/compiler/xla/service/gpu/triton_autotuner.cc index 440a9611a8fe27..b8b8b5f6719931 100644 --- a/tensorflow/compiler/xla/service/gpu/triton_autotuner.cc +++ b/tensorflow/compiler/xla/service/gpu/triton_autotuner.cc @@ -418,12 +418,11 @@ std::vector GetExhaustiveMatmulAutotuneConfigs( std::vector GetFixedMatmulAutotuneConfigs( const se::CudaComputeCapability compute_capability) { std::vector configs = { - GemmKey(32, 32, 256, 1, 1, 4), GemmKey(64, 32, 32, 16, 1, 4), - GemmKey(32, 64, 64, 4, 1, 4), GemmKey(128, 128, 64, 4, 1, 4), - GemmKey(16, 16, 256, 1, 1, 4), GemmKey(16, 128, 32, 16, 1, 4), - GemmKey(16, 64, 128, 1, 1, 4), GemmKey(16, 128, 32, 8, 1, 4), - GemmKey(16, 16, 512, 1, 1, 4), GemmKey(32, 16, 512, 1, 1, 4), - GemmKey(64, 32, 64, 1, 2, 8)}; + GemmKey(32, 32, 256, 1, 1, 4), GemmKey(64, 32, 32, 16, 1, 4), + GemmKey(32, 64, 64, 4, 1, 4), GemmKey(16, 16, 256, 1, 1, 4), + GemmKey(16, 128, 32, 16, 1, 4), GemmKey(16, 64, 128, 1, 1, 4), + GemmKey(16, 128, 32, 8, 1, 4), GemmKey(16, 16, 512, 1, 1, 4), + GemmKey(32, 16, 512, 1, 1, 4), GemmKey(64, 32, 64, 1, 2, 8)}; if (compute_capability.IsAtLeast(se::CudaComputeCapability::AMPERE)) { absl::c_copy( std::vector{ From 9e9b04efbc49b3f57da5d60fce7f20a1d875dbe7 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 10 Jul 2023 11:36:46 -0700 Subject: [PATCH 067/376] Integrate LLVM at llvm/llvm-project@bfd94882f264 Updates LLVM usage to match [bfd94882f264](https://github.com/llvm/llvm-project/commit/bfd94882f264) PiperOrigin-RevId: 546934735 --- third_party/llvm/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl index 278db4f8354cd9..a439c3f924583c 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "dbaa5838c13e5593b9de37b8f3daffe4cb914a17" - LLVM_SHA256 = "79a02eb8733ec1f51c23fdc0cfc123fb023d855fe53ca59515cd8c6cb2af8993" + LLVM_COMMIT = "bfd94882f2648e2a5ed651bca6cfeb4fb7788b86" + LLVM_SHA256 = "9c23082138fa8706ebb4c4e5e2f1873d954202b9b76ef4ee52542ab00262f5dd" tf_http_archive( name = name, From 883928088216e29e34ba42ddc40acff646a76198 Mon Sep 17 00:00:00 2001 From: Terry Heo Date: Mon, 10 Jul 2023 11:47:56 -0700 Subject: [PATCH 068/376] lite: Add TFLite Benchmark Op Profiler supports of GPU delegate //tensorflow/lite/delegates/gpu:tflite_profile is added to access the given TFLite Profiler object via TfLiteContext. Now we can use '--enable_op_profiling=true' with ' --use_gpu=true' PiperOrigin-RevId: 546937910 --- tensorflow/lite/CMakeLists.txt | 1 + tensorflow/lite/delegates/gpu/BUILD | 12 +++++ tensorflow/lite/delegates/gpu/cl/BUILD | 1 + tensorflow/lite/delegates/gpu/cl/api.cc | 13 +++++- tensorflow/lite/delegates/gpu/delegate.cc | 5 +- .../lite/delegates/gpu/tflite_profile.cc | 46 +++++++++++++++++++ .../lite/delegates/gpu/tflite_profile.h | 38 +++++++++++++++ 7 files changed, 113 insertions(+), 3 deletions(-) create mode 100644 tensorflow/lite/delegates/gpu/tflite_profile.cc create mode 100644 tensorflow/lite/delegates/gpu/tflite_profile.h diff --git a/tensorflow/lite/CMakeLists.txt b/tensorflow/lite/CMakeLists.txt index 78a3f922ec655b..bc97bac8a1b102 100644 --- a/tensorflow/lite/CMakeLists.txt +++ b/tensorflow/lite/CMakeLists.txt @@ -337,6 +337,7 @@ if(TFLITE_ENABLE_GPU) ${TFLITE_SOURCE_DIR}/delegates/gpu/api.cc ${TFLITE_SOURCE_DIR}/delegates/gpu/delegate.cc ${TFLITE_SOURCE_DIR}/delegates/gpu/delegate_options.cc + ${TFLITE_SOURCE_DIR}/delegates/gpu/tflite_profile.cc ${TFLITE_SOURCE_DIR}/experimental/acceleration/compatibility/android_info.cc ${TFLITE_DELEGATES_GPU_CL_SRCS} ${TFLITE_DELEGATES_GPU_CL_DEFAULT_SRCS} diff --git a/tensorflow/lite/delegates/gpu/BUILD b/tensorflow/lite/delegates/gpu/BUILD index 40f95082cd86cf..875c2a4f3da7df 100644 --- a/tensorflow/lite/delegates/gpu/BUILD +++ b/tensorflow/lite/delegates/gpu/BUILD @@ -240,6 +240,7 @@ cc_library( }) + [ ":api", ":delegate_options", + ":tflite_profile", "//tensorflow/lite:kernel_api", "//tensorflow/lite:minimal_logging", "//tensorflow/lite/core/async:backend_async_kernel_interface", @@ -267,3 +268,14 @@ cc_library( "@com_google_absl//absl/types:span", ], ) + +cc_library( + name = "tflite_profile", + srcs = ["tflite_profile.cc"], + hdrs = ["tflite_profile.h"], + deps = [ + "//tensorflow/lite/core/api", + "//tensorflow/lite/delegates/gpu/common/task:profiling_info", + "@com_google_absl//absl/time", + ], +) diff --git a/tensorflow/lite/delegates/gpu/cl/BUILD b/tensorflow/lite/delegates/gpu/cl/BUILD index 50ab40d61f8206..d710059af90886 100644 --- a/tensorflow/lite/delegates/gpu/cl/BUILD +++ b/tensorflow/lite/delegates/gpu/cl/BUILD @@ -36,6 +36,7 @@ cc_library( ":tensor", ":tensor_type_util", "//tensorflow/lite/delegates/gpu:api", + "//tensorflow/lite/delegates/gpu:tflite_profile", "//tensorflow/lite/delegates/gpu/cl/kernels:converter", "//tensorflow/lite/delegates/gpu/common:data_type", "//tensorflow/lite/delegates/gpu/common:model", diff --git a/tensorflow/lite/delegates/gpu/cl/api.cc b/tensorflow/lite/delegates/gpu/cl/api.cc index 21462b111af1de..490836435a02df 100644 --- a/tensorflow/lite/delegates/gpu/cl/api.cc +++ b/tensorflow/lite/delegates/gpu/cl/api.cc @@ -43,6 +43,7 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/shape.h" #include "tensorflow/lite/delegates/gpu/common/task/tensor_desc.h" #include "tensorflow/lite/delegates/gpu/common/tensor.h" +#include "tensorflow/lite/delegates/gpu/tflite_profile.h" #ifdef CL_DELEGATE_ALLOW_GL #include @@ -454,6 +455,7 @@ class InferenceRunnerImpl : public CLInferenceRunner { #endif ) : queue_(environment->queue()), + profiling_queue_(environment->profiling_queue()), context_(std::move(context)) #ifdef CL_DELEGATE_ALLOW_GL , @@ -555,8 +557,14 @@ class InferenceRunnerImpl : public CLInferenceRunner { } absl::Status RunWithoutExternalBufferCopy() override { - RETURN_IF_ERROR(context_->AddToQueue(queue_)); - clFlush(queue_->queue()); + if (IsTfLiteProfilerActive()) { + ProfilingInfo profiling_info; + RETURN_IF_ERROR(context_->Profile(profiling_queue_, &profiling_info)); + AddTfLiteProfilerEvents(&profiling_info); + } else { + RETURN_IF_ERROR(context_->AddToQueue(queue_)); + clFlush(queue_->queue()); + } return absl::OkStatus(); } @@ -585,6 +593,7 @@ class InferenceRunnerImpl : public CLInferenceRunner { } CLCommandQueue* queue_; + ProfilingCommandQueue* profiling_queue_; std::unique_ptr context_; #ifdef CL_DELEGATE_ALLOW_GL std::unique_ptr gl_interop_fabric_; diff --git a/tensorflow/lite/delegates/gpu/delegate.cc b/tensorflow/lite/delegates/gpu/delegate.cc index e0ba598c843fe4..f51e01fe036a9a 100644 --- a/tensorflow/lite/delegates/gpu/delegate.cc +++ b/tensorflow/lite/delegates/gpu/delegate.cc @@ -52,6 +52,7 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/model_builder_helper.h" #include "tensorflow/lite/delegates/gpu/common/quantization_util.h" #include "tensorflow/lite/delegates/gpu/delegate_options.h" +#include "tensorflow/lite/delegates/gpu/tflite_profile.h" #include "tensorflow/lite/delegates/serialization.h" #if defined(__ANDROID__) @@ -171,7 +172,7 @@ class Delegate { delegate_.CopyFromBufferHandle = nullptr; delegate_.CopyToBufferHandle = nullptr; delegate_.FreeBufferHandle = nullptr; - delegate_.flags = kTfLiteDelegateFlagsNone; + delegate_.flags = kTfLiteDelegateFlagsPerOperatorProfiling; options_ = options ? *options : TfLiteGpuDelegateOptionsV2Default(); if (options_.max_delegated_partitions <= 0) { options_.max_delegated_partitions = 1; @@ -1496,6 +1497,8 @@ TfLiteStatus DelegatePrepare(TfLiteContext* context, TfLiteDelegate* delegate) { telemetry::TelemetryReportDelegateSettings( context, "GpuDelegate::DelegatePrepare", telemetry::TelemetrySource::TFLITE_GPU, delegate_setting); + + SetTfLiteProfiler(context->profiler); return status; } diff --git a/tensorflow/lite/delegates/gpu/tflite_profile.cc b/tensorflow/lite/delegates/gpu/tflite_profile.cc new file mode 100644 index 00000000000000..f0b95553845db4 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/tflite_profile.cc @@ -0,0 +1,46 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/delegates/gpu/tflite_profile.h" + +#include "absl/time/time.h" +#include "tensorflow/lite/core/api/profiler.h" + +namespace tflite { +namespace gpu { + +static void* s_profiler = nullptr; + +bool IsTfLiteProfilerActive() { return s_profiler != nullptr; } + +void SetTfLiteProfiler(void* profiler) { s_profiler = profiler; } + +void* GetTfLiteProfiler() { return s_profiler; } + +void AddTfLiteProfilerEvents(tflite::gpu::ProfilingInfo* profiling_info) { + tflite::Profiler* profile = + reinterpret_cast(GetTfLiteProfiler()); + if (profile == nullptr) return; + + int node_index = 0; + for (const auto& dispatch : profiling_info->dispatches) { + profile->AddEvent( + dispatch.label.c_str(), + Profiler::EventType::DELEGATE_PROFILED_OPERATOR_INVOKE_EVENT, + absl::ToDoubleMicroseconds(dispatch.duration), node_index++); + } +} + +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/tflite_profile.h b/tensorflow/lite/delegates/gpu/tflite_profile.h new file mode 100644 index 00000000000000..6e9d7310ffa04c --- /dev/null +++ b/tensorflow/lite/delegates/gpu/tflite_profile.h @@ -0,0 +1,38 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_TFLITE_PROFILE_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_TFLITE_PROFILE_H_ + +#include "tensorflow/lite/delegates/gpu/common/task/profiling_info.h" + +namespace tflite { +namespace gpu { + +// Returns if TFLite Profiler is active. +bool IsTfLiteProfilerActive(); + +// Save the given TFLite Profiler object (from TfLiteContext) for op profiling. +void SetTfLiteProfiler(void* profiler); + +// Returns saved TFLite Profiler object. +void* GetTfLiteProfiler(); + +// Generate TFLite Profiler events with the given ProfilingInfo object. +void AddTfLiteProfilerEvents(tflite::gpu::ProfilingInfo* profiling_info); + +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_TFLITE_PROFILE_H_ From 7a4aa69c64e8d77869579e0958970a100dd9a135 Mon Sep 17 00:00:00 2001 From: Sagun Bajra Date: Mon, 10 Jul 2023 12:13:18 -0700 Subject: [PATCH 069/376] Increase timeout for the env test. PiperOrigin-RevId: 546944708 --- tensorflow/core/platform/BUILD | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/platform/BUILD b/tensorflow/core/platform/BUILD index 524b890994880e..4409054d26f0a2 100644 --- a/tensorflow/core/platform/BUILD +++ b/tensorflow/core/platform/BUILD @@ -1259,7 +1259,7 @@ tf_cc_test( tf_cc_test( name = "fake_python_env_test", - size = "small", + size = "medium", srcs = ["fake_python_env_test.cc"], args = [ "/some/path/to/pythontest.runfiles/org_tensorflow/stuff/to/run.py", From e7d8ebb233aa927a4b405c43371ff2ff80731f84 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 10 Jul 2023 12:14:28 -0700 Subject: [PATCH 070/376] [PJRT C API] Remove ApiVersion test. The ApiVersion test has now been moved to the test base in [pjrt_c_api_test.cc](third_party/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_test.cc) PiperOrigin-RevId: 546944989 --- tensorflow/compiler/xla/pjrt/c/pjrt_c_api_gpu_test.cc | 5 ----- 1 file changed, 5 deletions(-) diff --git a/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_gpu_test.cc b/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_gpu_test.cc index 627c8b864c882a..f8b6c7ae5351fa 100644 --- a/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_gpu_test.cc +++ b/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_gpu_test.cc @@ -85,11 +85,6 @@ TEST_F(PjrtCApiGpuTest, PlatformName) { ASSERT_EQ("gpu", platform_name); } -TEST_F(PjrtCApiGpuTest, ApiVersion) { - CHECK_EQ(api_->pjrt_api_version.major_version, PJRT_API_MAJOR); - CHECK_EQ(api_->pjrt_api_version.minor_version, PJRT_API_MINOR); -} - std::unique_ptr<::pjrt::PJRT_KeyValueCallbackData> CreateTestCKVCallback( absl::flat_hash_map* kv_store, absl::Mutex& mu) { PjRtClient::KeyValueGetCallback kv_get = From cc8cb9dc9cda63765eca3ec78338a7d01064ef26 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 10 Jul 2023 12:18:11 -0700 Subject: [PATCH 071/376] Update TFRT dependency to use revision http://github.com/tensorflow/runtime/commit/46664e9fe40d2a204bbc8b629a778cb8032fd8c1. PiperOrigin-RevId: 546945943 --- third_party/tf_runtime/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/tf_runtime/workspace.bzl b/third_party/tf_runtime/workspace.bzl index 8b5a8623f70020..03069a8bb25376 100644 --- a/third_party/tf_runtime/workspace.bzl +++ b/third_party/tf_runtime/workspace.bzl @@ -6,8 +6,8 @@ def repo(): """Imports TFRT.""" # Attention: tools parse and update these lines. - TFRT_COMMIT = "00a31b7ce92e062de48321d9ff50ad414144a47b" - TFRT_SHA256 = "113355c7dd55eb34346e2264544f309068acee5f8102a1a8f2146fc6a571cece" + TFRT_COMMIT = "46664e9fe40d2a204bbc8b629a778cb8032fd8c1" + TFRT_SHA256 = "2d676f58d4e803a0f4e4a9de951a5e2663dad51d535147ae748aec806522e6bf" tf_http_archive( name = "tf_runtime", From 8b116e21d125efe69764316a86af4f109bb3d5b6 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Mon, 10 Jul 2023 13:16:53 -0700 Subject: [PATCH 072/376] [xla:gpu] Remove fp8 check as it is fully supported now PiperOrigin-RevId: 546960448 --- .../compiler/xla/service/gpu/compile_module_to_llvm_ir.cc | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/compile_module_to_llvm_ir.cc b/tensorflow/compiler/xla/service/gpu/compile_module_to_llvm_ir.cc index b27ffc0fbd6e07..f950b5850697ea 100644 --- a/tensorflow/compiler/xla/service/gpu/compile_module_to_llvm_ir.cc +++ b/tensorflow/compiler/xla/service/gpu/compile_module_to_llvm_ir.cc @@ -427,10 +427,7 @@ Status CompileModuleToLlvmIrImpl( RecordHloToLlvmDuration(end_usecs - start_usecs); } - // TODO(ezhulenev): Remove the FP8 check once https://reviews.llvm.org/D140088 - // is submitted. Currently we can't emit LLVM IR with fp8 types. - if (IsXlaRuntimeExecutableEnabled(hlo_module->config()) && - !HasFp8(*hlo_module)) { + if (IsXlaRuntimeExecutableEnabled(hlo_module->config())) { std::vector buffer_sizes; llvm::transform( results->allocations, std::back_inserter(buffer_sizes), From 1ce8c076abf604820c58310670abe815ce549588 Mon Sep 17 00:00:00 2001 From: Fiona Lang Date: Mon, 10 Jul 2023 13:32:45 -0700 Subject: [PATCH 073/376] Update visibility for //third_party/tensorflow/python/framework:tensor. PiperOrigin-RevId: 546965766 --- tensorflow/python/framework/BUILD | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tensorflow/python/framework/BUILD b/tensorflow/python/framework/BUILD index 64729bcd99e1c3..20f54f63c34c7f 100644 --- a/tensorflow/python/framework/BUILD +++ b/tensorflow/python/framework/BUILD @@ -1836,6 +1836,12 @@ py_strict_library( py_strict_library( name = "tensor", srcs = ["tensor.py"], + visibility = visibility + [ + "//tensorflow:internal", + "//tensorflow_models:__subpackages__", + "//third_party/mlperf:__subpackages__", + "//third_party/py/tf_slim:__subpackages__", + ], deps = [ ":common_shapes", ":dtypes", From 6473f74c9311e90d58e6cc935ec0b3edca30d107 Mon Sep 17 00:00:00 2001 From: David Dunleavy Date: Mon, 10 Jul 2023 13:34:57 -0700 Subject: [PATCH 074/376] [NFC] Change uses of get_compatible_with_cloud to get_compatible_with_portable. PiperOrigin-RevId: 546966435 --- tensorflow/python/grappler/BUILD | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/grappler/BUILD b/tensorflow/python/grappler/BUILD index 229152816da5cb..26125785138c5d 100644 --- a/tensorflow/python/grappler/BUILD +++ b/tensorflow/python/grappler/BUILD @@ -1,5 +1,5 @@ load("//tensorflow:strict.default.bzl", "py_strict_binary", "py_strict_library") -load("//tensorflow:tensorflow.default.bzl", "cuda_py_strict_test", "get_compatible_with_cloud", "tf_py_strict_test", "tf_pybind_cc_library_wrapper", "tf_python_pybind_extension") +load("//tensorflow:tensorflow.default.bzl", "cuda_py_strict_test", "get_compatible_with_portable", "tf_py_strict_test", "tf_pybind_cc_library_wrapper", "tf_python_pybind_extension") load("//tensorflow/core/platform:build_config.bzl", "tf_protos_grappler") load("//tensorflow:tensorflow.bzl", "if_not_windows") @@ -15,7 +15,7 @@ cc_library( name = "cost_analyzer_lib", srcs = ["cost_analyzer.cc"], hdrs = ["cost_analyzer.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", From 48cecb7ba9b9b65ffdb077bb28485aabfe341f57 Mon Sep 17 00:00:00 2001 From: Fiona Lang Date: Mon, 10 Jul 2023 13:35:24 -0700 Subject: [PATCH 075/376] Update ops.Tensor references to //third_party/tensorflow/python/framework/tensor.py. PiperOrigin-RevId: 546966539 --- .../python/eager/polymorphic_function/BUILD | 16 +- .../polymorphic_function/compiler_ir_test.py | 24 +-- .../polymorphic_function/concrete_function.py | 12 +- .../function_type_utils.py | 12 +- .../polymorphic_function_test.py | 202 +++++++++--------- .../polymorphic_function_xla_jit_test.py | 14 +- .../tracing_compilation_test.py | 124 +++++------ tensorflow/python/tpu/BUILD | 6 +- tensorflow/python/tpu/tensor_tracer.py | 9 +- .../python/tpu/tpu_embedding_for_serving.py | 9 +- tensorflow/python/tpu/tpu_embedding_v1.py | 25 ++- tensorflow/python/tpu/tpu_embedding_v2.py | 9 +- .../tpu/tpu_outside_compilation_test.py | 6 +- tensorflow/python/training/BUILD | 6 + tensorflow/python/training/input.py | 3 +- .../python/training/monitored_session_test.py | 11 +- tensorflow/python/training/moving_averages.py | 3 +- tensorflow/python/training/optimizer.py | 15 +- .../training/sync_replicas_optimizer.py | 3 +- tensorflow/python/training/training_util.py | 3 +- 20 files changed, 266 insertions(+), 246 deletions(-) diff --git a/tensorflow/python/eager/polymorphic_function/BUILD b/tensorflow/python/eager/polymorphic_function/BUILD index e020919e5f9d8f..0f8b19c819d187 100644 --- a/tensorflow/python/eager/polymorphic_function/BUILD +++ b/tensorflow/python/eager/polymorphic_function/BUILD @@ -105,8 +105,8 @@ py_strict_library( "//tensorflow/python/framework:func_graph", "//tensorflow/python/framework:indexed_slices", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_shape", - "//tensorflow/python/framework:tensor_spec", "//tensorflow/python/framework:type_spec", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:default_gradient", @@ -266,8 +266,8 @@ cuda_py_strict_test( "//tensorflow/python/framework:ops", "//tensorflow/python/framework:random_seed", "//tensorflow/python/framework:sparse_tensor", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_shape", - "//tensorflow/python/framework:tensor_spec", "//tensorflow/python/framework:test_lib", "//tensorflow/python/framework:test_ops", "//tensorflow/python/framework:type_spec", @@ -352,7 +352,7 @@ tf_xla_py_strict_test( "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:errors", "//tensorflow/python/framework:ops", - "//tensorflow/python/framework:tensor_spec", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:test_lib", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:collective_ops", @@ -443,7 +443,6 @@ tf_py_strict_test( python_version = "PY3", deps = [ ":function_type_utils", - ":polymorphic_function", ":tracing_compilation", "//tensorflow/core:protos_all_py", "//tensorflow/core/function/capture:capture_container", @@ -455,7 +454,7 @@ tf_py_strict_test( "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:errors", "//tensorflow/python/framework:ops", - "//tensorflow/python/framework:tensor_spec", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:test_lib", "//tensorflow/python/framework:test_ops", "//tensorflow/python/layers", @@ -478,7 +477,6 @@ tf_py_strict_test( "//tensorflow/python/saved_model:save", "//tensorflow/python/util:compat", "//tensorflow/python/util:nest", - "//tensorflow/python/util:tf_decorator", "@absl_py//absl/testing:parameterized", ], ) @@ -585,10 +583,8 @@ py_strict_library( ":composite_tensor_utils", "//tensorflow/core/function/polymorphism:function_type", "//tensorflow/core/function/trace_type", - "//tensorflow/python/framework:composite_tensor", - "//tensorflow/python/framework:constant_op", "//tensorflow/python/framework:ops", - "//tensorflow/python/framework:tensor_spec", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:type_spec", "//tensorflow/python/ops:resource_variable_ops", "//tensorflow/python/util:nest", @@ -651,7 +647,7 @@ tf_xla_py_strict_test( "//tensorflow/python/framework:constant_op", "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", - "//tensorflow/python/framework:tensor_spec", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:test_lib", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:array_ops_gen", diff --git a/tensorflow/python/eager/polymorphic_function/compiler_ir_test.py b/tensorflow/python/eager/polymorphic_function/compiler_ir_test.py index e3b301307838b5..edc9ad50cb5e66 100644 --- a/tensorflow/python/eager/polymorphic_function/compiler_ir_test.py +++ b/tensorflow/python/eager/polymorphic_function/compiler_ir_test.py @@ -19,7 +19,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_spec +from tensorflow.python.framework import tensor from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_array_ops @@ -33,11 +33,11 @@ class CompilerIrTest(xla_test.XLATestCase): def _compareTwoMethodsCompilerIROutput(self, f, args, kwargs): flat_args = list(args) + list(kwargs.values()) - if not all([isinstance(x, ops.Tensor) for x in flat_args]): + if not all([isinstance(x, tensor.Tensor) for x in flat_args]): self.skipTest('It only support args and kwargs are all tf.Tensor types.') - args_spec = nest.map_structure(tensor_spec.TensorSpec.from_tensor, args) - kwargs_spec = nest.map_structure(tensor_spec.TensorSpec.from_tensor, kwargs) + args_spec = nest.map_structure(tensor.TensorSpec.from_tensor, args) + kwargs_spec = nest.map_structure(tensor.TensorSpec.from_tensor, kwargs) hlo_1 = f.experimental_get_compiler_ir(*args, **kwargs)() hlo_2 = f.experimental_get_compiler_ir(*args_spec, **kwargs_spec)() @@ -105,7 +105,7 @@ def f(x): with self.assertRaisesRegex( ValueError, 'Only support static input shape but got' ): - args_spec = [tensor_spec.TensorSpec((None), dtype=dtypes.float32)] + args_spec = [tensor.TensorSpec((None), dtype=dtypes.float32)] concrete_fn = f.get_concrete_function(*args_spec) _ = compiler_ir.from_concrete_function(concrete_fn)(stage='hlo') @@ -117,7 +117,7 @@ def f2(x): return x[x[0] : 0] args = [ops.convert_to_tensor([1, 2, 3, 4])] - args_spec = nest.map_structure(tensor_spec.TensorSpec.from_tensor, args) + args_spec = nest.map_structure(tensor.TensorSpec.from_tensor, args) concrete_fn = f2.get_concrete_function(*args_spec) if test_util.is_mlir_bridge_enabled(): with self.assertRaisesRegex( @@ -142,17 +142,17 @@ def f4(a, b): kwargs = {'b': a, 'a': b} kwargs_spec = nest.map_structure( - tensor_spec.TensorSpec.from_tensor, kwargs + tensor.TensorSpec.from_tensor, kwargs ) concrete_fn = f4.get_concrete_function(**kwargs_spec) captured_inputs = concrete_fn.captured_inputs captured_spec = compiler_ir.make_handledata_tensor_specs(captured_inputs) self.assertEqual(len(captured_spec), 2) self.assertEqual( - captured_spec[0], tensor_spec.TensorSpec((2), dtype=dtypes.float32) + captured_spec[0], tensor.TensorSpec((2), dtype=dtypes.float32) ) self.assertEqual( - captured_spec[1], tensor_spec.TensorSpec((1), dtype=dtypes.int32) + captured_spec[1], tensor.TensorSpec((1), dtype=dtypes.int32) ) def test_capture_variable_1(self): @@ -224,13 +224,13 @@ def fun_tf(x): return (x * v3 + t4 + v2) * v3 + t5 concrete_fn = fun_tf.get_concrete_function( - tensor_spec.TensorSpec((None,), dtype=dtypes.float32) + tensor.TensorSpec((None,), dtype=dtypes.float32) ) - x = tensor_spec.TensorSpec((10,), dtype=dtypes.float32) + x = tensor.TensorSpec((10,), dtype=dtypes.float32) hlo_1 = compiler_ir.from_concrete_function(concrete_fn, [x])(stage='hlo') self.assertIn('f32[10]', hlo_1) - x = tensor_spec.TensorSpec((20,), dtype=dtypes.float32) + x = tensor.TensorSpec((20,), dtype=dtypes.float32) hlo_2 = compiler_ir.from_concrete_function(concrete_fn, [x])(stage='hlo') self.assertIn('f32[20]', hlo_2) diff --git a/tensorflow/python/eager/polymorphic_function/concrete_function.py b/tensorflow/python/eager/polymorphic_function/concrete_function.py index 5461a54759b1f4..3f3cce06fb92a4 100644 --- a/tensorflow/python/eager/polymorphic_function/concrete_function.py +++ b/tensorflow/python/eager/polymorphic_function/concrete_function.py @@ -36,8 +36,8 @@ from tensorflow.python.framework import func_graph as func_graph_module from tensorflow.python.framework import indexed_slices from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import type_spec from tensorflow.python.ops import array_ops from tensorflow.python.ops import default_gradient @@ -168,7 +168,7 @@ def _construct_forward_backward(self, num_doutputs): signature = [] for t in trainable_outputs: signature.append( - tensor_spec.TensorSpec(*default_gradient.shape_and_dtype(t))) + tensor_lib.TensorSpec(*default_gradient.shape_and_dtype(t))) def _backprop_function(*grad_ys): with ops.device(None): @@ -1177,7 +1177,7 @@ def _call_with_flat_signature(self, args, kwargs): for i, arg in enumerate(args): if not isinstance( - arg, (ops.Tensor, resource_variable_ops.BaseResourceVariable)): + arg, (tensor_lib.Tensor, resource_variable_ops.BaseResourceVariable)): raise TypeError(f"{self._flat_signature_summary()}: expected argument " f"#{i}(zero-based) to be a Tensor; " f"got {type(arg).__name__} ({arg}).") @@ -1391,7 +1391,7 @@ def bool_closure(): concrete_fn.replace_capture_with_deferred_capture( bool_captured_tensor, bool_closure, - spec=tensor_spec.TensorSpec(shape=(), dtype=dtypes.bool)) + spec=tensor_lib.TensorSpec(shape=(), dtype=dtypes.bool)) print(concrete_fn()) # tf.Tensor([5.], shape=(1,), dtype=float32) ``` @@ -1651,7 +1651,7 @@ def pretty_printed_signature(self, verbose=True): def pretty_print_spec(spec): """Returns a string describing the spec for a single argument.""" - if isinstance(spec, tensor_spec.TensorSpec): + if isinstance(spec, tensor_lib.TensorSpec): return "{} Tensor, shape={}".format(spec.dtype.name, spec.shape) elif nest.is_nested(spec): pieces = nest.flatten(spec, expand_composites=False) @@ -1762,7 +1762,7 @@ def _export_to_saved_model_graph(self, object_map, tensor_map, return [] -_pywrap_utils.RegisterType("Tensor", ops.Tensor) +_pywrap_utils.RegisterType("Tensor", tensor_lib.Tensor) _pywrap_utils.RegisterType("EagerTensor", ops.EagerTensor) _pywrap_utils.RegisterType("IndexedSlices", indexed_slices.IndexedSlices) diff --git a/tensorflow/python/eager/polymorphic_function/function_type_utils.py b/tensorflow/python/eager/polymorphic_function/function_type_utils.py index 612caa4fb0c5ff..5da72faa326f52 100644 --- a/tensorflow/python/eager/polymorphic_function/function_type_utils.py +++ b/tensorflow/python/eager/polymorphic_function/function_type_utils.py @@ -23,7 +23,7 @@ from tensorflow.core.function import trace_type from tensorflow.core.function.polymorphism import function_type as function_type_lib from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_spec +from tensorflow.python.framework import tensor from tensorflow.python.framework import type_spec from tensorflow.python.ops import resource_variable_ops from tensorflow.python.util import nest @@ -165,7 +165,7 @@ def to_input_signature(function_type): trace_type.InternalPlaceholderContext(unnest_only=True) ) if any( - not isinstance(arg, tensor_spec.TensorSpec) + not isinstance(arg, tensor.TensorSpec) for arg in nest.flatten([constraint], expand_composites=True) ): # input_signature only supports contiguous TensorSpec composites @@ -465,13 +465,13 @@ def _validate_signature(signature): ) if any( - not isinstance(arg, tensor_spec.TensorSpec) + not isinstance(arg, tensor.TensorSpec) for arg in nest.flatten(signature, expand_composites=True) ): bad_args = [ arg for arg in nest.flatten(signature, expand_composites=True) - if not isinstance(arg, tensor_spec.TensorSpec) + if not isinstance(arg, tensor.TensorSpec) ] raise TypeError( "input_signature must be a possibly nested sequence of " @@ -483,7 +483,7 @@ def _validate_signature(signature): def _to_tensor_or_tensor_spec(x): return ( x - if isinstance(x, (ops.Tensor, tensor_spec.TensorSpec)) + if isinstance(x, (tensor.Tensor, tensor.TensorSpec)) else ops.convert_to_tensor(x) ) @@ -502,7 +502,7 @@ def _get_variable_specs(args): continue if isinstance(arg, resource_variable_ops.VariableSpec): variable_specs.append(arg) - elif not isinstance(arg, tensor_spec.TensorSpec): + elif not isinstance(arg, tensor.TensorSpec): # arg is a CompositeTensor spec. variable_specs.extend(_get_variable_specs(arg._component_specs)) # pylint: disable=protected-access return variable_specs diff --git a/tensorflow/python/eager/polymorphic_function/polymorphic_function_test.py b/tensorflow/python/eager/polymorphic_function/polymorphic_function_test.py index 12414894f7a004..e06f87a04fe6ab 100644 --- a/tensorflow/python/eager/polymorphic_function/polymorphic_function_test.py +++ b/tensorflow/python/eager/polymorphic_function/polymorphic_function_test.py @@ -52,8 +52,8 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import test_ops from tensorflow.python.framework import test_util from tensorflow.python.framework import type_spec @@ -110,7 +110,7 @@ def _spec_for_value(value): """Returns the (nested) TypeSpec for a value.""" if nest.is_nested(value): return nest.map_structure(_spec_for_value, value) - elif isinstance(value, (ops.Tensor, composite_tensor.CompositeTensor)): + elif isinstance(value, (tensor_lib.Tensor, composite_tensor.CompositeTensor)): return type_spec.type_spec_from_value(value) else: return value @@ -408,8 +408,8 @@ def testImplementsWorksWithTensorSpec(self): v = polymorphic_function.function( experimental_implements='func')(lambda x, y: x + y) v = v.get_concrete_function( - tensor_spec.TensorSpec(shape=None, dtype=dtypes.float32), - tensor_spec.TensorSpec(shape=None, dtype=dtypes.float32)) + tensor_lib.TensorSpec(shape=None, dtype=dtypes.float32), + tensor_lib.TensorSpec(shape=None, dtype=dtypes.float32)) x = v(1., 2.) self.assertEqual(x.numpy(), 3.) @@ -546,21 +546,21 @@ def check_trace(x, expected_trace): check_trace( structured_tensor.StructuredTensor.from_pyval({'a': [1]}), structured_tensor.StructuredTensor.Spec._from_fields_and_rank( - fields={'a': tensor_spec.TensorSpec((1,), dtypes.int32)}, rank=0)) + fields={'a': tensor_lib.TensorSpec((1,), dtypes.int32)}, rank=0)) check_trace( structured_tensor.StructuredTensor.from_pyval({'b': [1]}), structured_tensor.StructuredTensor.Spec._from_fields_and_rank( - fields={'b': tensor_spec.TensorSpec((1,), dtypes.int32)}, rank=0)) + fields={'b': tensor_lib.TensorSpec((1,), dtypes.int32)}, rank=0)) check_trace( structured_tensor.StructuredTensor.from_pyval({'c': [1]}), structured_tensor.StructuredTensor.Spec._from_fields_and_rank( - fields={'c': tensor_spec.TensorSpec((1,), dtypes.int32)}, rank=0)) + fields={'c': tensor_lib.TensorSpec((1,), dtypes.int32)}, rank=0)) # But if we call again with only shape different, then do relax: check_trace( # relax & retrace structured_tensor.StructuredTensor.from_pyval({'a': [1, 2]}), structured_tensor.StructuredTensor.Spec._from_fields_and_rank( - fields={'a': tensor_spec.TensorSpec((None,), dtypes.int32)}, + fields={'a': tensor_lib.TensorSpec((None,), dtypes.int32)}, rank=0)) check_trace( # use relaxed graph structured_tensor.StructuredTensor.from_pyval({'a': [1, 2, 3]}), None) @@ -593,13 +593,13 @@ def check_trace(x, expected_trace): check_trace( # shape=[1, 2]: retrace dataset_ops.make_one_shot_iterator(ds_1_2), iterator_ops.IteratorSpec( - tensor_spec.TensorSpec([1, 2], dtypes.float32))) + tensor_lib.TensorSpec([1, 2], dtypes.float32))) check_trace( # shape=[1, 2]: no retrace (use the [1, 2] graph) dataset_ops.make_one_shot_iterator(ds_1_2), None) check_trace( # shape=[2, 2]: relax to [None, 2] and retrace dataset_ops.make_one_shot_iterator(ds_2_2), iterator_ops.IteratorSpec( - tensor_spec.TensorSpec([None, 2], dtypes.float32))) + tensor_lib.TensorSpec([None, 2], dtypes.float32))) check_trace( # shape=[3, 2]: no retrace (use the [None, 2] graph) dataset_ops.make_one_shot_iterator(ds_3_2), None) check_trace( # shape=[4, 2]: no retrace (use the [None, 2] graph) @@ -607,7 +607,7 @@ def check_trace(x, expected_trace): check_trace( # shape=[2, 1]: relax to [None, None] and retrace dataset_ops.make_one_shot_iterator(ds_2_1), iterator_ops.IteratorSpec( - tensor_spec.TensorSpec([None, None], dtypes.float32))) + tensor_lib.TensorSpec([None, None], dtypes.float32))) def testCapturesVariables(self): a = variables.Variable(1.0, trainable=False) @@ -787,7 +787,7 @@ def sq(a): return matmul(a, a) sq_op = sq.get_concrete_function( - tensor_spec.TensorSpec((None, None), dtypes.float32)) + tensor_lib.TensorSpec((None, None), dtypes.float32)) self.assertEqual([None, None], sq_op.output_shapes.as_list()) t1 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) @@ -806,18 +806,16 @@ def sq(mats): ((a, b),) = mats return matmul(a, b) - sq_op_autonamed = sq.get_concrete_function([(tensor_spec.TensorSpec( - (None, None), - dtypes.float32), tensor_spec.TensorSpec((None, None), dtypes.float32))]) + sq_op_autonamed = sq.get_concrete_function([( + tensor_lib.TensorSpec((None, None), dtypes.float32), + tensor_lib.TensorSpec((None, None), dtypes.float32), + )]) self.assertEqual([None, None], sq_op_autonamed.output_shapes.as_list()) - sq_op = sq.get_concrete_function([(tensor_spec.TensorSpec((None, None), - dtypes.float32, - name='first_mat'), - tensor_spec.TensorSpec( - (None, None), - dtypes.float32, - name='second_mat'))]) + sq_op = sq.get_concrete_function([( + tensor_lib.TensorSpec((None, None), dtypes.float32, name='first_mat'), + tensor_lib.TensorSpec((None, None), dtypes.float32, name='second_mat'), + )]) self.assertEqual([None, None], sq_op.output_shapes.as_list()) t1 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) @@ -892,7 +890,7 @@ def testShareRendezvous(self): cpu = '/device:CPU:0' - signature = [tensor_spec.TensorSpec([], dtypes.int32)] + signature = [tensor_lib.TensorSpec([], dtypes.int32)] @polymorphic_function.function def send(): @@ -960,8 +958,8 @@ def a_times_b(inputs): t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) sq_op = a_times_b.get_concrete_function( pair( - dict(a=tensor_spec.TensorSpec([2, 2], dtypes.float32, 'a')), - dict(b=tensor_spec.TensorSpec([2, 2], dtypes.float32, 'b')))) + dict(a=tensor_lib.TensorSpec([2, 2], dtypes.float32, 'a')), + dict(b=tensor_lib.TensorSpec([2, 2], dtypes.float32, 'b')))) self.assertEqual(sq_op.output_shapes, tensor_shape.TensorShape([2, 2])) out = sq_op(a=t, b=t) self.assertAllEqual(out, math_ops.matmul(t, t).numpy()) @@ -1116,7 +1114,7 @@ def testShapeInferenceForMoreSpecificInput(self): def f(a): return array_ops.reshape(a, [-1, 3]) - signature = [tensor_spec.TensorSpec(None, dtypes.float32)] + signature = [tensor_lib.TensorSpec(None, dtypes.float32)] compiled = polymorphic_function.function(f, input_signature=signature) @polymorphic_function.function @@ -1522,7 +1520,7 @@ def testConcreteFunctionType(self): def foo(x): return {'input': x, 'capture': y} - cf = foo.get_concrete_function(tensor_spec.TensorSpec([], dtypes.int32)) + cf = foo.get_concrete_function(tensor_lib.TensorSpec([], dtypes.int32)) x = constant_op.constant(2) output = cf(x) self.assertEqual(set(output.keys()), {'input', 'capture'}) @@ -1534,12 +1532,12 @@ def foo(x): self.assertEqual(parameters[0].name, 'x') self.assertEqual( parameters[0].type_constraint, - tensor_spec.TensorSpec([], dtypes.int32), + tensor_lib.TensorSpec([], dtypes.int32), ) captures = cf.function_type.captures self.assertLen(captures, 1) - self.assertEqual(captures[id(y)], tensor_spec.TensorSpec([], dtypes.int32)) + self.assertEqual(captures[id(y)], tensor_lib.TensorSpec([], dtypes.int32)) output = cf.function_type.output self.assertEqual(output, trace_type.from_value({'input': x, 'capture': y})) @@ -1551,8 +1549,8 @@ def testSequenceInputs(self): clipped_list, global_norm = clip_by_global_norm(t_list, constant_op.constant(.2)) for t in clipped_list: - self.assertIsInstance(t, ops.Tensor) - self.assertIsInstance(global_norm, ops.Tensor) + self.assertIsInstance(t, tensor_lib.Tensor) + self.assertIsInstance(global_norm, tensor_lib.Tensor) def testNestedSequenceInputs(self): @@ -1690,7 +1688,7 @@ def foo(a, b): del b # Signatures must consist exclusively of `TensorSpec` objects. - signature = [(2, 3), tensor_spec.TensorSpec([2, 3], dtypes.float32)] + signature = [(2, 3), tensor_lib.TensorSpec([2, 3], dtypes.float32)] with self.assertRaisesRegex(TypeError, 'input_signature.*nested sequence'): polymorphic_function.function(foo, input_signature=signature) @@ -1700,7 +1698,7 @@ def testInputsIncompatibleWithSignatureRaisesError(self): def foo(a): return a - signature = [tensor_spec.TensorSpec(shape=(2,), dtype=dtypes.float32)] + signature = [tensor_lib.TensorSpec(shape=(2,), dtype=dtypes.float32)] defined = polymorphic_function.function(foo, input_signature=signature) # Valid call @@ -1729,7 +1727,7 @@ def foo(a): TypeError, r'Can not cast .*shape=\(3,\).* to .*shape=\(2,\).*' ): defined.get_concrete_function( - tensor_spec.TensorSpec(shape=(3,), dtype=dtypes.float32)) + tensor_lib.TensorSpec(shape=(3,), dtype=dtypes.float32)) def testMismatchedConcreteSignatureRaisesError(self): @@ -1761,8 +1759,8 @@ def foo(a, training=True): return -1.0 * a signature = [ - tensor_spec.TensorSpec([], dtypes.float32), - tensor_spec.TensorSpec([], dtypes.bool), + tensor_lib.TensorSpec([], dtypes.float32), + tensor_lib.TensorSpec([], dtypes.bool), ] defined = polymorphic_function.function(foo, input_signature=signature) a = constant_op.constant(1.0) @@ -1860,8 +1858,8 @@ def py_add(x, y): py_add(array_ops.ones([]), array_ops.ones([])) add = py_add.get_concrete_function( - tensor_spec.TensorSpec(None, dtypes.float32), - tensor_spec.TensorSpec(None, dtypes.float32)) + tensor_lib.TensorSpec(None, dtypes.float32), + tensor_lib.TensorSpec(None, dtypes.float32)) @polymorphic_function.function def py_composite(x, y): @@ -1869,8 +1867,8 @@ def py_composite(x, y): py_composite(array_ops.ones([]), array_ops.ones([])) composite = py_composite.get_concrete_function( - tensor_spec.TensorSpec(None, dtypes.float32), - tensor_spec.TensorSpec(None, dtypes.float32)) + tensor_lib.TensorSpec(None, dtypes.float32), + tensor_lib.TensorSpec(None, dtypes.float32)) with context.graph_mode(), self.cached_session(): with ops.get_default_graph().as_default(): @@ -2287,9 +2285,9 @@ def _uses_symbolic_shapes(w, x, y): return array_ops.reshape(y, [n, x_batch, -1]) conc = _uses_symbolic_shapes.get_concrete_function( - tensor_spec.TensorSpec(None, dtypes.float32), - tensor_spec.TensorSpec(None, dtypes.float32), - tensor_spec.TensorSpec(None, dtypes.float32)) + tensor_lib.TensorSpec(None, dtypes.float32), + tensor_lib.TensorSpec(None, dtypes.float32), + tensor_lib.TensorSpec(None, dtypes.float32)) @polymorphic_function.function def _call_concrete(): @@ -2482,7 +2480,7 @@ def f(x, y): @test_util.run_in_graph_and_eager_modes def testConcreteFunctionMethodWithVarargs(self): - float32_scalar = tensor_spec.TensorSpec(shape=(), dtype=dtypes.float32) + float32_scalar = tensor_lib.TensorSpec(shape=(), dtype=dtypes.float32) class MyModel(module.Module): @@ -2808,8 +2806,8 @@ def f(x, y): return x * 10 + y conc = f.get_concrete_function( - x=tensor_spec.TensorSpec(None, dtypes.int32, name='y'), - y=tensor_spec.TensorSpec(None, dtypes.int32, name='x')) + x=tensor_lib.TensorSpec(None, dtypes.int32, name='y'), + y=tensor_lib.TensorSpec(None, dtypes.int32, name='x')) result = conc(x=constant_op.constant(5), y=constant_op.constant(6)) self.assertAllEqual(result, 56) @@ -2886,7 +2884,7 @@ def func2(x, y=3, *args, **kwargs): def testPrettyPrintedExplicitSignatureWithKeywordArg(self): @polymorphic_function.function( - input_signature=[tensor_spec.TensorSpec(None)]) + input_signature=[tensor_lib.TensorSpec(None)]) def fn(a, b=1): return a + b @@ -3101,8 +3099,8 @@ def func_pos_3args(position_arg1, position_arg2, position_arg3): def testShapeInferencePropagateConstNestedStack(self): @polymorphic_function.function(input_signature=[ - tensor_spec.TensorSpec((None, None), dtype=dtypes.int32), - tensor_spec.TensorSpec((), dtype=dtypes.int32), + tensor_lib.TensorSpec((None, None), dtype=dtypes.int32), + tensor_lib.TensorSpec((), dtype=dtypes.int32), ]) def f(x, s): old_shape = array_ops.shape(x) @@ -3111,7 +3109,7 @@ def f(x, s): return y @polymorphic_function.function(input_signature=[ - tensor_spec.TensorSpec(shape=(3, 6), dtype=dtypes.int32) + tensor_lib.TensorSpec(shape=(3, 6), dtype=dtypes.int32) ]) def g(x): y = f(x, s=5) @@ -3124,8 +3122,8 @@ def g(x): def testShapeInferencePropagateConstNestedUnstackStack(self): @polymorphic_function.function(input_signature=[ - tensor_spec.TensorSpec((None, None), dtype=dtypes.int32), - tensor_spec.TensorSpec((), dtype=dtypes.int32), + tensor_lib.TensorSpec((None, None), dtype=dtypes.int32), + tensor_lib.TensorSpec((), dtype=dtypes.int32), ]) def f(x, s): s0, _ = array_ops_stack.unstack(array_ops.shape(x), axis=0) @@ -3134,7 +3132,7 @@ def f(x, s): return y @polymorphic_function.function(input_signature=[ - tensor_spec.TensorSpec(shape=(3, 6), dtype=dtypes.int32) + tensor_lib.TensorSpec(shape=(3, 6), dtype=dtypes.int32) ]) def g(x): y = f(x, s=5) @@ -3147,9 +3145,9 @@ def g(x): def testShapeInferencePropagateConstNestedConcat(self): @polymorphic_function.function(input_signature=[ - tensor_spec.TensorSpec((), dtype=dtypes.int32), - tensor_spec.TensorSpec((), dtype=dtypes.int32), - tensor_spec.TensorSpec((), dtype=dtypes.int32), + tensor_lib.TensorSpec((), dtype=dtypes.int32), + tensor_lib.TensorSpec((), dtype=dtypes.int32), + tensor_lib.TensorSpec((), dtype=dtypes.int32), ]) def f(d1, d2, d3): new_shape = array_ops.concat([[d1], [d2], [d3]], axis=-1) @@ -3167,9 +3165,9 @@ def g(): def testShapeInferencePropagateConstDoubleNested(self): @polymorphic_function.function(input_signature=[ - tensor_spec.TensorSpec((), dtype=dtypes.int32), - tensor_spec.TensorSpec((), dtype=dtypes.int32), - tensor_spec.TensorSpec((), dtype=dtypes.int32), + tensor_lib.TensorSpec((), dtype=dtypes.int32), + tensor_lib.TensorSpec((), dtype=dtypes.int32), + tensor_lib.TensorSpec((), dtype=dtypes.int32), ]) def f(d1, d2, d3): new_shape = array_ops.concat([[d1], [d2], [d3]], axis=-1) @@ -3417,8 +3415,8 @@ def apply(self, x): def testMethodExtensionType(self): class MaskedTensor(extension_type.ExtensionType): - values: ops.Tensor - mask: ops.Tensor + values: tensor_lib.Tensor + mask: tensor_lib.Tensor @polymorphic_function.function def with_default(self, default_value): @@ -3495,24 +3493,24 @@ def dynamic_unroll(core_fn, def test_unspecified_default_argument(self): wrapped = polymorphic_function.function( lambda x, y=2: x + y, - input_signature=[tensor_spec.TensorSpec((), dtypes.int32)]) + input_signature=[tensor_lib.TensorSpec((), dtypes.int32)]) self.assertEqual(3, wrapped(constant_op.constant(1)).numpy()) def test_concrete_function_from_signature(self): @polymorphic_function.function( - input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)]) + input_signature=[tensor_lib.TensorSpec(None, dtypes.float32)]) def compute(x): return 2. * x concrete = compute.get_concrete_function() self.assertAllClose(1., concrete(constant_op.constant(0.5))) concrete = compute.get_concrete_function( - tensor_spec.TensorSpec(None, dtypes.float32)) + tensor_lib.TensorSpec(None, dtypes.float32)) self.assertAllClose(4., concrete(constant_op.constant(2.))) signature_args, _ = concrete.structured_input_signature self.assertEqual(signature_args, - (tensor_spec.TensorSpec( + (tensor_lib.TensorSpec( None, dtypes.float32, name='x'),)) def testInputSignatureMissingTensorSpecsMethod(self): @@ -3539,7 +3537,7 @@ def f6(self, arg1, arg4=4, **kwargs): m = MyModule() tf_func_dec = polymorphic_function.function( - input_signature=(tensor_spec.TensorSpec([], dtypes.int32),)) + input_signature=(tensor_lib.TensorSpec([], dtypes.int32),)) error_message = 'input_signature missing type constraint' with self.assertRaisesRegex(TypeError, error_message): tf_func_dec(m.f1)(1, 2, 3) @@ -3560,7 +3558,7 @@ def f6(self, arg1, arg4=4, **kwargs): def testInputSignatureMissingTensorSpecsFunction(self): tf_func_dec = polymorphic_function.function( - input_signature=(tensor_spec.TensorSpec([], dtypes.int32),)) + input_signature=(tensor_lib.TensorSpec([], dtypes.int32),)) error_message = 'input_signature missing type constraint' # pylint: disable=unused-argument def f1(arg1, arg2, arg3): @@ -3600,7 +3598,7 @@ def f6(arg1, arg4=4, **kwargs): def testInputSignatureMissingTensorSpecsLambdaFunction(self): tf_func_dec = polymorphic_function.function( - input_signature=(tensor_spec.TensorSpec([], dtypes.int32),)) + input_signature=(tensor_lib.TensorSpec([], dtypes.int32),)) error_message = 'input_signature missing type constraint' with self.assertRaisesRegex(TypeError, error_message): tf_func_dec(lambda ar1, arg2, arg3: None)(1, 2, 3) @@ -3638,7 +3636,7 @@ def f(arg1, arg2, arg3, arg4=4): error_message = 'input_signature missing type constraint' tf_func_dec = polymorphic_function.function( - input_signature=(tensor_spec.TensorSpec([], dtypes.int32),) + input_signature=(tensor_lib.TensorSpec([], dtypes.int32),) ) with self.assertRaisesRegex(TypeError, error_message): tf_func_dec(functools.partial(f, 1))(2, 3) @@ -3690,20 +3688,20 @@ def f(x): return x conc = f.get_concrete_function( - tensor_spec.TensorSpec(None, dtypes.float32, 'y')) + tensor_lib.TensorSpec(None, dtypes.float32, 'y')) conc(y=constant_op.constant(3.0)) signature_args, _ = conc.structured_input_signature self.assertEqual('y', signature_args[0].name) # If name is not specified, the previously named one will be returned. - conc = f.get_concrete_function(tensor_spec.TensorSpec(None, dtypes.float32)) + conc = f.get_concrete_function(tensor_lib.TensorSpec(None, dtypes.float32)) conc(x=constant_op.constant(3.0)) signature_args, _ = conc.structured_input_signature self.assertEqual('y', signature_args[0].name) # New name will return updated signature. conc = f.get_concrete_function( - tensor_spec.TensorSpec(None, dtypes.float32, 'z') + tensor_lib.TensorSpec(None, dtypes.float32, 'z') ) conc(x=constant_op.constant(3.0)) signature_args, _ = conc.structured_input_signature @@ -3714,7 +3712,7 @@ def g(x): return x[0] conc = g.get_concrete_function( - [tensor_spec.TensorSpec(None, dtypes.float32, 'z'), 2]) + [tensor_lib.TensorSpec(None, dtypes.float32, 'z'), 2]) conc(z=constant_op.constant(3.0)) signature_args, _ = conc.structured_input_signature self.assertEqual('z', signature_args[0][0].name) @@ -3756,10 +3754,10 @@ def f(x, y): self.assertEqual( signatures_args, - set(((tensor_spec.TensorSpec([1, 2], dtypes.float32, name='x'), - tensor_spec.TensorSpec([1], dtypes.float32, name='y')), - (tensor_spec.TensorSpec([1, 3], dtypes.int32, name='x'), - tensor_spec.TensorSpec([1], dtypes.int32, name='y'))))) + set(((tensor_lib.TensorSpec([1, 2], dtypes.float32, name='x'), + tensor_lib.TensorSpec([1], dtypes.float32, name='y')), + (tensor_lib.TensorSpec([1, 3], dtypes.int32, name='x'), + tensor_lib.TensorSpec([1], dtypes.int32, name='y'))))) @test_util.assert_no_garbage_created def testFunctionReferenceCycles(self): @@ -3838,10 +3836,10 @@ def non_unique_arg_names(x, **kwargs): return a + b + c + d concrete = non_unique_arg_names.get_concrete_function( - (tensor_spec.TensorSpec(None, dtypes.float32), - tensor_spec.TensorSpec(None, dtypes.float32), - tensor_spec.TensorSpec(None, dtypes.float32)), - d=tensor_spec.TensorSpec(None, dtypes.float32)) + (tensor_lib.TensorSpec(None, dtypes.float32), + tensor_lib.TensorSpec(None, dtypes.float32), + tensor_lib.TensorSpec(None, dtypes.float32)), + d=tensor_lib.TensorSpec(None, dtypes.float32)) self.assertAllClose( 10., concrete(x=constant_op.constant(1.), @@ -3949,9 +3947,9 @@ def func(x): return 2 * x func_a = func.get_concrete_function( - tensor_spec.TensorSpec([None], dtypes.int32)) + tensor_lib.TensorSpec([None], dtypes.int32)) func_b = func.get_concrete_function( - tensor_spec.TensorSpec([None], dtypes.int32)) + tensor_lib.TensorSpec([None], dtypes.int32)) self.assertIs(func_a, func_b) @@ -4049,7 +4047,7 @@ def decorator(f): self.assertEqual(func().numpy(), 2) @parameterized.parameters(*itertools.product( - (None, (tensor_spec.TensorSpec([]),)), # input_signature + (None, (tensor_lib.TensorSpec([]),)), # input_signature (True, False), # autograph (None, converter.Feature.ALL), # autograph_options (None, 'foo.bar'), # implements @@ -4133,7 +4131,7 @@ def func(): self.assertEmpty(graph.captures) @parameterized.parameters(*itertools.product( - (None, (tensor_spec.TensorSpec([]),)), # input_signature + (None, (tensor_lib.TensorSpec([]),)), # input_signature (True, False), # autograph (None, converter.Feature.ALL), # autograph_options (None, 'foo.bar'), # implements @@ -4317,7 +4315,7 @@ def __call__(self, x): f_flexible = Foo() _ = f_flexible.__call__.get_concrete_function( - tensor_spec.TensorSpec(shape=[None], dtype=dtypes.int32)) + tensor_lib.TensorSpec(shape=[None], dtype=dtypes.int32)) tmp_dir = self.create_tempdir() save(f_flexible, tmp_dir.full_path) restored_f_flexible = load(tmp_dir.full_path) @@ -4386,14 +4384,14 @@ def testDouble(self, a): def test_tensor_shape_casted_to_specific(self): @polymorphic_function.function( - input_signature=[tensor_spec.TensorSpec([1])] + input_signature=[tensor_lib.TensorSpec([1])] ) def specific(x): self.assertEqual(x.shape, [1]) return x @polymorphic_function.function( - input_signature=[tensor_spec.TensorSpec(None)] + input_signature=[tensor_lib.TensorSpec(None)] ) def general(x): return specific(x) @@ -4572,7 +4570,7 @@ def closure(): concrete_fn.replace_capture_with_deferred_capture( concrete_fn.captured_inputs[1], closure, - spec=tensor_spec.TensorSpec(shape=(), dtype=dtypes.float32), + spec=tensor_lib.TensorSpec(shape=(), dtype=dtypes.float32), placeholder=concrete_fn.inputs[1]) self.assertAllEqual(concrete_fn(), 8.0) @@ -4589,7 +4587,7 @@ def testRaiseReplaceCaptureWithDeferredTypeSpecMismatch(self): def fn(): deferred_tensor = ops.get_default_graph().capture_call_time_value( lambda: value, - tensor_spec.TensorSpec(shape=(1,), dtype=dtypes.float32)) + tensor_lib.TensorSpec(shape=(1,), dtype=dtypes.float32)) if bool_captured_tensor: return deferred_tensor else: @@ -4615,13 +4613,13 @@ def float_closure(): concrete_fn.replace_capture_with_deferred_capture( bool_captured_tensor, float_closure, - spec=tensor_spec.TensorSpec(shape=(1,), dtype=dtypes.float32)) + spec=tensor_lib.TensorSpec(shape=(1,), dtype=dtypes.float32)) # Test replace without a placeholder concrete_fn.replace_capture_with_deferred_capture( bool_captured_tensor, bool_closure, - spec=tensor_spec.TensorSpec(shape=(), dtype=dtypes.bool)) + spec=tensor_lib.TensorSpec(shape=(), dtype=dtypes.bool)) self.assertAllEqual(concrete_fn(), [5.]) @@ -4633,7 +4631,7 @@ def testConcreteFunctionSetExternalCapture(self): def fn(): deferred_tensor = ops.get_default_graph().capture_call_time_value( lambda: value, - tensor_spec.TensorSpec(shape=(1,), dtype=dtypes.float32)) + tensor_lib.TensorSpec(shape=(1,), dtype=dtypes.float32)) return deferred_tensor + captured_tensor cf = fn.get_concrete_function() @@ -4656,7 +4654,7 @@ def testGraphReplaceCaptureAndSetExternalCapture(self): def fn(): deferred_tensor = ops.get_default_graph().capture_call_time_value( lambda: value, - tensor_spec.TensorSpec(shape=(1,), dtype=dtypes.float32)) + tensor_lib.TensorSpec(shape=(1,), dtype=dtypes.float32)) if bool_captured_tensor: return deferred_tensor else: @@ -4673,7 +4671,7 @@ def closure(): concrete_fn.graph.replace_capture_with_deferred_capture( concrete_fn.captured_inputs[0], closure, - spec=tensor_spec.TensorSpec(shape=(), dtype=dtypes.bool), + spec=tensor_lib.TensorSpec(shape=(), dtype=dtypes.bool), placeholder=concrete_fn.inputs[1]) concrete_fn.set_external_captures([ @@ -4688,7 +4686,7 @@ def testDeferredCapture(self): @polymorphic_function.function def lazy_capture(x): y = ops.get_default_graph().capture_call_time_value( - lambda: value, tensor_spec.TensorSpec(None)) + lambda: value, tensor_lib.TensorSpec(None)) return x + y self.assertAllEqual(lazy_capture(2.0), 3.0) @@ -4703,7 +4701,7 @@ def testNestedDeferredCapture(self): @polymorphic_function.function def inner(x): y = ops.get_default_graph().capture_call_time_value( - lambda: value, tensor_spec.TensorSpec(None)) + lambda: value, tensor_lib.TensorSpec(None)) return x + y @polymorphic_function.function @@ -4723,7 +4721,7 @@ def testNestedDeferredCaptureInTFWhileLoop(self): @polymorphic_function.function def inner(x): y = ops.get_default_graph().capture_call_time_value( - lambda: value, tensor_spec.TensorSpec(None)) + lambda: value, tensor_lib.TensorSpec(None)) return x + y @polymorphic_function.function @@ -4752,15 +4750,15 @@ def testDeferredCaptureWithKey(self): @polymorphic_function.function def lazy_capture(x): w = ops.get_default_graph().capture_call_time_value( - lambda: value0, tensor_spec.TensorSpec(None), key=0) + lambda: value0, tensor_lib.TensorSpec(None), key=0) y = ops.get_default_graph().capture_call_time_value( - lambda: value1, tensor_spec.TensorSpec(None), key=1) + lambda: value1, tensor_lib.TensorSpec(None), key=1) def bad_closure(): raise ValueError('Should not run') z = ops.get_default_graph().capture_call_time_value( - bad_closure, tensor_spec.TensorSpec(None), key=1) + bad_closure, tensor_lib.TensorSpec(None), key=1) return x + y + w + z self.assertAllEqual(lazy_capture(2.0), 7.0) @@ -4774,7 +4772,7 @@ def testDeferredCaptureTypeError(self): @polymorphic_function.function def lazy_capture(x): y = ops.get_default_graph().capture_call_time_value( - lambda: value, tensor_spec.TensorSpec(())) + lambda: value, tensor_lib.TensorSpec(())) return x + y self.assertAllEqual(lazy_capture(2.0), 3.0) diff --git a/tensorflow/python/eager/polymorphic_function/polymorphic_function_xla_jit_test.py b/tensorflow/python/eager/polymorphic_function/polymorphic_function_xla_jit_test.py index 43f67f81870ce2..5598226ba6d966 100644 --- a/tensorflow/python/eager/polymorphic_function/polymorphic_function_xla_jit_test.py +++ b/tensorflow/python/eager/polymorphic_function/polymorphic_function_xla_jit_test.py @@ -24,7 +24,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_spec +from tensorflow.python.framework import tensor from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import collective_ops @@ -49,11 +49,11 @@ class FunctionTest(xla_test.XLATestCase): def _compareTwoMethodsCompilerIROutput(self, f, args, kwargs): """Assert the two differnet methods (tensor_spec inputs or tensor inputs) experimental_get_compiler give same HLO text.""" flat_args = list(args) + list(kwargs.values()) - if not all([isinstance(x, ops.Tensor) for x in flat_args]): + if not all([isinstance(x, tensor.Tensor) for x in flat_args]): self.skipTest('It only support args and kwargs are all tf.Tensor types.') - args_spec = nest.map_structure(tensor_spec.TensorSpec.from_tensor, args) - kwargs_spec = nest.map_structure(tensor_spec.TensorSpec.from_tensor, kwargs) + args_spec = nest.map_structure(tensor.TensorSpec.from_tensor, args) + kwargs_spec = nest.map_structure(tensor.TensorSpec.from_tensor, kwargs) hlo_1 = f.experimental_get_compiler_ir(*args, **kwargs)() hlo_2 = f.experimental_get_compiler_ir(*args_spec, **kwargs_spec)() @@ -389,7 +389,7 @@ def g(x): def testWhileLoopWithUnmodifiedCarriedShape(self): with ops.device('device:{}:0'.format(self.device)): - signature = [tensor_spec.TensorSpec(shape=[None], dtype=dtypes.float32)] + signature = [tensor.TensorSpec(shape=[None], dtype=dtypes.float32)] # We define a signature that specifies unknown vector shape, then test # that tf.shape constness gets properly propagated into the while_loop @@ -407,7 +407,7 @@ def g(x): def testNestedWhileLoopWithUnmodifiedCarriedShape(self): with ops.device('device:{}:0'.format(self.device)): - signature = [tensor_spec.TensorSpec(shape=[None], dtype=dtypes.float32)] + signature = [tensor.TensorSpec(shape=[None], dtype=dtypes.float32)] @polymorphic_function.function( input_signature=signature, jit_compile=True) @@ -432,7 +432,7 @@ def outer(y, shp): def testNestedWhileLoopWithUnmodifiedCarriedShapeSlice(self): with ops.device('device:{}:0'.format(self.device)): signature = [ - tensor_spec.TensorSpec(shape=[None, None], dtype=dtypes.float32) + tensor.TensorSpec(shape=[None, None], dtype=dtypes.float32) ] @polymorphic_function.function( diff --git a/tensorflow/python/eager/polymorphic_function/tracing_compilation_test.py b/tensorflow/python/eager/polymorphic_function/tracing_compilation_test.py index d8dbd707d6b478..ec7c0da10db7e0 100644 --- a/tensorflow/python/eager/polymorphic_function/tracing_compilation_test.py +++ b/tensorflow/python/eager/polymorphic_function/tracing_compilation_test.py @@ -34,7 +34,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_spec +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import test_ops from tensorflow.python.framework import test_util from tensorflow.python.layers import convolutional @@ -221,7 +221,7 @@ def f_py(): @test_util.run_v2_only def testCompilationNumpyArraysConvertedToTensors(self): def f(x): - self.assertIsInstance(x, ops.Tensor) + self.assertIsInstance(x, tensor_lib.Tensor) return x x = random_ops.random_uniform([2, 2]).numpy() @@ -280,7 +280,7 @@ def f(x, dtype): def testCompilationNumpyArraysConvertedToTensorsInKwargs(self): def f(**kwargs): x = kwargs.pop('x') - self.assertIsInstance(x, ops.Tensor) + self.assertIsInstance(x, tensor_lib.Tensor) return x x = random_ops.random_uniform([2, 2]).numpy() @@ -580,7 +580,7 @@ def foo(a): return a function_cache = function_cache_lib.FunctionCache() - signature = [tensor_spec.TensorSpec(shape=(2,), dtype=dtypes.float32)] + signature = [tensor_lib.TensorSpec(shape=(2,), dtype=dtypes.float32)] defined = compiled_fn( foo, input_signature=signature, function_cache=function_cache ) @@ -592,7 +592,7 @@ def foo(a): self.assertAllEqual( a, defined.get_concrete_function( - tensor_spec.TensorSpec((2,), dtype=dtypes.float32) + tensor_lib.TensorSpec((2,), dtype=dtypes.float32) )(a), ) self.assertLen(function_cache, 1) @@ -601,7 +601,7 @@ def bar(a): self.assertEqual(a._shape_tuple(), (2, None)) return a - signature = [tensor_spec.TensorSpec((2, None), dtypes.float32)] + signature = [tensor_lib.TensorSpec((2, None), dtypes.float32)] defined = compiled_fn(bar, input_signature=signature) a = array_ops.ones([2, 1]) out = defined(a) @@ -629,7 +629,7 @@ def f(*_args, **_kwargs): self.assertLen(function_cache, 2) def testInputSignatureWithCompatibleInputs(self): - rank2_spec = tensor_spec.TensorSpec( + rank2_spec = tensor_lib.TensorSpec( shape=(None, None), dtype=dtypes.float32 ) @@ -656,8 +656,8 @@ def expected_foo(a, b): @compiled_fn( input_signature=[ - [tensor_spec.TensorSpec((2, None), dtypes.float32)] * 2, - tensor_spec.TensorSpec((1,), dtypes.float32), + [tensor_lib.TensorSpec((2, None), dtypes.float32)] * 2, + tensor_lib.TensorSpec((1,), dtypes.float32), ], function_cache=function_cache, ) @@ -707,9 +707,9 @@ def expected_bar(a): @compiled_fn( input_signature=[{ - 'a': tensor_spec.TensorSpec((2, None), dtypes.float32), - 'b': tensor_spec.TensorSpec((2, None), dtypes.float32), - 'c': tensor_spec.TensorSpec((1,), dtypes.float32), + 'a': tensor_lib.TensorSpec((2, None), dtypes.float32), + 'b': tensor_lib.TensorSpec((2, None), dtypes.float32), + 'c': tensor_lib.TensorSpec((1,), dtypes.float32), }] ) def bar(a): @@ -744,7 +744,7 @@ def foo(a, b): del b # Signatures must be either lists or tuples on their outermost levels. - signature = {'t1': tensor_spec.TensorSpec([], dtypes.float32)} + signature = {'t1': tensor_lib.TensorSpec([], dtypes.float32)} with self.assertRaisesRegex( TypeError, 'input_signature must be either a tuple or a list.*' ): @@ -755,8 +755,8 @@ def foo(a, b): return [a, b] signature = [ - [tensor_spec.TensorSpec((1,), dtypes.float32)] * 2, - [tensor_spec.TensorSpec((1,), dtypes.float32)] * 2, + [tensor_lib.TensorSpec((1,), dtypes.float32)] * 2, + [tensor_lib.TensorSpec((1,), dtypes.float32)] * 2, ] defined = compiled_fn(foo, input_signature=signature) a = array_ops.ones([1]) @@ -772,7 +772,7 @@ def foo(a, b): def testUnderspecifiedInputSignature(self): @compiled_fn( input_signature=[ - tensor_spec.TensorSpec([], dtypes.float32), + tensor_lib.TensorSpec([], dtypes.float32), ] ) def foo(a, training=True): @@ -794,7 +794,7 @@ def full_function(a, b, c=3.0): partial = functools.partial(full_function, 1, c=4) a, b, c = partial(2.0) - signature = [tensor_spec.TensorSpec([], dtypes.float32)] + signature = [tensor_lib.TensorSpec([], dtypes.float32)] defined = compiled_fn(partial, input_signature=signature) x = constant_op.constant(2.0) func_a, func_b, func_c = defined(x) @@ -808,8 +808,8 @@ def testInputSignatureWithKeywordPositionalArgs(self): @compiled_fn( input_signature=[ - tensor_spec.TensorSpec([], dtypes.float32), - tensor_spec.TensorSpec([], dtypes.int64), + tensor_lib.TensorSpec([], dtypes.float32), + tensor_lib.TensorSpec([], dtypes.int64), ], function_cache=function_cache, ) @@ -848,8 +848,8 @@ def foo(a, b, **kwargs): x = compiled_fn( foo, input_signature=[ - tensor_spec.TensorSpec([], dtypes.float32), - tensor_spec.TensorSpec([], dtypes.int32), + tensor_lib.TensorSpec([], dtypes.float32), + tensor_lib.TensorSpec([], dtypes.int32), ], ).get_concrete_function() result = x(constant_op.constant(5.0), constant_op.constant(5)) @@ -899,15 +899,15 @@ def f(rt): @test_util.run_v2_only def testInputSignatureWithKeywordOnlyArgs(self): def f(a, b, c=3, *, d=4): - self.assertIsInstance(a, ops.Tensor) - self.assertIsInstance(b, ops.Tensor) + self.assertIsInstance(a, tensor_lib.Tensor) + self.assertIsInstance(b, tensor_lib.Tensor) self.assertIsInstance(c, int) - self.assertIsInstance(d, (int, ops.Tensor)) + self.assertIsInstance(d, (int, tensor_lib.Tensor)) return a + b + c + d signature = [ - tensor_spec.TensorSpec(shape=[], dtype=dtypes.int32), - tensor_spec.TensorSpec(shape=[], dtype=dtypes.int32), + tensor_lib.TensorSpec(shape=[], dtype=dtypes.int32), + tensor_lib.TensorSpec(shape=[], dtype=dtypes.int32), ] defined = compiled_fn(f, input_signature=signature) self.assertEqual(defined(1, 2).numpy(), 10) @@ -935,8 +935,8 @@ def f(a, b, c=3, *, d=4): def testInputSignatureWithKeywordOnlyArgsNoDefaults(self): signature = [ - tensor_spec.TensorSpec(shape=[], dtype=dtypes.int32), - tensor_spec.TensorSpec(shape=[], dtype=dtypes.int32), + tensor_lib.TensorSpec(shape=[], dtype=dtypes.int32), + tensor_lib.TensorSpec(shape=[], dtype=dtypes.int32), ] def test_func(a, *, b): @@ -1104,8 +1104,8 @@ def py_add(x, y): py_add(array_ops.ones([]), array_ops.ones([])) add = py_add.get_concrete_function( - tensor_spec.TensorSpec(None, dtypes.float32), - tensor_spec.TensorSpec(None, dtypes.float32), + tensor_lib.TensorSpec(None, dtypes.float32), + tensor_lib.TensorSpec(None, dtypes.float32), ) @compiled_fn( @@ -1116,8 +1116,8 @@ def py_composite(x, y): py_composite(array_ops.ones([]), array_ops.ones([])) composite = py_composite.get_concrete_function( - tensor_spec.TensorSpec(None, dtypes.float32), - tensor_spec.TensorSpec(None, dtypes.float32), + tensor_lib.TensorSpec(None, dtypes.float32), + tensor_lib.TensorSpec(None, dtypes.float32), ) with context.graph_mode(), self.cached_session(): @@ -1188,8 +1188,8 @@ def matmul(x, y): defun_matmul = compiled_fn( matmul, input_signature=[ - tensor_spec.TensorSpec(shape=(2, 2), dtype=dtypes.float32), - tensor_spec.TensorSpec(shape=(2, 2), dtype=dtypes.float32), + tensor_lib.TensorSpec(shape=(2, 2), dtype=dtypes.float32), + tensor_lib.TensorSpec(shape=(2, 2), dtype=dtypes.float32), ], function_cache=function_cache_lib.FunctionCache(), ) @@ -1450,7 +1450,7 @@ def defined(t): return t z = array_ops.zeros([2, 2]) - z_spec = tensor_spec.TensorSpec.from_tensor(z) + z_spec = tensor_lib.TensorSpec.from_tensor(z) self.assertIs( defined.get_concrete_function(z_spec), defined.get_concrete_function(z) ) @@ -1616,7 +1616,7 @@ def func(x): return array_ops.shape(x) @compiled_fn( - input_signature=[tensor_spec.TensorSpec([None, None], dtypes.float32)] + input_signature=[tensor_lib.TensorSpec([None, None], dtypes.float32)] ) def calls_func(x): return func(x) @@ -2014,8 +2014,8 @@ def fn(a, b): fn(array_ops.ones([]), array_ops.ones([])) fn_op = fn.get_concrete_function( - tensor_spec.TensorSpec(shape=(None,), dtype=dtypes.float32), - tensor_spec.TensorSpec(shape=(), dtype=dtypes.float32), + tensor_lib.TensorSpec(shape=(None,), dtype=dtypes.float32), + tensor_lib.TensorSpec(shape=(), dtype=dtypes.float32), ) self.assertEqual(['a', 'b'], [inp.op.name for inp in fn_op.inputs]) self.assertEqual( @@ -2040,7 +2040,7 @@ def fn(a, b): fn(array_ops.ones([]), array_ops.ones([])) fn_op = fn.get_concrete_function( - tensor_spec.TensorSpec(shape=(None,), dtype=dtypes.float32), + tensor_lib.TensorSpec(shape=(None,), dtype=dtypes.float32), variables.Variable(1.0), ) self.assertEqual(['a', 'b'], [inp.op.name for inp in fn_op.inputs]) @@ -2060,8 +2060,8 @@ def fn(x, z=(1.0, 2.0), y=3.0): fn(array_ops.ones([])) fn_op = fn.get_concrete_function( - x=tensor_spec.TensorSpec(shape=(None,), dtype=dtypes.float32), - y=tensor_spec.TensorSpec(shape=(), dtype=dtypes.float32), + x=tensor_lib.TensorSpec(shape=(None,), dtype=dtypes.float32), + y=tensor_lib.TensorSpec(shape=(), dtype=dtypes.float32), ) self.assertEqual(['x', 'y'], [inp.op.name for inp in fn_op.inputs]) self.assertEqual( @@ -2074,14 +2074,14 @@ def fn(x, z=(1.0, 2.0), y=3.0): fn_op2 = fn.get_concrete_function( z=( - tensor_spec.TensorSpec( + tensor_lib.TensorSpec( shape=(None,), dtype=dtypes.float32, name='z_first' ), - tensor_spec.TensorSpec( + tensor_lib.TensorSpec( shape=(), dtype=dtypes.float32, name='z_second' ), ), - y=tensor_spec.TensorSpec(shape=(), dtype=dtypes.float32, name='custom'), + y=tensor_lib.TensorSpec(shape=(), dtype=dtypes.float32, name='custom'), x=4.0, ) self.assertEqual( @@ -2094,14 +2094,14 @@ def fn(x, z=(1.0, 2.0), y=3.0): ) fn_op3 = fn.get_concrete_function( - tensor_spec.TensorSpec(shape=(), dtype=dtypes.float32, name='custom'), + tensor_lib.TensorSpec(shape=(), dtype=dtypes.float32, name='custom'), z=( - tensor_spec.TensorSpec( + tensor_lib.TensorSpec( shape=(None,), dtype=dtypes.float32, name='z1' ), - tensor_spec.TensorSpec(shape=(), dtype=dtypes.float32, name='z2'), + tensor_lib.TensorSpec(shape=(), dtype=dtypes.float32, name='z2'), ), - y=tensor_spec.TensorSpec(shape=(), dtype=dtypes.float32), + y=tensor_lib.TensorSpec(shape=(), dtype=dtypes.float32), ) self.assertEqual( ['custom', 'z1', 'z2', 'y'], [inp.op.name for inp in fn_op3.inputs] @@ -2120,7 +2120,7 @@ def method(self, x): has_method = HasMethod() compiled_method = compiled_fn(has_method.method) class_op = compiled_method.get_concrete_function( - tensor_spec.TensorSpec(shape=(), dtype=dtypes.float32) + tensor_lib.TensorSpec(shape=(), dtype=dtypes.float32) ) self.assertEqual(['x'], [inp.op.name for inp in class_op.inputs]) self.assertEqual( @@ -2129,7 +2129,7 @@ def method(self, x): ) method_op = compiled_method.get_concrete_function( - tensor_spec.TensorSpec(shape=(), dtype=dtypes.float32) + tensor_lib.TensorSpec(shape=(), dtype=dtypes.float32) ) self.assertEqual(['x'], [inp.op.name for inp in method_op.inputs]) self.assertEqual( @@ -2141,7 +2141,7 @@ def method(self, x): # should always retrace? self.skipTest('Not working') method_op = has_method.method.get_concrete_function( - tensor_spec.TensorSpec(shape=(), dtype=dtypes.float32, name='y') + tensor_lib.TensorSpec(shape=(), dtype=dtypes.float32, name='y') ) self.assertEqual(['y'], [inp.op.name for inp in method_op.inputs]) self.assertEqual( @@ -2160,7 +2160,7 @@ def method(self, x): compiled_method = compiled_fn( has_method.method, input_signature=( - tensor_spec.TensorSpec(shape=None, dtype=dtypes.float64, name='y'), + tensor_lib.TensorSpec(shape=None, dtype=dtypes.float64, name='y'), ), ) @@ -2185,14 +2185,14 @@ def variadic_fn(x, *args, **kwargs): # Call the function to make def_function happy variadic_fn(array_ops.ones([]), array_ops.ones([])) variadic_op = variadic_fn.get_concrete_function( - tensor_spec.TensorSpec(shape=(), dtype=dtypes.float32), - tensor_spec.TensorSpec(shape=None, dtype=dtypes.float32, name='y'), - tensor_spec.TensorSpec(shape=(), dtype=dtypes.float32), - tensor_spec.TensorSpec( + tensor_lib.TensorSpec(shape=(), dtype=dtypes.float32), + tensor_lib.TensorSpec(shape=None, dtype=dtypes.float32, name='y'), + tensor_lib.TensorSpec(shape=(), dtype=dtypes.float32), + tensor_lib.TensorSpec( shape=(), dtype=dtypes.float32, name='second_variadic' ), - z=tensor_spec.TensorSpec(shape=(), dtype=dtypes.float32), - zz=tensor_spec.TensorSpec(shape=(), dtype=dtypes.float32, name='cust'), + z=tensor_lib.TensorSpec(shape=(), dtype=dtypes.float32), + zz=tensor_lib.TensorSpec(shape=(), dtype=dtypes.float32, name='cust'), ) self.assertEqual( ['x', 'y', 'args_1', 'second_variadic', 'z', 'cust'], @@ -2206,10 +2206,10 @@ def variadic_fn(x, *args, **kwargs): def testVariadicInputSignature(self): @compiled_fn( input_signature=( - tensor_spec.TensorSpec(shape=None, dtype=dtypes.float32), - tensor_spec.TensorSpec(shape=None, dtype=dtypes.float32, name='y'), - tensor_spec.TensorSpec(shape=(), dtype=dtypes.float32), - tensor_spec.TensorSpec(shape=(), dtype=dtypes.float32, name='z'), + tensor_lib.TensorSpec(shape=None, dtype=dtypes.float32), + tensor_lib.TensorSpec(shape=None, dtype=dtypes.float32, name='y'), + tensor_lib.TensorSpec(shape=(), dtype=dtypes.float32), + tensor_lib.TensorSpec(shape=(), dtype=dtypes.float32, name='z'), ), name='variadic_fn', ) diff --git a/tensorflow/python/tpu/BUILD b/tensorflow/python/tpu/BUILD index a8f7a81b14954c..d3e94328c64911 100644 --- a/tensorflow/python/tpu/BUILD +++ b/tensorflow/python/tpu/BUILD @@ -263,6 +263,7 @@ pytype_strict_library( "//tensorflow/python/framework:func_graph", "//tensorflow/python/framework:function", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_util", "//tensorflow/python/lib/io:lib", "//tensorflow/python/ops:array_ops", @@ -781,6 +782,7 @@ pytype_strict_library( "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", "//tensorflow/python/framework:sparse_tensor", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:math_ops", @@ -827,6 +829,7 @@ pytype_strict_library( "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", "//tensorflow/python/framework:sparse_tensor", + "//tensorflow/python/framework:tensor", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:array_ops_stack", "//tensorflow/python/ops:embedding_ops", @@ -876,6 +879,7 @@ pytype_strict_library( "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", "//tensorflow/python/framework:sparse_tensor", + "//tensorflow/python/framework:tensor", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:embedding_ops", "//tensorflow/python/ops:math_ops", @@ -926,7 +930,7 @@ tpu_py_strict_test( "//tensorflow/python/framework:constant_op", "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", - "//tensorflow/python/framework:tensor_spec", + "//tensorflow/python/framework:tensor", "//tensorflow/python/lib/io:lib", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:cond", diff --git a/tensorflow/python/tpu/tensor_tracer.py b/tensorflow/python/tpu/tensor_tracer.py index 8a2e69e46961c5..e8550ddeb1de87 100644 --- a/tensorflow/python/tpu/tensor_tracer.py +++ b/tensorflow/python/tpu/tensor_tracer.py @@ -31,6 +31,7 @@ from tensorflow.python.framework import function from tensorflow.python.framework import graph_io from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_util from tensorflow.python.lib.io import file_io from tensorflow.python.ops import array_ops @@ -1300,7 +1301,7 @@ def _filter_execution_path_operations(self, operations, fetches): for fetch in fetches: if isinstance(fetch, ops.Operation): op_fetches.append(fetch) - elif isinstance(fetch, ops.Tensor): + elif isinstance(fetch, tensor_lib.Tensor): op_fetches.append(fetch.op) else: raise RuntimeError('Given fetch:%s is neither a tensor nor an op.' @@ -1741,7 +1742,7 @@ def _process_tensor_fetches(self, tensor_fetches): 'empty list.') fetches = [] for fetch in tensor_fetches: - if isinstance(fetch, ops.Tensor): + if isinstance(fetch, tensor_lib.Tensor): fetches.append(fetch) else: raise RuntimeError('Given tensor_fetch:%s is not a tensor.' % fetch) @@ -1759,7 +1760,7 @@ def _process_op_fetches(self, op_fetches): for fetch in op_fetches: if isinstance(fetch, ops.Operation): fetches.append(fetch) - elif isinstance(fetch, ops.Tensor): + elif isinstance(fetch, tensor_lib.Tensor): fetches.append(fetch.op) else: logging.warning('Ignoring the given op_fetch:%s, which is not an op.' % @@ -1768,7 +1769,7 @@ def _process_op_fetches(self, op_fetches): def _convert_fetches_to_input_format(self, input_fetches, current_fetches): """Changes current_fetches' format, so that it matches input_fetches.""" - if isinstance(input_fetches, ops.Tensor): + if isinstance(input_fetches, tensor_lib.Tensor): if len(current_fetches) != 1: raise RuntimeError('Tensor tracer input/output fetches do not match.') return current_fetches[0] diff --git a/tensorflow/python/tpu/tpu_embedding_for_serving.py b/tensorflow/python/tpu/tpu_embedding_for_serving.py index 9914e084bb1f37..fb8a3205e1358c 100644 --- a/tensorflow/python/tpu/tpu_embedding_for_serving.py +++ b/tensorflow/python/tpu/tpu_embedding_for_serving.py @@ -22,6 +22,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops_stack from tensorflow.python.ops import embedding_ops @@ -292,7 +293,7 @@ def serve_tensors(embedding_features): table = tables[feature.table] if weight is not None: - if isinstance(inp, ops.Tensor): + if isinstance(inp, tensor.Tensor): raise ValueError( "Weight specified for {}, but input is dense.".format(path)) elif type(weight) is not type(inp): @@ -303,7 +304,7 @@ def serve_tensors(embedding_features): raise ValueError("Weight specified for {}, but this is a sequence " "feature.".format(path)) - if isinstance(inp, ops.Tensor): + if isinstance(inp, tensor.Tensor): if feature.max_sequence_length > 0: raise ValueError("Feature {} is a sequence feature but a dense tensor " "was passed.".format(path)) @@ -324,7 +325,7 @@ def serve_tensors(embedding_features): def _embedding_lookup_for_sparse_tensor( inp: sparse_tensor.SparseTensor, weight: Optional[sparse_tensor.SparseTensor], table: tf_variables.Variable, - feature: tpu_embedding_v2_utils.FeatureConfig) -> ops.Tensor: + feature: tpu_embedding_v2_utils.FeatureConfig) -> tensor.Tensor: """Embedding lookup for sparse tensor based on its feature config. Args: @@ -380,7 +381,7 @@ def _embedding_lookup_for_sparse_tensor( def _embedding_lookup_for_ragged_tensor( inp: ragged_tensor.RaggedTensor, weight: Optional[ragged_tensor.RaggedTensor], table: tf_variables.Variable, - feature: tpu_embedding_v2_utils.FeatureConfig) -> ops.Tensor: + feature: tpu_embedding_v2_utils.FeatureConfig) -> tensor.Tensor: """Embedding lookup for ragged tensor based on its feature config. Args: diff --git a/tensorflow/python/tpu/tpu_embedding_v1.py b/tensorflow/python/tpu/tpu_embedding_v1.py index 7b19500025bbc1..259650cd9f8396 100644 --- a/tensorflow/python/tpu/tpu_embedding_v1.py +++ b/tensorflow/python/tpu/tpu_embedding_v1.py @@ -21,6 +21,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor from tensorflow.python.ops import array_ops from tensorflow.python.ops import embedding_ops from tensorflow.python.ops import math_ops @@ -179,9 +180,9 @@ def _maybe_build(self): def _apply_combiner_to_embeddings( self, - embeddings: ops.Tensor, - weight: ops.Tensor, - combiner: Optional[Text] = None) -> ops.Tensor: + embeddings: tensor.Tensor, + weight: tensor.Tensor, + combiner: Optional[Text] = None) -> tensor.Tensor: """Apply the combiner to the embedding look up result on second to last axis. Args: @@ -213,8 +214,9 @@ def _apply_combiner_to_embeddings( f"combiner must be one of 'mean', 'sqrtn' or 'sum', got {combiner}") return embeddings - def _pad_or_truncate_with_sequence_length(self, embeddings: ops.Tensor, - sequence_length: int) -> ops.Tensor: + def _pad_or_truncate_with_sequence_length( + self, embeddings: tensor.Tensor, sequence_length: int + ) -> tensor.Tensor: """Pad or truncate the embedding lookup result based on the sequence length. Args: @@ -272,7 +274,7 @@ def embedding_lookup(self, table = self.embedding_tables[feature.table] if weight is not None: - if isinstance(inp, ops.Tensor): + if isinstance(inp, tensor.Tensor): raise ValueError( "Weight specified for {}, but input is dense.".format(path)) elif type(weight) is not type(inp): @@ -283,7 +285,7 @@ def embedding_lookup(self, raise ValueError("Weight specified for {}, but this is a sequence " "feature.".format(path)) - if isinstance(inp, ops.Tensor): + if isinstance(inp, tensor.Tensor): if feature.max_sequence_length > 0: raise ValueError( "Feature {} is a sequence feature but a dense tensor " @@ -307,7 +309,7 @@ def _embedding_lookup_for_sparse_tensor( self, inp: sparse_tensor.SparseTensor, weight: Optional[sparse_tensor.SparseTensor], table: tf_variables.Variable, - feature: tpu_embedding_v2_utils.FeatureConfig) -> ops.Tensor: + feature: tpu_embedding_v2_utils.FeatureConfig) -> tensor.Tensor: """Embedding lookup for sparse tensor based on its feature config. Args: @@ -352,7 +354,7 @@ def _embedding_lookup_for_ragged_tensor( self, inp: ragged_tensor.RaggedTensor, weight: Optional[ragged_tensor.RaggedTensor], table: tf_variables.Variable, - feature: tpu_embedding_v2_utils.FeatureConfig) -> ops.Tensor: + feature: tpu_embedding_v2_utils.FeatureConfig) -> tensor.Tensor: """Embedding lookup for ragged tensor based on its feature config. Args: @@ -398,7 +400,10 @@ def ragged_to_dense_outside_compilation(inp, weight, batch_size, feature): # If the data batch size is a factor of the output batch size, the # divide result will be the sequence length. Ignore the weights and # combiner. - elif output_batch_size > batch_size and output_batch_size % batch_size == 0: + elif ( + output_batch_size > batch_size + and output_batch_size % batch_size == 0 + ): # Pad or truncate in the sequence dimension seq_length = output_batch_size // batch_size inp = inp.to_tensor(shape=(batch_size, seq_length)) diff --git a/tensorflow/python/tpu/tpu_embedding_v2.py b/tensorflow/python/tpu/tpu_embedding_v2.py index 787bf6e23bd6b0..bf954ac0a55e77 100644 --- a/tensorflow/python/tpu/tpu_embedding_v2.py +++ b/tensorflow/python/tpu/tpu_embedding_v2.py @@ -33,6 +33,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework.tensor_shape import TensorShape from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops @@ -687,7 +688,7 @@ def tpu_step(tpu_features): full_output_shape = [x * num_cores_per_replica for x in output_shape] + [ feature.table.dim ] - if gradient is not None and not isinstance(gradient, ops.Tensor): + if gradient is not None and not isinstance(gradient, tensor_lib.Tensor): raise ValueError( f"found non-tensor type: {type(gradient)} at path {path}.") if gradient is not None: @@ -992,7 +993,7 @@ def _generate_enqueue_op( # early. for inp, weight, (path, feature) in zip( flat_inputs, flat_weights, flat_features): - if isinstance(inp, ops.Tensor): + if isinstance(inp, tensor_lib.Tensor): self._add_data_for_tensor(inp, weight, indices_or_row_splits, values, weights, int_zeros, float_zeros, path) elif isinstance(inp, sparse_tensor.SparseTensor): @@ -1310,7 +1311,7 @@ def generate_enqueue_ops(): def _split_fn(ts, idx): if ts is None: return None - elif isinstance(ts, ops.Tensor): + elif isinstance(ts, tensor_lib.Tensor): return array_ops.split( ts, num_or_size_splits=self._num_cores_per_replica, @@ -1389,7 +1390,7 @@ def _get_input_shapes( else: tensor = maybe_tensor - if isinstance(tensor, ops.Tensor): + if isinstance(tensor, tensor_lib.Tensor): input_shapes.append( self._get_input_shape_for_tensor(tensor, feature, per_replica, path) ) diff --git a/tensorflow/python/tpu/tpu_outside_compilation_test.py b/tensorflow/python/tpu/tpu_outside_compilation_test.py index eb0ce826ca69c6..505aea97aa021a 100644 --- a/tensorflow/python/tpu/tpu_outside_compilation_test.py +++ b/tensorflow/python/tpu/tpu_outside_compilation_test.py @@ -33,7 +33,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_spec +from tensorflow.python.framework import tensor from tensorflow.python.lib.io import tf_record from tensorflow.python.ops import array_ops from tensorflow.python.ops import cond @@ -812,11 +812,11 @@ def train_step(x): partitioned_tpu_fn = _tpu_partitioned_call_wrapper(tpu_fn) concrete = partitioned_tpu_fn.get_concrete_function( - x=tensor_spec.TensorSpec( + x=tensor.TensorSpec( shape=(1), dtype=dtypes.float32, name="input_tensor")) self.assertIsInstance( - concrete(array_ops.ones((1), dtype=dtypes.float32))[0], ops.Tensor) + concrete(array_ops.ones((1), dtype=dtypes.float32))[0], tensor.Tensor) if __name__ == "__main__": diff --git a/tensorflow/python/training/BUILD b/tensorflow/python/training/BUILD index f65c6536005e40..1c46c7e05cda5c 100644 --- a/tensorflow/python/training/BUILD +++ b/tensorflow/python/training/BUILD @@ -297,6 +297,7 @@ py_strict_library( "//tensorflow/python/framework:indexed_slices", "//tensorflow/python/framework:ops", "//tensorflow/python/framework:sparse_tensor", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/layers:layers_util", "//tensorflow/python/ops:array_ops", @@ -347,6 +348,7 @@ py_strict_library( "//tensorflow/python/distribute:reduce_util", "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/ops:cond", "//tensorflow/python/ops:control_flow_ops", "//tensorflow/python/ops:init_ops", @@ -379,6 +381,7 @@ py_strict_library( "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:indexed_slices", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:control_flow_ops", "//tensorflow/python/ops:gradients", @@ -532,6 +535,7 @@ py_strict_library( "//tensorflow/python/distribute:distribute_lib", "//tensorflow/python/framework:indexed_slices", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:control_flow_ops", "//tensorflow/python/ops:data_flow_ops", @@ -1080,6 +1084,7 @@ py_strict_library( "//tensorflow/python/framework", "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/ops:cond", "//tensorflow/python/ops:init_ops", "//tensorflow/python/ops:resource_variable_ops", @@ -1687,6 +1692,7 @@ tf_py_strict_test( "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:errors", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:test_lib", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:control_flow_assert", diff --git a/tensorflow/python/training/input.py b/tensorflow/python/training/input.py index ee1d019f75630e..23bd73220042c1 100644 --- a/tensorflow/python/training/input.py +++ b/tensorflow/python/training/input.py @@ -27,6 +27,7 @@ from tensorflow.python.framework import indexed_slices from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_shape from tensorflow.python.layers import utils from tensorflow.python.ops import array_ops @@ -251,7 +252,7 @@ def string_input_producer(string_tensor, @end_compatibility """ not_null_err = "string_input_producer requires a non-null input tensor" - if not isinstance(string_tensor, ops.Tensor) and not string_tensor: + if not isinstance(string_tensor, tensor_lib.Tensor) and not string_tensor: raise ValueError(not_null_err) with ops.name_scope(name, "input_producer", [string_tensor]) as name: diff --git a/tensorflow/python/training/monitored_session_test.py b/tensorflow/python/training/monitored_session_test.py index 6e0e1a99ec2205..7ad3e772f2a983 100644 --- a/tensorflow/python/training/monitored_session_test.py +++ b/tensorflow/python/training/monitored_session_test.py @@ -33,6 +33,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors_impl from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_assert @@ -84,8 +85,9 @@ def test_defaults_empty_graph(self): self.assertTrue(isinstance(scaffold.init_op, ops.Operation)) self.assertEqual(None, scaffold.init_feed_dict) self.assertEqual(None, scaffold.init_fn) - self.assertTrue(isinstance(scaffold.ready_op, ops.Tensor)) - self.assertTrue(isinstance(scaffold.ready_for_local_init_op, ops.Tensor)) + self.assertTrue(isinstance(scaffold.ready_op, tensor.Tensor)) + self.assertTrue(isinstance( + scaffold.ready_for_local_init_op, tensor.Tensor)) self.assertTrue(isinstance(scaffold.local_init_op, ops.Operation)) self.assertEqual(None, scaffold.local_init_feed_dict) self.assertTrue(isinstance(scaffold.saver, saver_lib.Saver)) @@ -107,8 +109,9 @@ def test_defaults_no_variables(self): self.assertTrue(isinstance(scaffold.init_op, ops.Operation)) self.assertEqual(None, scaffold.init_feed_dict) self.assertEqual(None, scaffold.init_fn) - self.assertTrue(isinstance(scaffold.ready_op, ops.Tensor)) - self.assertTrue(isinstance(scaffold.ready_for_local_init_op, ops.Tensor)) + self.assertTrue(isinstance(scaffold.ready_op, tensor.Tensor)) + self.assertTrue(isinstance( + scaffold.ready_for_local_init_op, tensor.Tensor)) self.assertTrue(isinstance(scaffold.local_init_op, ops.Operation)) self.assertEqual(None, scaffold.local_init_feed_dict) self.assertTrue(isinstance(scaffold.saver, saver_lib.Saver)) diff --git a/tensorflow/python/training/moving_averages.py b/tensorflow/python/training/moving_averages.py index 1f3d05a2163d2f..d310f3488f1524 100644 --- a/tensorflow/python/training/moving_averages.py +++ b/tensorflow/python/training/moving_averages.py @@ -17,6 +17,7 @@ from tensorflow.python.distribute import reduce_util as ds_reduce_util from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.ops import cond from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import init_ops @@ -531,7 +532,7 @@ def apply(self, var_list=None): if var_list is None: var_list = variables.trainable_variables() for v in var_list: - if (isinstance(v, ops.Tensor) + if (isinstance(v, tensor.Tensor) and ops.executing_eagerly_outside_functions()): raise TypeError( "tf.train.ExponentialMovingAverage does not support non-Variable" diff --git a/tensorflow/python/training/optimizer.py b/tensorflow/python/training/optimizer.py index 4af9a1ebe43664..aa59f2e343cdf8 100644 --- a/tensorflow/python/training/optimizer.py +++ b/tensorflow/python/training/optimizer.py @@ -26,6 +26,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import indexed_slices from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gradients @@ -115,7 +116,7 @@ def target(self): return self._v._ref() # pylint: disable=protected-access def update_op(self, optimizer, g): - if isinstance(g, ops.Tensor): + if isinstance(g, tensor.Tensor): update_op = optimizer._apply_dense(g, self._v) # pylint: disable=protected-access if self._v.constraint is not None: with ops.control_dependencies([update_op]): @@ -197,7 +198,7 @@ def update_op(self, optimizer, g): def _get_processor(v): """The processor of v.""" if context.executing_eagerly(): - if isinstance(v, ops.Tensor): + if isinstance(v, tensor.Tensor): return _TensorProcessor(v) else: return _DenseResourceVariableProcessor(v) @@ -208,7 +209,7 @@ def _get_processor(v): return _DenseResourceVariableProcessor(v) if isinstance(v, variables.Variable): return _RefVariableProcessor(v) - if isinstance(v, ops.Tensor): + if isinstance(v, tensor.Tensor): return _TensorProcessor(v) raise NotImplementedError("Trying to optimize unsupported type ", v) @@ -690,7 +691,7 @@ def apply_gradients( raise TypeError( "Gradient must be convertible to a Tensor" " or IndexedSlices, or None: %s" % g) - if not isinstance(g, (ops.Tensor, indexed_slices.IndexedSlices)): + if not isinstance(g, (tensor.Tensor, indexed_slices.IndexedSlices)): raise TypeError( "Gradient must be a Tensor, IndexedSlices, or None: %s" % g) p = _get_processor(v) @@ -739,7 +740,7 @@ def apply_gradients( apply_updates = state_ops.assign_add(global_step, 1, name=name) if not context.executing_eagerly(): - if isinstance(apply_updates, ops.Tensor): + if isinstance(apply_updates, tensor.Tensor): apply_updates = apply_updates.op train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) if apply_updates not in train_op: @@ -791,7 +792,7 @@ def update(v, g): except TypeError: raise TypeError("Gradient must be convertible to a Tensor" " or IndexedSlices, or None: %s" % g) - if not isinstance(g, (ops.Tensor, indexed_slices.IndexedSlices)): + if not isinstance(g, (tensor.Tensor, indexed_slices.IndexedSlices)): raise TypeError( "Gradient must be a Tensor, IndexedSlices, or None: %s" % g) p = _get_processor(v) @@ -834,7 +835,7 @@ def finish(self, update_ops): kwargs={"name": name}) if not context.executing_eagerly(): - if isinstance(apply_updates, ops.Tensor): + if isinstance(apply_updates, tensor.Tensor): apply_updates = apply_updates.op train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) if apply_updates not in train_op: diff --git a/tensorflow/python/training/sync_replicas_optimizer.py b/tensorflow/python/training/sync_replicas_optimizer.py index 195c928764a0f4..26c26dc1ff7627 100644 --- a/tensorflow/python/training/sync_replicas_optimizer.py +++ b/tensorflow/python/training/sync_replicas_optimizer.py @@ -17,6 +17,7 @@ from tensorflow.python.distribute import distribute_lib from tensorflow.python.framework import indexed_slices from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import data_flow_ops @@ -277,7 +278,7 @@ def apply_gradients(self, grads_and_vars, global_step=None, name=None): if grad is None: aggregated_grad.append(None) # pass-through. continue - elif isinstance(grad, ops.Tensor): + elif isinstance(grad, tensor.Tensor): grad_accum = data_flow_ops.ConditionalAccumulator( grad.dtype, shape=var.get_shape(), diff --git a/tensorflow/python/training/training_util.py b/tensorflow/python/training/training_util.py index 050828f7637c6d..778ad9771f8591 100644 --- a/tensorflow/python/training/training_util.py +++ b/tensorflow/python/training/training_util.py @@ -17,6 +17,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import graph_io from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.ops import cond from tensorflow.python.ops import init_ops from tensorflow.python.ops import resource_variable_ops @@ -333,7 +334,7 @@ def assert_global_step(global_step_tensor): global_step_tensor: `Tensor` to test. """ if not (isinstance(global_step_tensor, variables.Variable) or - isinstance(global_step_tensor, ops.Tensor) or + isinstance(global_step_tensor, tensor.Tensor) or resource_variable_ops.is_resource_variable(global_step_tensor)): raise TypeError('Existing "global_step" must be a Variable or Tensor: %s.' % global_step_tensor) From 20b22c53244b83b2d734179220a840b19f251296 Mon Sep 17 00:00:00 2001 From: David Dunleavy Date: Mon, 10 Jul 2023 13:42:12 -0700 Subject: [PATCH 076/376] Add get_compatible_with_cloud to `service/gpu:gpu_asm_opts_util` PiperOrigin-RevId: 546968794 --- tensorflow/compiler/xla/service/gpu/BUILD | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 316330826a364b..b2c7fd02999d60 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -2945,6 +2945,7 @@ cc_library( name = "gpu_asm_opts_util", srcs = ["gpu_asm_opts_util.cc"], hdrs = ["gpu_asm_opts_util.h"], + compatible_with = get_compatible_with_cloud(), copts = tsl_copts(), deps = [ "//tensorflow/compiler/xla:xla_proto_cc", From 52b4f7e4246dc36fb0d31fb746ac9049c50c5606 Mon Sep 17 00:00:00 2001 From: Jieying Luo Date: Mon, 10 Jul 2023 13:46:06 -0700 Subject: [PATCH 077/376] [TF:PJRT] In GpuDevice, always gets receive stream from tensorflow_accelerator_device_info()->default_context. In GpuDevice, tensorflow_accelerator_device_info()->default_context is GPUDeviceContext. When tensorflow_accelerator_device_info()->use_pjrt_tensor_buffer is true, the recv_dev_context is a PjRtDeviceContext which does not have a pointer to the stream. PiperOrigin-RevId: 546970116 --- tensorflow/core/common_runtime/gpu/gpu_util.cc | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tensorflow/core/common_runtime/gpu/gpu_util.cc b/tensorflow/core/common_runtime/gpu/gpu_util.cc index fdc56d0a35f9f9..b699239fdb979b 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_util.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_util.cc @@ -219,8 +219,13 @@ void GPUUtil::DeviceToDeviceCopy( DeviceMemoryBase gpu_src_ptr(src_ptr, total_bytes); void* dst_ptr = GetBase(output); DeviceMemoryBase gpu_dst_ptr(dst_ptr, total_bytes); - auto recv_stream = - static_cast(recv_dev_context)->stream(); + // For GpuDevice, always gets receive stream from + // dst->tensorflow_accelerator_device_info()->default_context which is + // GPUDeviceContext. + stream_executor::Stream* recv_stream = + static_cast( + dst->tensorflow_accelerator_device_info()->default_context) + ->stream(); if (recv_stream == nullptr) { done(errors::Internal("No recv gpu stream is available.")); return; From 7bf78095e6e1309dbc5bd35dfb5e45606cc91924 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 10 Jul 2023 14:09:22 -0700 Subject: [PATCH 078/376] Adds a tiny amount of deterministic salt (< 0.01%) to each term in the objective function so that the optimal solution will be unique for the vast majority of models. Also removes some solver parameters -- originally added to achieve determinism a different way -- that were triggering a significantly slower CP-SAT recipe. Combined, these changes result in a solver that is (a) nearly always deterministic, and (b) lightning fast. PiperOrigin-RevId: 546979099 --- .../xla/hlo/experimental/auto_sharding/BUILD | 2 ++ .../auto_sharding/auto_sharding_solver.cc | 32 +++++++++++++------ .../auto_sharding/auto_sharding_solver.h | 1 + 3 files changed, 26 insertions(+), 9 deletions(-) diff --git a/tensorflow/compiler/xla/hlo/experimental/auto_sharding/BUILD b/tensorflow/compiler/xla/hlo/experimental/auto_sharding/BUILD index 3b1817c2790e23..6f785714bcaf59 100644 --- a/tensorflow/compiler/xla/hlo/experimental/auto_sharding/BUILD +++ b/tensorflow/compiler/xla/hlo/experimental/auto_sharding/BUILD @@ -59,6 +59,8 @@ cc_library( deps = [ ":auto_sharding_strategy", "//tensorflow/compiler/xla:statusor", + "//tensorflow/tsl/platform:hash", + "//tensorflow/tsl/platform:types", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", "@com_google_ortools//ortools/linear_solver", diff --git a/tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc b/tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc index 6e0cb4b9ae85ff..c31315ea83d8a7 100644 --- a/tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc +++ b/tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include #include #include #include @@ -26,6 +27,8 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h" +#include "tensorflow/tsl/platform/hash.h" +#include "tensorflow/tsl/platform/types.h" #include "ortools/linear_solver/linear_solver.h" #include "ortools/linear_solver/linear_solver.pb.h" #ifdef PLATFORM_GOOGLE @@ -104,6 +107,13 @@ void PrintLargestInstructions( } } +// Adds deterministic noise to the coefficient using the name & salt multiplier. +void AddSalt(const std::string& name, double saltiplier, double* coeff) { + if (saltiplier <= 0.0) return; + const tsl::uint64 hash = tsl::Hash64(name); // stable across runs & platforms + *coeff *= 1.0 + saltiplier * hash / std::numeric_limits::max(); +} + // We formulate the auto sharding process as the following ILP problem: // Variables: // s[i]: Sharding strategy one-hot vector. @@ -162,12 +172,8 @@ AutoShardingSolverResult CallORToolsSolver( #ifdef PLATFORM_GOOGLE if (solver->ProblemType() == operations_research::MPSolver::SAT_INTEGER_PROGRAMMING) { - // Set random_seed, interleave_search and share_binary_clauses for - // determinism, and num_workers for parallelism. - solver_parameter_str = absl::StrCat( - "share_binary_clauses:false,random_seed:1,interleave_" - "search:true,num_workers:", - num_workers); + // Set num_workers for parallelism. + solver_parameter_str = absl::StrCat("num_workers:", num_workers); solver->SetSolverSpecificParametersAsString(solver_parameter_str); } #endif @@ -206,8 +212,10 @@ AutoShardingSolverResult CallORToolsSolver( for (size_t j = 0; j < s[i].size(); ++j) { double accumulated_coefficient = solver->MutableObjective()->GetCoefficient(s[i][j]); + double coefficient = request.c[i][j] + request.d[i][j]; + AddSalt(absl::StrCat(i, "S", j), request.saltiplier, &coefficient); solver->MutableObjective()->SetCoefficient( - s[i][j], accumulated_coefficient + request.c[i][j] + request.d[i][j]); + s[i][j], accumulated_coefficient + coefficient); } } // Edge costs @@ -215,8 +223,10 @@ AutoShardingSolverResult CallORToolsSolver( for (size_t j = 0; j < e[i].size(); ++j) { double accumulated_coefficient = solver->MutableObjective()->GetCoefficient(e[i][j]); + double coefficient = request.r[i][j]; + AddSalt(absl::StrCat(i, "E", j), request.saltiplier, &coefficient); solver->MutableObjective()->SetCoefficient( - e[i][j], accumulated_coefficient + request.r[i][j]); + e[i][j], accumulated_coefficient + coefficient); } } @@ -460,6 +470,7 @@ AutoShardingSolverResult CallORToolsSolver( } // Return value + double unsalted_objective = 0.0; std::vector chosen_strategy(request.num_nodes, -1), e_val(num_edges, -1); for (int i = 0; i < request.num_nodes; ++i) { @@ -467,6 +478,7 @@ AutoShardingSolverResult CallORToolsSolver( // if lhs == 1 if (s[i][j]->solution_value() > 0.5) { chosen_strategy[i] = j; + unsalted_objective += request.c[i][j] + request.d[i][j]; break; } } @@ -476,11 +488,13 @@ AutoShardingSolverResult CallORToolsSolver( // if lhs == 1 if (e[i][j]->solution_value() > 0.5) { e_val[i] = j; + unsalted_objective += request.r[i][j]; break; } } } + LOG(INFO) << "Unsalted objective value: " << unsalted_objective; LOG(INFO) << "N = " << request.num_nodes; if (request.memory_budget < 0) { LOG(INFO) << "memory budget: -1"; @@ -492,7 +506,7 @@ AutoShardingSolverResult CallORToolsSolver( request.instruction_names); return AutoShardingSolverResult( std::make_tuple(std::move(chosen_strategy), std::move(e_val), - solver->Objective().Value()), + unsalted_objective), false); } diff --git a/tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding_solver.h b/tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding_solver.h index 5225b91e143c43..88bc578bceec56 100644 --- a/tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding_solver.h +++ b/tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding_solver.h @@ -44,6 +44,7 @@ struct AutoShardingSolverRequest { std::vector instruction_names; std::optional solver_timeout_in_seconds; bool crash_at_infinity_costs_check = false; + double saltiplier = 0.0001; // Modifies each objective term by at most 0.01% }; struct AutoShardingSolverResult { From 50904301327580580324e9744d4ea27ef3cbb179 Mon Sep 17 00:00:00 2001 From: Fiona Lang Date: Mon, 10 Jul 2023 14:29:03 -0700 Subject: [PATCH 079/376] Update ops.Tensor references to //third_party/tensorflow/python/framework/tensor.py. PiperOrigin-RevId: 546984067 --- .../python/kernel_tests/array_ops/BUILD | 11 ++- .../kernel_tests/array_ops/array_ops_test.py | 80 ++++++++++--------- .../array_ops/constant_op_test.py | 5 +- .../python/kernel_tests/data_structures/BUILD | 14 +++- .../data_structures/barrier_ops_test.py | 3 +- .../data_structures/fifo_queue_test.py | 11 +-- .../padding_fifo_queue_test.py | 7 +- tensorflow/python/kernel_tests/linalg/BUILD | 8 +- .../kernel_tests/linalg/linalg_ops_test.py | 5 +- .../linalg/linear_operator_util_test.py | 3 +- .../python/kernel_tests/variables/BUILD | 6 +- .../variables/resource_variable_ops_test.py | 5 +- .../kernel_tests/variables/variables_test.py | 3 +- tensorflow/python/ops/numpy_ops/tests/BUILD | 1 - .../python/ops/numpy_ops/tests/extensions.py | 14 ++-- .../python/ops/numpy_ops/tests/test_util.py | 3 +- tensorflow/python/util/BUILD | 2 + tensorflow/python/util/dispatch_test.py | 18 +++-- tensorflow/python/util/variable_utils_test.py | 11 +-- 19 files changed, 122 insertions(+), 88 deletions(-) diff --git a/tensorflow/python/kernel_tests/array_ops/BUILD b/tensorflow/python/kernel_tests/array_ops/BUILD index 771439070ec402..d5a07b25752942 100644 --- a/tensorflow/python/kernel_tests/array_ops/BUILD +++ b/tensorflow/python/kernel_tests/array_ops/BUILD @@ -23,10 +23,12 @@ cuda_py_strict_test( "//tensorflow/python/eager:def_function", "//tensorflow/python/framework:config", "//tensorflow/python/framework:constant_op", + "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:errors", - "//tensorflow/python/framework:for_generated_wrappers", + "//tensorflow/python/framework:ops", "//tensorflow/python/framework:sparse_tensor", - "//tensorflow/python/framework:tensor_spec", + "//tensorflow/python/framework:tensor", + "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/framework:test_lib", "//tensorflow/python/framework:test_ops", "//tensorflow/python/ops:array_ops", @@ -216,8 +218,11 @@ cuda_py_strict_test( "//tensorflow/python/eager:def_function", "//tensorflow/python/framework", "//tensorflow/python/framework:constant_op", + "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:errors", - "//tensorflow/python/framework:for_generated_wrappers", + "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", + "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/framework:test_lib", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:gradient_checker", diff --git a/tensorflow/python/kernel_tests/array_ops/array_ops_test.py b/tensorflow/python/kernel_tests/array_ops/array_ops_test.py index 7090bd7a3c49ae..2cfeb07d558154 100644 --- a/tensorflow/python/kernel_tests/array_ops/array_ops_test.py +++ b/tensorflow/python/kernel_tests/array_ops/array_ops_test.py @@ -32,8 +32,8 @@ from tensorflow.python.framework import errors_impl from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import test_ops from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops @@ -90,7 +90,7 @@ def testNonBatchMatrixDynamicallyDefined(self): expected_transposed = [[1, 4], [2, 5], [3, 6]] # Shape (3, 2) @def_function.function(input_signature=[ - tensor_spec.TensorSpec(shape=None, dtype=dtypes.int32) + tensor_lib.TensorSpec(shape=None, dtype=dtypes.int32) ]) def transpose(matrix): self.assertIs(matrix.shape.ndims, None) @@ -109,7 +109,7 @@ def testBatchMatrixDynamicallyDefined(self): expected_transposed = [matrix_0_t, matrix_1_t] # Shape (2, 3, 2) @def_function.function(input_signature=[ - tensor_spec.TensorSpec(shape=None, dtype=dtypes.int32) + tensor_lib.TensorSpec(shape=None, dtype=dtypes.int32) ]) def transpose(matrix): self.assertIs(matrix.shape.ndims, None) @@ -244,8 +244,8 @@ def func(ph_tensor, ph_mask): return array_ops.boolean_mask(ph_tensor, ph_mask) f = func.get_concrete_function( - tensor_spec.TensorSpec(None, dtypes.int32), - tensor_spec.TensorSpec([None], dtypes.bool)) + tensor_lib.TensorSpec(None, dtypes.int32), + tensor_lib.TensorSpec([None], dtypes.bool)) arr = np.array([[1, 2], [3, 4]], np.int32) mask = np.array([False, True]) masked_tensor = f(arr, mask) @@ -260,8 +260,8 @@ def func(tensor, mask): with self.assertRaisesRegex(ValueError, "dimensions must be specified"): _ = func.get_concrete_function( - tensor_spec.TensorSpec([None, 2], dtypes.int32), - tensor_spec.TensorSpec(None, dtypes.bool)) + tensor_lib.TensorSpec([None, 2], dtypes.int32), + tensor_lib.TensorSpec(None, dtypes.bool)) def testMaskHasMoreDimsThanTensorRaises(self): mask = [[True, True], [False, False]] @@ -314,7 +314,7 @@ def testMaskWithAxisNonConstTensor(self): @def_function.function( autograph=False, input_signature=[ - tensor_spec.TensorSpec(shape=None, dtype=dtypes.int32) + tensor_lib.TensorSpec(shape=None, dtype=dtypes.int32) ]) def f(axis): return array_ops.boolean_mask([1, 2, 3], [True, False, True], axis=axis) @@ -595,10 +595,12 @@ def casts_to_bool_nparray(x): except NotImplementedError: return False - if isinstance(spec, bool) or \ - (isinstance(spec, ops.Tensor) and spec.dtype == dtypes.bool) or \ - (isinstance(spec, np.ndarray) and spec.dtype == bool) or \ - (isinstance(spec, (list, tuple)) and casts_to_bool_nparray(spec)): + if ( + isinstance(spec, bool) + or (isinstance(spec, tensor_lib.Tensor) and spec.dtype == dtypes.bool) + or (isinstance(spec, np.ndarray) and spec.dtype == bool) + or (isinstance(spec, (list, tuple)) and casts_to_bool_nparray(spec)) + ): tensor = self.test.evaluate(op) np_spec = eval_if_tensor(spec) self.test.assertAllEqual(self.x_np[np_spec], tensor) @@ -753,7 +755,7 @@ def func(inp): return inp[array_ops.newaxis, :, 0] f = func.get_concrete_function( - tensor_spec.TensorSpec([2, 2], dtypes.int16)) + tensor_lib.TensorSpec([2, 2], dtypes.int16)) # TODO(b/190416665): Allow the constant to be eagerly copied/created on # the GPU. @@ -892,7 +894,7 @@ def f(x): y = x[...] self.assertAllEqual(y.get_shape().ndims, None) - _ = f.get_concrete_function(tensor_spec.TensorSpec(None, dtypes.float32)) + _ = f.get_concrete_function(tensor_lib.TensorSpec(None, dtypes.float32)) def testScalarInput(self): c = constant_op.constant(3) @@ -916,7 +918,7 @@ def f1(x): tensor_shape.TensorShape([2, None, 7])) _ = f1.get_concrete_function( - tensor_spec.TensorSpec((5, None, 7), dtypes.float32)) + tensor_lib.TensorSpec((5, None, 7), dtypes.float32)) @def_function.function def f2(x): @@ -925,7 +927,7 @@ def f2(x): None])) _ = f2.get_concrete_function( - tensor_spec.TensorSpec((5, None, 7), dtypes.float32)) + tensor_lib.TensorSpec((5, None, 7), dtypes.float32)) @def_function.function def f3(x): @@ -934,7 +936,7 @@ def f3(x): None])) _ = f3.get_concrete_function( - tensor_spec.TensorSpec((5, None, 7), dtypes.float32)) + tensor_lib.TensorSpec((5, None, 7), dtypes.float32)) @def_function.function def f4(x): @@ -943,7 +945,7 @@ def f4(x): tensor_shape.TensorShape([2, None, 2])) _ = f4.get_concrete_function( - tensor_spec.TensorSpec((5, None, 7), dtypes.float32)) + tensor_lib.TensorSpec((5, None, 7), dtypes.float32)) @def_function.function def f5(x): @@ -952,7 +954,7 @@ def f5(x): tensor_shape.TensorShape([2, None, 0])) _ = f5.get_concrete_function( - tensor_spec.TensorSpec((5, None, 7), dtypes.float32)) + tensor_lib.TensorSpec((5, None, 7), dtypes.float32)) @def_function.function def f6(x): @@ -961,7 +963,7 @@ def f6(x): tensor_shape.TensorShape([2, None, 1, 0])) _ = f6.get_concrete_function( - tensor_spec.TensorSpec((5, None, 7), dtypes.float32)) + tensor_lib.TensorSpec((5, None, 7), dtypes.float32)) @def_function.function def f7(x): @@ -970,7 +972,7 @@ def f7(x): tensor_shape.TensorShape([2, None, 1, 0])) _ = f7.get_concrete_function( - tensor_spec.TensorSpec((5, None, 7), dtypes.float32)) + tensor_lib.TensorSpec((5, None, 7), dtypes.float32)) @def_function.function def f8(x): @@ -979,7 +981,7 @@ def f8(x): tensor_shape.TensorShape([2, None, 1, 0])) _ = f8.get_concrete_function( - tensor_spec.TensorSpec((5, None, 7), dtypes.float32)) + tensor_lib.TensorSpec((5, None, 7), dtypes.float32)) @def_function.function def f9(x): @@ -988,7 +990,7 @@ def f9(x): tensor_shape.TensorShape([1, None, 1, 0])) _ = f9.get_concrete_function( - tensor_spec.TensorSpec((5, None, 7), dtypes.float32)) + tensor_lib.TensorSpec((5, None, 7), dtypes.float32)) @def_function.function def f10(x): @@ -997,7 +999,7 @@ def f10(x): tensor_shape.TensorShape([5, None, 1, 4])) _ = f10.get_concrete_function( - tensor_spec.TensorSpec((5, None, 7), dtypes.float32)) + tensor_lib.TensorSpec((5, None, 7), dtypes.float32)) def testTensorValuedIndexShape(self): with self.session(): @@ -1008,8 +1010,8 @@ def f1(x, y): self.tensorShapeEqual(z.get_shape(), tensor_shape.TensorShape([3, 7])) _ = f1.get_concrete_function( - tensor_spec.TensorSpec((5, 3, 7)), - tensor_spec.TensorSpec((), dtypes.int32)) + tensor_lib.TensorSpec((5, 3, 7)), + tensor_lib.TensorSpec((), dtypes.int32)) @def_function.function def f2(x, y): @@ -1017,8 +1019,8 @@ def f2(x, y): self.tensorShapeEqual(z.get_shape(), tensor_shape.TensorShape([3, 7])) _ = f2.get_concrete_function( - tensor_spec.TensorSpec((5, 3, 7)), - tensor_spec.TensorSpec((), dtypes.int32)) + tensor_lib.TensorSpec((5, 3, 7)), + tensor_lib.TensorSpec((), dtypes.int32)) @def_function.function def f3(x, y): @@ -1026,8 +1028,8 @@ def f3(x, y): self.tensorShapeEqual(z.get_shape(), tensor_shape.TensorShape([2, 7])) _ = f3.get_concrete_function( - tensor_spec.TensorSpec((5, 3, 7)), - tensor_spec.TensorSpec((), dtypes.int32)) + tensor_lib.TensorSpec((5, 3, 7)), + tensor_lib.TensorSpec((), dtypes.int32)) @def_function.function def f4(x, y, s): @@ -1036,9 +1038,9 @@ def f4(x, y, s): 7])) _ = f4.get_concrete_function( - tensor_spec.TensorSpec((5, 3, 7)), - tensor_spec.TensorSpec((), dtypes.int32), - tensor_spec.TensorSpec((), dtypes.int32)) + tensor_lib.TensorSpec((5, 3, 7)), + tensor_lib.TensorSpec((), dtypes.int32), + tensor_lib.TensorSpec((), dtypes.int32)) class GradSliceChecker(object): @@ -1076,7 +1078,7 @@ def __getitem__(self, spec): # compute analytic gradient for slice np_val_grad = (2 * self.varnp * self.varnp) np_sliceval_grad = np.zeros(self.var.get_shape()) - if isinstance(spec, ops.Tensor): + if isinstance(spec, tensor_lib.Tensor): spec = self.test.evaluate(spec) np_sliceval_grad[spec] = np_val_grad[spec] # verify gradient @@ -1615,7 +1617,7 @@ def testIdentityVariable(self): v = resource_variable_ops.ResourceVariable(1.0) self.evaluate(v.initializer) result = array_ops.identity(v) - self.assertIsInstance(result, ops.Tensor) + self.assertIsInstance(result, tensor_lib.Tensor) self.assertAllEqual(result, v) @@ -2387,8 +2389,8 @@ def func(params, indices): params=params, indices=indices, batch_dims=batch_dims) # pylint: disable=cell-var-from-loop f = func.get_concrete_function( - tensor_spec.TensorSpec(params_ph_shape, dtypes.float32), - tensor_spec.TensorSpec(indices_ph_shape, dtypes.int32)) + tensor_lib.TensorSpec(params_ph_shape, dtypes.float32), + tensor_lib.TensorSpec(indices_ph_shape, dtypes.int32)) params_val = np.ones(dtype=np.float32, shape=params_shape) indices_val = np.ones(dtype=np.int32, shape=indices_shape) @@ -2419,7 +2421,7 @@ def testRepeat(self, array, repeats, axis): array = np.array(array) @def_function.function( - input_signature=[tensor_spec.TensorSpec(None, dtypes.int32)] * 2) + input_signature=[tensor_lib.TensorSpec(None, dtypes.int32)] * 2) def repeat_fn(array, repeats): return array_ops.repeat(array, repeats, axis) @@ -2560,7 +2562,7 @@ def stop_gradient_f(x): y = stop_gradient_f(x) self.assertIsNone(tape.gradient(y, x)) # stop_gradient converts ResourceVariable to Tensor - self.assertIsInstance(y, ops.Tensor) + self.assertIsInstance(y, tensor_lib.Tensor) self.assertAllEqual(y, x) if __name__ == "__main__": diff --git a/tensorflow/python/kernel_tests/array_ops/constant_op_test.py b/tensorflow/python/kernel_tests/array_ops/constant_op_test.py index a580075f3fa3d7..eb974c2ddc342d 100644 --- a/tensorflow/python/kernel_tests/array_ops/constant_op_test.py +++ b/tensorflow/python/kernel_tests/array_ops/constant_op_test.py @@ -26,6 +26,7 @@ from tensorflow.python.framework import errors_impl from tensorflow.python.framework import importer from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops @@ -296,7 +297,7 @@ def testAsTensorForTensorInput(self): def testAsTensorForNonTensorInput(self): with ops.Graph().as_default(): x = ops.convert_to_tensor(10.0) - self.assertTrue(isinstance(x, ops.Tensor)) + self.assertTrue(isinstance(x, tensor.Tensor)) def testAsTensorForShapeInput(self): with self.cached_session(): @@ -381,7 +382,7 @@ def testIdTensor(self): with ops.Graph().as_default(): x = constant_op.constant(2.0, shape=[6], name="input") id_op = array_ops.identity(x, name="id") - self.assertTrue(isinstance(id_op.op.inputs[0], ops.Tensor)) + self.assertTrue(isinstance(id_op.op.inputs[0], tensor.Tensor)) self.assertProtoEquals("name: 'id' op: 'Identity' input: 'input' " "attr { key: 'T' value { type: DT_FLOAT } }", id_op.op.node_def) diff --git a/tensorflow/python/kernel_tests/data_structures/BUILD b/tensorflow/python/kernel_tests/data_structures/BUILD index 8ed917fa2add14..bdce5457d26e9a 100644 --- a/tensorflow/python/kernel_tests/data_structures/BUILD +++ b/tensorflow/python/kernel_tests/data_structures/BUILD @@ -16,8 +16,10 @@ tf_py_strict_test( "no_mac", # TODO(b/129706424): Re-enable this test on Mac. ], deps = [ + "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:errors", - "//tensorflow/python/framework:for_generated_wrappers", + "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:test_lib", "//tensorflow/python/ops:data_flow_ops", "//tensorflow/python/platform:client_testlib", @@ -96,8 +98,11 @@ tf_py_strict_test( "//tensorflow/python/eager:context", "//tensorflow/python/eager:def_function", "//tensorflow/python/framework:constant_op", + "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:errors", - "//tensorflow/python/framework:for_generated_wrappers", + "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", + "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/framework:test_lib", "//tensorflow/python/module", "//tensorflow/python/ops:array_ops", @@ -247,8 +252,11 @@ cuda_py_strict_test( srcs = ["padding_fifo_queue_test.py"], deps = [ "//tensorflow/python/framework:constant_op", + "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:errors", - "//tensorflow/python/framework:for_generated_wrappers", + "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", + "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/framework:test_lib", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:data_flow_ops", diff --git a/tensorflow/python/kernel_tests/data_structures/barrier_ops_test.py b/tensorflow/python/kernel_tests/data_structures/barrier_ops_test.py index 112c7454a99094..ec10af3f818ae9 100644 --- a/tensorflow/python/kernel_tests/data_structures/barrier_ops_test.py +++ b/tensorflow/python/kernel_tests/data_structures/barrier_ops_test.py @@ -21,6 +21,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors_impl from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.framework import test_util from tensorflow.python.ops import data_flow_ops from tensorflow.python.platform import test @@ -35,7 +36,7 @@ def testConstructorWithShapes(self): shapes=((1, 2, 3), (8,)), shared_name="B", name="B") - self.assertTrue(isinstance(b.barrier_ref, ops.Tensor)) + self.assertTrue(isinstance(b.barrier_ref, tensor.Tensor)) self.assertProtoEquals(""" name:'B' op:'Barrier' attr { diff --git a/tensorflow/python/kernel_tests/data_structures/fifo_queue_test.py b/tensorflow/python/kernel_tests/data_structures/fifo_queue_test.py index 8111a6997b9d2e..8ee4f9c03df151 100644 --- a/tensorflow/python/kernel_tests/data_structures/fifo_queue_test.py +++ b/tensorflow/python/kernel_tests/data_structures/fifo_queue_test.py @@ -28,6 +28,7 @@ from tensorflow.python.framework import dtypes as dtypes_lib from tensorflow.python.framework import errors_impl from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import test_util from tensorflow.python.module import module @@ -45,7 +46,7 @@ class FIFOQueueTest(test.TestCase): def testConstructor(self): with ops.Graph().as_default(): q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32, name="Q") - self.assertTrue(isinstance(q.queue_ref, ops.Tensor)) + self.assertTrue(isinstance(q.queue_ref, tensor.Tensor)) self.assertProtoEquals(""" name:'Q' device: "/device:CPU:*" op:'FIFOQueueV2' attr { key: 'component_types' value { list { type: DT_FLOAT } } } @@ -61,7 +62,7 @@ def testMultiQueueConstructor(self): 5, (dtypes_lib.int32, dtypes_lib.float32), shared_name="foo", name="Q") - self.assertTrue(isinstance(q.queue_ref, ops.Tensor)) + self.assertTrue(isinstance(q.queue_ref, tensor.Tensor)) self.assertProtoEquals(""" name:'Q' device: "/device:CPU:*" op:'FIFOQueueV2' attr { key: 'component_types' value { list { @@ -80,7 +81,7 @@ def testConstructorWithShapes(self): shapes=(tensor_shape.TensorShape([1, 1, 2, 3]), tensor_shape.TensorShape([5, 8])), name="Q") - self.assertTrue(isinstance(q.queue_ref, ops.Tensor)) + self.assertTrue(isinstance(q.queue_ref, tensor.Tensor)) self.assertProtoEquals(""" name:'Q' device: "/device:CPU:*" op:'FIFOQueueV2' attr { key: 'component_types' value { list { @@ -1645,7 +1646,7 @@ def testConstructor(self): names=("i", "j"), shared_name="foo", name="Q") - self.assertTrue(isinstance(q.queue_ref, ops.Tensor)) + self.assertTrue(isinstance(q.queue_ref, tensor.Tensor)) self.assertProtoEquals(""" name:'Q' device: "/device:CPU:*" op:'FIFOQueueV2' attr { key: 'component_types' value { list { @@ -1666,7 +1667,7 @@ def testConstructorWithShapes(self): shapes=(tensor_shape.TensorShape([1, 1, 2, 3]), tensor_shape.TensorShape([5, 8])), name="Q") - self.assertTrue(isinstance(q.queue_ref, ops.Tensor)) + self.assertTrue(isinstance(q.queue_ref, tensor.Tensor)) self.assertProtoEquals(""" name:'Q' device: "/device:CPU:*" op:'FIFOQueueV2' attr { key: 'component_types' value { list { diff --git a/tensorflow/python/kernel_tests/data_structures/padding_fifo_queue_test.py b/tensorflow/python/kernel_tests/data_structures/padding_fifo_queue_test.py index db21e1bee59931..1b4c9d6b961ac7 100644 --- a/tensorflow/python/kernel_tests/data_structures/padding_fifo_queue_test.py +++ b/tensorflow/python/kernel_tests/data_structures/padding_fifo_queue_test.py @@ -23,6 +23,7 @@ from tensorflow.python.framework import dtypes as dtypes_lib from tensorflow.python.framework import errors_impl from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops @@ -37,7 +38,7 @@ def testConstructor(self): with ops.Graph().as_default(): q = data_flow_ops.PaddingFIFOQueue( 10, dtypes_lib.float32, ((None,),), name="Q") - self.assertTrue(isinstance(q.queue_ref, ops.Tensor)) + self.assertTrue(isinstance(q.queue_ref, tensor.Tensor)) self.assertProtoEquals(""" name:'Q' op:'PaddingFIFOQueueV2' attr { key: 'component_types' value { list { type: DT_FLOAT } } } @@ -53,7 +54,7 @@ def testMultiQueueConstructor(self): 5, (dtypes_lib.int32, dtypes_lib.float32), ((), ()), shared_name="foo", name="Q") - self.assertTrue(isinstance(q.queue_ref, ops.Tensor)) + self.assertTrue(isinstance(q.queue_ref, tensor.Tensor)) self.assertProtoEquals(""" name:'Q' op:'PaddingFIFOQueueV2' attr { key: 'component_types' value { list { @@ -72,7 +73,7 @@ def testConstructorWithShapes(self): shapes=(tensor_shape.TensorShape([1, 1, 2, 3]), tensor_shape.TensorShape([5, 8])), name="Q") - self.assertTrue(isinstance(q.queue_ref, ops.Tensor)) + self.assertTrue(isinstance(q.queue_ref, tensor.Tensor)) self.assertProtoEquals(""" name:'Q' op:'PaddingFIFOQueueV2' attr { key: 'component_types' value { list { diff --git a/tensorflow/python/kernel_tests/linalg/BUILD b/tensorflow/python/kernel_tests/linalg/BUILD index fa8f91b61fb7a2..ada2c64403c01b 100644 --- a/tensorflow/python/kernel_tests/linalg/BUILD +++ b/tensorflow/python/kernel_tests/linalg/BUILD @@ -132,7 +132,9 @@ cuda_py_strict_test( tags = ["no_windows_gpu"], deps = [ "//tensorflow/python/framework:constant_op", - "//tensorflow/python/framework:for_generated_wrappers", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:test_lib", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:linalg_ops", @@ -557,7 +559,9 @@ cuda_py_strict_test( shard_count = 5, tags = ["optonly"], deps = [ - "//tensorflow/python/framework:for_generated_wrappers", + "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", + "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:linalg_ops", "//tensorflow/python/ops:math_ops", diff --git a/tensorflow/python/kernel_tests/linalg/linalg_ops_test.py b/tensorflow/python/kernel_tests/linalg/linalg_ops_test.py index 827d1545e716cb..88d51257b517be 100644 --- a/tensorflow/python/kernel_tests/linalg/linalg_ops_test.py +++ b/tensorflow/python/kernel_tests/linalg/linalg_ops_test.py @@ -23,6 +23,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import linalg_ops @@ -180,10 +181,10 @@ def testShapeInferenceStaticBatchWith(self, num_rows_fn, num_columns_fn): batch_shape=batch_shape) self.assertEqual(4, identity_matrix.shape.ndims) self.assertEqual((2, 3), identity_matrix.shape[:2]) - if num_rows is not None and not isinstance(num_rows, ops.Tensor): + if num_rows is not None and not isinstance(num_rows, tensor.Tensor): self.assertEqual(2, identity_matrix.shape[-2]) - if num_columns is not None and not isinstance(num_columns, ops.Tensor): + if num_columns is not None and not isinstance(num_columns, tensor.Tensor): self.assertEqual(3, identity_matrix.shape[-1]) @parameterized.parameters( diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_util_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_util_test.py index e28d1b2cae2cde..e1ca2f5ce6bcad 100644 --- a/tensorflow/python/kernel_tests/linalg/linear_operator_util_test.py +++ b/tensorflow/python/kernel_tests/linalg/linear_operator_util_test.py @@ -17,6 +17,7 @@ import numpy as np from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import linalg_ops @@ -97,7 +98,7 @@ def test_zero_batch_matrices_returned_as_empty_list(self): def test_one_batch_matrix_returned_after_tensor_conversion(self): arr = rng.rand(2, 3, 4) tensor, = linear_operator_util.broadcast_matrix_batch_dims([arr]) - self.assertTrue(isinstance(tensor, ops.Tensor)) + self.assertTrue(isinstance(tensor, tensor_lib.Tensor)) self.assertAllClose(arr, self.evaluate(tensor)) diff --git a/tensorflow/python/kernel_tests/variables/BUILD b/tensorflow/python/kernel_tests/variables/BUILD index 0d965052fb029d..9a945cf7e3d0e2 100644 --- a/tensorflow/python/kernel_tests/variables/BUILD +++ b/tensorflow/python/kernel_tests/variables/BUILD @@ -84,6 +84,7 @@ cuda_py_strict_test( "//tensorflow/python/framework:indexed_slices", "//tensorflow/python/framework:memory_checker", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/framework:tensor_util", "//tensorflow/python/framework:test_lib", @@ -174,8 +175,11 @@ tf_py_strict_test( "//tensorflow/python/eager:context", "//tensorflow/python/eager:def_function", "//tensorflow/python/framework:constant_op", + "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:errors", - "//tensorflow/python/framework:for_generated_wrappers", + "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", + "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/framework:test_lib", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:cond", diff --git a/tensorflow/python/kernel_tests/variables/resource_variable_ops_test.py b/tensorflow/python/kernel_tests/variables/resource_variable_ops_test.py index ebe9788667c28f..6920466f64cb7a 100644 --- a/tensorflow/python/kernel_tests/variables/resource_variable_ops_test.py +++ b/tensorflow/python/kernel_tests/variables/resource_variable_ops_test.py @@ -38,6 +38,7 @@ from tensorflow.python.framework import indexed_slices from tensorflow.python.framework import memory_checker from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util from tensorflow.python.framework import test_ops @@ -989,7 +990,7 @@ def gradient_func(*grad): result = tape.gradient(out, v) self.assertAllEqual(out, 5.) - self.assertIsInstance(result, ops.Tensor) + self.assertIsInstance(result, tensor_lib.Tensor) self.assertAllEqual(result, 2.) def testToFromProtoCachedValue(self): @@ -1805,7 +1806,7 @@ def testCompositeTensorTypeSpec(self): def testVariableInExtensionType(self): class MaskVariable(extension_type.ExtensionType): variable: resource_variable_ops.ResourceVariable - mask: ops.Tensor + mask: tensor_lib.Tensor v = resource_variable_ops.ResourceVariable([1., 2.]) self.evaluate(v.initializer) diff --git a/tensorflow/python/kernel_tests/variables/variables_test.py b/tensorflow/python/kernel_tests/variables/variables_test.py index 929452b104e7c6..45b88857090313 100644 --- a/tensorflow/python/kernel_tests/variables/variables_test.py +++ b/tensorflow/python/kernel_tests/variables/variables_test.py @@ -26,6 +26,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors_impl from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops @@ -377,7 +378,7 @@ def testOperatorWrapping(self): for attr in functools.WRAPPER_ASSIGNMENTS: self.assertEqual( getattr(variables.Variable.__add__, attr), - getattr(ops.Tensor.__add__, attr)) + getattr(tensor.Tensor.__add__, attr)) @test_util.run_deprecated_v1 def testOperators(self): diff --git a/tensorflow/python/ops/numpy_ops/tests/BUILD b/tensorflow/python/ops/numpy_ops/tests/BUILD index 1b1c3b1442be3a..a9cfd781790e35 100644 --- a/tensorflow/python/ops/numpy_ops/tests/BUILD +++ b/tensorflow/python/ops/numpy_ops/tests/BUILD @@ -23,7 +23,6 @@ py_strict_library( ":extensions", "//tensorflow:tensorflow_py", "//tensorflow/python/framework:dtypes", - "//tensorflow/python/framework:ops", "//tensorflow/python/framework:tensor", "//tensorflow/python/ops:gradient_checker_v2", "//tensorflow/python/ops/numpy_ops:np_array_ops", diff --git a/tensorflow/python/ops/numpy_ops/tests/extensions.py b/tensorflow/python/ops/numpy_ops/tests/extensions.py index b915dff755bc42..901a662ff80c7e 100644 --- a/tensorflow/python/ops/numpy_ops/tests/extensions.py +++ b/tensorflow/python/ops/numpy_ops/tests/extensions.py @@ -447,7 +447,7 @@ def eval_on_shapes(f, static_argnums=(), allow_static_outputs=False): def abstractify(args): def _abstractify(x): x = _canonicalize_jit_arg(x) - if isinstance(x, (ops.Tensor, tf_np.ndarray)): + if isinstance(x, (tensor_lib.Tensor, tf_np.ndarray)): return tensor_lib.TensorSpec(x.shape, x.dtype) else: return x @@ -472,7 +472,7 @@ def recorder(args, kwargs, res): def is_tensor_like(x): if hasattr(x, "_type_spec"): return True # x is a CompositeTensor - return isinstance(x, (tf_np.ndarray, ops.Tensor)) + return isinstance(x, (tf_np.ndarray, tensor_lib.Tensor)) py_values = nest.map_structure( lambda x: None if is_tensor_like(x) else x, res ) @@ -494,7 +494,7 @@ def is_tensor_like(x): # pylint: disable=missing-docstring def f_return(*args): def to_tensor_spec(x): - if isinstance(x, ops.Tensor): + if isinstance(x, tensor_lib.Tensor): return tensor_lib.TensorSpec(x.shape, x.dtype) else: return x @@ -1574,14 +1574,14 @@ def dataset_as_numpy(dataset): # Type check for Tensors and Datasets for ds_el in flat_ds: - if not isinstance(ds_el, (ops.Tensor, dataset_ops.DatasetV2)): + if not isinstance(ds_el, (tensor_lib.Tensor, dataset_ops.DatasetV2)): types = nest.map_structure(type, nested_ds) raise ValueError("Arguments to dataset_as_numpy must be (possibly nested " "structure of) tf.Tensors or tf.data.Datasets. Got: %s" % types) for ds_el in flat_ds: - if isinstance(ds_el, ops.Tensor): + if isinstance(ds_el, tensor_lib.Tensor): np_el = tf_np.asarray(ds_el) elif isinstance(ds_el, dataset_ops.DatasetV2): np_el = _eager_dataset_iterator(ds_el) @@ -1888,7 +1888,7 @@ def wrapper(*args): flattened_input_args = nest.flatten(args) flattened_per_device_args = [[] for _ in devices] for arg in flattened_input_args: - if isinstance(arg, ops.Tensor): + if isinstance(arg, tensor_lib.Tensor): # TODO(nareshmodi): Try and use the dynamic shape instead. if (not arg.shape.rank) or arg.shape[0] != len(devices): # TODO(nareshmodi): Fix this restriction @@ -1932,7 +1932,7 @@ def wrapper(*args): tensors = [] for j, device in enumerate(devices): assert isinstance( - flattened_results[j][i], ops.Tensor + flattened_results[j][i], tensor_lib.Tensor ), "currently only tensor return items are supported" tensors.append(flattened_results[j][i]) final_tree.append(ShardedNdArray(tensors)) diff --git a/tensorflow/python/ops/numpy_ops/tests/test_util.py b/tensorflow/python/ops/numpy_ops/tests/test_util.py index 27840cc70251fc..cf178a3f5dfbcc 100644 --- a/tensorflow/python/ops/numpy_ops/tests/test_util.py +++ b/tensorflow/python/ops/numpy_ops/tests/test_util.py @@ -32,7 +32,6 @@ import numpy.random as npr from tensorflow.python.util import nest -from tensorflow.python.framework import ops from tensorflow.python.framework import tensor from tensorflow.python.framework import dtypes from tensorflow.python.ops import gradient_checker_v2 @@ -83,7 +82,7 @@ def _dtype(x): - if isinstance(x, ops.Tensor): + if isinstance(x, tensor.Tensor): return x.dtype.as_numpy_dtype return (getattr(x, 'dtype', None) or onp.dtype(python_scalar_dtypes.get(type(x), None)) or diff --git a/tensorflow/python/util/BUILD b/tensorflow/python/util/BUILD index 61b5377ecc8f41..9d295572fe3002 100644 --- a/tensorflow/python/util/BUILD +++ b/tensorflow/python/util/BUILD @@ -275,6 +275,7 @@ tf_py_strict_test( "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:extension_type", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_conversion", "//tensorflow/python/framework:test_lib", "//tensorflow/python/ops:array_ops", @@ -1113,6 +1114,7 @@ tf_py_strict_test( "//tensorflow/python/framework:composite_tensor", "//tensorflow/python/framework:constant_op", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:test_lib", "//tensorflow/python/ops:resource_variable_ops", "//tensorflow/python/ops:variables", diff --git a/tensorflow/python/util/dispatch_test.py b/tensorflow/python/util/dispatch_test.py index db01441afba9f5..7bb8e8f8898f6a 100644 --- a/tensorflow/python/util/dispatch_test.py +++ b/tensorflow/python/util/dispatch_test.py @@ -23,6 +23,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import extension_type from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_conversion from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops @@ -105,13 +106,13 @@ def is_tensor_like(self): @classmethod def _overload_all_operators(cls): # pylint: disable=invalid-name """Register overloads for all operators.""" - for operator in ops.Tensor.OVERLOADABLE_OPERATORS: + for operator in tensor_lib.Tensor.OVERLOADABLE_OPERATORS: cls._overload_operator(operator) @classmethod def _overload_operator(cls, operator): # pylint: disable=invalid-name - """Overload an operator with the same overloading as `ops.Tensor`.""" - tensor_oper = getattr(ops.Tensor, operator) + """Overload an operator with the same overloading as `tensor_lib.Tensor`.""" + tensor_oper = getattr(tensor_lib.Tensor, operator) # Compatibility with Python 2: # Python 2 unbound methods have type checks for the first arg, @@ -459,13 +460,13 @@ def testGlobalDispatcherLinearOperators(self): class MaskedTensor(extension_type.ExtensionType): """Simple ExtensionType for testing v2 dispatch.""" - values: ops.Tensor - mask: ops.Tensor + values: tensor_lib.Tensor + mask: tensor_lib.Tensor class SillyTensor(extension_type.ExtensionType): """Simple ExtensionType for testing v2 dispatch.""" - value: ops.Tensor + value: tensor_lib.Tensor how_silly: float @@ -565,7 +566,7 @@ def masked_concat(values, axis, name=None): dispatch.unregister_dispatch_for(masked_concat) def testDispatchForUnion(self): - MaybeMasked = typing.Union[MaskedTensor, ops.Tensor] + MaybeMasked = typing.Union[MaskedTensor, tensor_lib.Tensor] @dispatch.dispatch_for_api(math_ops.add, { "x": MaybeMasked, @@ -936,7 +937,8 @@ def testGetApisWithTypeBasedDispatch(self): self.assertIn(array_ops.concat, dispatch_apis) def testTypeBasedDispatchTargetsFor(self): - MaskedTensorList = typing.List[typing.Union[MaskedTensor, ops.Tensor]] + MaskedTensorList = typing.List[ + typing.Union[MaskedTensor, tensor_lib.Tensor]] try: @dispatch.dispatch_for_api(math_ops.add) diff --git a/tensorflow/python/util/variable_utils_test.py b/tensorflow/python/util/variable_utils_test.py index 41c81812e3322b..9aaa0d0e1b5fc8 100644 --- a/tensorflow/python/util/variable_utils_test.py +++ b/tensorflow/python/util/variable_utils_test.py @@ -18,6 +18,7 @@ from tensorflow.python.framework import composite_tensor from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.framework import test_util from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variables @@ -65,9 +66,9 @@ def test_convert_variables_to_tensors(self): results = variable_utils.convert_variables_to_tensors(data) expected_results = [1, 2, 3, [4], 5, ct] # Only ResourceVariables are converted to Tensors. - self.assertIsInstance(results[0], ops.Tensor) - self.assertIsInstance(results[1], ops.Tensor) - self.assertIsInstance(results[2], ops.Tensor) + self.assertIsInstance(results[0], tensor.Tensor) + self.assertIsInstance(results[1], tensor.Tensor) + self.assertIsInstance(results[2], tensor.Tensor) self.assertIsInstance(results[3], list) self.assertIsInstance(results[4], int) self.assertIs(results[5], ct) @@ -82,7 +83,7 @@ def test_convert_variables_in_composite_tensor(self): self.assertIsInstance(ct2.component, resource_variable_ops.ResourceVariable) result = variable_utils.convert_variables_to_tensors(ct2) - self.assertIsInstance(result.component, ops.Tensor) + self.assertIsInstance(result.component, tensor.Tensor) self.assertAllEqual(result.component, 1) def test_replace_variables_with_atoms(self): @@ -99,7 +100,7 @@ def test_replace_variables_with_atoms(self): # Only ResourceVariables are replaced with int 0s. self.assertIsInstance(results[0], int) self.assertIsInstance(results[1], int) - self.assertIsInstance(results[2], ops.Tensor) + self.assertIsInstance(results[2], tensor.Tensor) self.assertIsInstance(results[3], list) self.assertIsInstance(results[4], int) results[2] = self.evaluate(results[2]) From 9d61e7f9581e646a50749837f86ffb5c1fb806e4 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 10 Jul 2023 14:37:12 -0700 Subject: [PATCH 080/376] Do not run msan against `se_gpu_pjrt_client_test`. PiperOrigin-RevId: 546986278 --- tensorflow/compiler/xla/pjrt/gpu/BUILD | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorflow/compiler/xla/pjrt/gpu/BUILD b/tensorflow/compiler/xla/pjrt/gpu/BUILD index 2a740c3aa7a017..75967151a6947a 100644 --- a/tensorflow/compiler/xla/pjrt/gpu/BUILD +++ b/tensorflow/compiler/xla/pjrt/gpu/BUILD @@ -94,6 +94,7 @@ xla_cc_test( "gpu", "no_oss", "noasan", + "nomsan", "requires-gpu-nvidia:2", ], deps = [ From fa803202f7c0b524777d095b59d89fed658962d4 Mon Sep 17 00:00:00 2001 From: Bixia Zheng Date: Mon, 10 Jul 2023 14:45:02 -0700 Subject: [PATCH 081/376] [xla][gpu] Enhance GpuAsyncTracker to handle Send and Recv. Split kGpuAsyncStream into kGpuAsyncStreamSend and kGpuAsyncStreamRecv. This allows the interleave of Send and Recv operation and can also prevent interleaving two Send and two Recv operations. Add a test. PiperOrigin-RevId: 546988445 --- .../xla/service/gpu/gpu_hlo_schedule.cc | 39 +++++--- .../xla/service/gpu/gpu_hlo_schedule_test.cc | 97 +++++++++++++++++++ 2 files changed, 125 insertions(+), 11 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc index 6f9cf219715754..1e63e1f7dfc37c 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc @@ -241,9 +241,15 @@ SchedulerConfig GetSchedulerConfig(const GpuDeviceInfo& gpu_info) { } // GPU specific resources for latency hiding scheduler. +// +// We use two resources to model collective operations: a resource for sending +// data and a resource for receiving data. All collective operations require +// both resources while the Send and Recv operations requires only the single +// resource corresponding to the operation. enum class GpuResourceType { - kGpuAsyncStream = 0, // The async stream for collectives. - kNumTargetResources = 1, + kGpuAsyncStreamSend = 0, // The resource for sending data. + kGpuAsyncStreamRecv = 1, // The resource for receiving data. + kNumTargetResources = 2, }; // Base GPU async tracker that enables async tracking only for async collectives @@ -285,11 +291,20 @@ class GpuAsyncTracker : public GpuAsyncTrackerBase { ResourceUsageType usage = op.outer == HloOpcode::kAsyncStart ? ResourceUsageType::kResourceRelease : ResourceUsageType::kResourceOccupy; - - const int64_t gpu_stream_resource = - GetFirstTargetDefinedResource() + - static_cast(GpuResourceType::kGpuAsyncStream); - return {std::make_pair(gpu_stream_resource, usage)}; + ResourcesVector resources; + auto add_resource = [&](GpuResourceType resource_type) { + const int64_t gpu_stream_resource = GetFirstTargetDefinedResource() + + static_cast(resource_type); + resources.push_back(std::make_pair(gpu_stream_resource, usage)); + }; + + if (op.inner != HloOpcode::kRecv) { + add_resource(GpuResourceType::kGpuAsyncStreamSend); + } + if (op.inner != HloOpcode::kSend) { + add_resource(GpuResourceType::kGpuAsyncStreamRecv); + } + return resources; } return GpuAsyncTrackerBase::GetResourcesFromInstruction(instr); } @@ -304,9 +319,9 @@ class GpuAsyncTracker : public GpuAsyncTrackerBase { if (resource_type < first_target_resource) { return GpuAsyncTrackerBase::GetNumAvailableResources(resource_type); } - CHECK_EQ(resource_type, + CHECK_LT(resource_type, first_target_resource + - static_cast(GpuResourceType::kGpuAsyncStream)); + static_cast(GpuResourceType::kNumTargetResources)); // We will allow upto 1 outstanding collective on the async stream. This // controls the number of collectives in flight in the schedule (a @@ -329,8 +344,10 @@ class GpuAsyncTracker : public GpuAsyncTrackerBase { CHECK_LE(resource_type, first_target_resource + GetNumTargetDefinedResources()); switch (resource_type - first_target_resource) { - case static_cast(GpuResourceType::kGpuAsyncStream): - return "kGpuAsyncStream"; + case static_cast(GpuResourceType::kGpuAsyncStreamSend): + return "kGpuAsyncStreamSend"; + case static_cast(GpuResourceType::kGpuAsyncStreamRecv): + return "kGpuAsyncStreamRecv"; default: return "kUnsupportedResource"; } diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc index 57d9157215ffc4..c77806b89fcdf6 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc @@ -501,6 +501,103 @@ TEST_F(GpuHloScheduleTest, LHSSendRecv) { EXPECT_TRUE(HasValidFingerprint(module.get())); } +// Checks that the two pairs of (Recv, RecvDone) and (Send, SendDone) do not +// interleave. +TEST_F(GpuHloScheduleTest, LHSSendRecvPairs2) { + const char* hlo_text = R"( + HloModule test + while_cond { + param = (u32[], f32[1, 1024, 1024]) parameter(0) + count = get-tuple-element(%param), index=0 + ub = u32[] constant(25) + ROOT cond_result = pred[] compare(count, ub), direction=LT + } + + while_body { + param = (u32[], f32[1, 1024, 1024]) parameter(0) + count = get-tuple-element(%param), index=0 + send-data = get-tuple-element(%param), index=1 + + after-all-0 = token[] after-all() + recv-0 = (f32[1, 1024, 1024], u32[], token[]) recv(after-all-0), channel_id=1, + frontend_attributes={ + _xla_send_recv_source_target_pairs="{{0, 1}}" + } + send-0 = (f32[1, 1024, 1024], u32[], token[]) send(send-data, after-all-0), + channel_id=1, control-predecessors={recv-0}, frontend_attributes={ + _xla_send_recv_source_target_pairs="{{0, 1}}" + } + recv-done-0 = (f32[1, 1024, 1024], token[]) recv-done(recv-0), channel_id=1 + send-done-0 = token[] send-done(send-0), control-predecessors={recv-done-0}, channel_id=1 + recv-data-0 = f32[1, 1024, 1024] get-tuple-element(recv-done-0), index=0 + + c1 = u32[] constant(1) + new_count = u32[] add(count, c1) + replica = u32[] replica-id() + c10 = u32[] constant(10) + sum = u32[] add(replica, c10) + sum2 = u32[] add(sum, count) + conv = f32[] convert(sum2) + s1 = f32[1, 1024, 1024] broadcast(conv), dimensions={} + + after-all-1 = token[] after-all() + recv-1 = (f32[1, 1024, 1024], u32[], token[]) recv(after-all-1), channel_id=2, + frontend_attributes={ + _xla_send_recv_source_target_pairs="{{1, 0}}" + } + send-1 = (f32[1, 1024, 1024], u32[], token[]) send(send-data, after-all-1), + channel_id=2, control-predecessors={recv-1}, frontend_attributes={ + _xla_send_recv_source_target_pairs="{{1, 0}}" + } + recv-done-1 = (f32[1, 1024, 1024], token[]) recv-done(recv-1), channel_id=2 + send-done-1 = token[] send-done(send-1), control-predecessors={recv-done-1}, channel_id=2 + recv-data-1 = f32[1, 1024, 1024] get-tuple-element(recv-done-1), index=0 + + s2 = f32[1, 1024, 1024] add(recv-data-0, s1) + s = f32[1, 1024, 1024] add(recv-data-1, s2) + + ROOT result = (u32[], f32[1, 1024, 1024]) tuple(new_count, s) + } + + ENTRY test_computation { + c0 = u32[] constant(0) + f0 = f32[] constant(0.0) + init = f32[1, 1024, 1024] broadcast(f0), dimensions={} + while_init = (u32[], f32[1, 1024, 1024]) tuple(c0, init) + while_result = (u32[], f32[1, 1024, 1024]) while(while_init), + body=while_body, condition=while_cond + ROOT entry_result = f32[1, 1024, 1024] get-tuple-element(while_result), index=1 + } + )"; + + TF_ASSERT_OK_AND_ASSIGN( + auto module, + ParseAndReturnVerifiedModule( + hlo_text, GetModuleConfig(/*enable_latency_hiding_scheduler=*/true, + /*enable_gpu_async_tracker=*/true))); + SequentialHloOrdering order = BuildHloOrdering(module.get()); + HloComputation* while_body = module->GetComputationWithName("while_body"); + const std::vector& instruction_sequence = + order.SequentialOrder(*while_body)->instructions(); + auto get_index = [&](absl::string_view hlo_name) { + return absl::c_find_if(instruction_sequence, + [hlo_name](HloInstruction* instruction) { + return instruction->name() == hlo_name; + }) - + instruction_sequence.begin(); + }; + + EXPECT_TRUE(HasValidFingerprint(module.get())); + + EXPECT_LT(get_index("recv-1"), get_index("send-1")); + EXPECT_LT(get_index("send-1"), get_index("recv-done-1")); + EXPECT_GE(get_index("send-done-1") - get_index("send-1"), 14); + EXPECT_LT(abs(get_index("send-done-1") - get_index("result")), 2); + + EXPECT_LT(get_index("recv-done-0"), get_index("recv-1")); + EXPECT_LT(get_index("send-done-0"), get_index("send-1")); +} + class GpuHloScheduleParameterizedTest : public GpuHloScheduleTest, public ::testing::WithParamInterface {}; From fb2cf18785046424857a1d818ef4997102bef794 Mon Sep 17 00:00:00 2001 From: David Dunleavy Date: Mon, 10 Jul 2023 14:52:42 -0700 Subject: [PATCH 082/376] [NFC] Change uses of get_compatible_with_cloud to get_compatible_with_portable. PiperOrigin-RevId: 546990509 --- tensorflow/compiler/mlir/tf2xla/transforms/BUILD | 8 ++++---- tensorflow/core/data/service/BUILD | 6 +++--- tensorflow/core/ir/BUILD | 4 ++-- tensorflow/core/ir/importexport/tests/BUILD | 4 ++-- .../core/ir/importexport/tests/graphdef_to_mlir/BUILD | 4 ++-- .../core/ir/importexport/tests/mlir_to_graphdef/BUILD | 4 ++-- tensorflow/core/ir/tests/BUILD | 4 ++-- tensorflow/core/ir/types/BUILD | 4 ++-- tensorflow/core/transforms/BUILD | 4 ++-- tensorflow/core/transforms/cf_sink/BUILD | 4 ++-- tensorflow/core/transforms/consolidate_attrs/BUILD | 4 ++-- tensorflow/core/transforms/const_dedupe_hoist/BUILD | 4 ++-- tensorflow/core/transforms/constant_folding/BUILD | 4 ++-- tensorflow/core/transforms/cse/BUILD | 4 ++-- .../core/transforms/drop_unregistered_attribute/BUILD | 4 ++-- .../core/transforms/eliminate_passthrough_iter_args/BUILD | 4 ++-- tensorflow/core/transforms/func_to_graph/BUILD | 4 ++-- tensorflow/core/transforms/functional_to_region/BUILD | 4 ++-- tensorflow/core/transforms/graph_compactor/BUILD | 4 ++-- tensorflow/core/transforms/graph_to_func/BUILD | 4 ++-- tensorflow/core/transforms/legacy_call/BUILD | 4 ++-- tensorflow/core/transforms/region_to_functional/BUILD | 4 ++-- tensorflow/core/transforms/remapper/BUILD | 4 ++-- tensorflow/core/transforms/shape_inference/BUILD | 4 ++-- tensorflow/core/transforms/toposort/BUILD | 4 ++-- 25 files changed, 53 insertions(+), 53 deletions(-) diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/BUILD b/tensorflow/compiler/mlir/tf2xla/transforms/BUILD index f8aada0c497aa3..5733d107fa4574 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/BUILD +++ b/tensorflow/compiler/mlir/tf2xla/transforms/BUILD @@ -4,7 +4,7 @@ load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load("//tensorflow:tensorflow.bzl", "tf_cc_test") load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") load("//tensorflow/tsl/platform:build_config_root.bzl", "if_static") package( @@ -15,7 +15,7 @@ package( gentbl_cc_library( name = "legalize_tf_patterns_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["-gen-rewriters"], @@ -35,7 +35,7 @@ gentbl_cc_library( gentbl_cc_library( name = "xla_legalize_tf_passes_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( [ @@ -54,7 +54,7 @@ gentbl_cc_library( gentbl_cc_library( name = "tf_xla_passes_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( [ diff --git a/tensorflow/core/data/service/BUILD b/tensorflow/core/data/service/BUILD index 4855381b7f8e74..c3eaaf3f60da1b 100644 --- a/tensorflow/core/data/service/BUILD +++ b/tensorflow/core/data/service/BUILD @@ -6,7 +6,7 @@ load( "tf_proto_library", "tf_protos_profiler_service", ) -load("//tensorflow:tensorflow.default.bzl", "cc_header_only_library", "get_compatible_with_cloud", "tf_grpc_cc_dependencies") +load("//tensorflow:tensorflow.default.bzl", "cc_header_only_library", "get_compatible_with_portable", "tf_grpc_cc_dependencies") load( "//tensorflow:tensorflow.bzl", "tf_cc_test", @@ -331,7 +331,7 @@ tf_cc_test( cc_grpc_library( name = "dispatcher_cc_grpc_proto", srcs = [":dispatcher_proto"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), # copybara:uncomment copts = ["-Wthread-safety-analysis"], generate_mocks = True, grpc_only = True, @@ -982,7 +982,7 @@ tf_cc_test( cc_grpc_library( name = "worker_cc_grpc_proto", srcs = [":worker_proto"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), # copybara:uncomment copts = ["-Wthread-safety-analysis"], generate_mocks = True, grpc_only = True, diff --git a/tensorflow/core/ir/BUILD b/tensorflow/core/ir/BUILD index c1618ee0cdf169..3900892c21d85f 100644 --- a/tensorflow/core/ir/BUILD +++ b/tensorflow/core/ir/BUILD @@ -1,10 +1,10 @@ load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") load("//tensorflow:tensorflow.bzl", "tf_cc_test") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_compatible_with = get_compatible_with_cloud(), + default_compatible_with = get_compatible_with_portable(), default_visibility = [ "//tensorflow/compiler/mlir/tensorflow:__subpackages__", "//tensorflow/core:__subpackages__", diff --git a/tensorflow/core/ir/importexport/tests/BUILD b/tensorflow/core/ir/importexport/tests/BUILD index 71fa743ef9b557..f25d004408d86f 100644 --- a/tensorflow/core/ir/importexport/tests/BUILD +++ b/tensorflow/core/ir/importexport/tests/BUILD @@ -1,9 +1,9 @@ load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_compatible_with = get_compatible_with_cloud(), + default_compatible_with = get_compatible_with_portable(), default_visibility = [":__subpackages__"], licenses = ["notice"], # Apache 2.0 ) diff --git a/tensorflow/core/ir/importexport/tests/graphdef_to_mlir/BUILD b/tensorflow/core/ir/importexport/tests/graphdef_to_mlir/BUILD index 71fa743ef9b557..f25d004408d86f 100644 --- a/tensorflow/core/ir/importexport/tests/graphdef_to_mlir/BUILD +++ b/tensorflow/core/ir/importexport/tests/graphdef_to_mlir/BUILD @@ -1,9 +1,9 @@ load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_compatible_with = get_compatible_with_cloud(), + default_compatible_with = get_compatible_with_portable(), default_visibility = [":__subpackages__"], licenses = ["notice"], # Apache 2.0 ) diff --git a/tensorflow/core/ir/importexport/tests/mlir_to_graphdef/BUILD b/tensorflow/core/ir/importexport/tests/mlir_to_graphdef/BUILD index 71fa743ef9b557..f25d004408d86f 100644 --- a/tensorflow/core/ir/importexport/tests/mlir_to_graphdef/BUILD +++ b/tensorflow/core/ir/importexport/tests/mlir_to_graphdef/BUILD @@ -1,9 +1,9 @@ load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_compatible_with = get_compatible_with_cloud(), + default_compatible_with = get_compatible_with_portable(), default_visibility = [":__subpackages__"], licenses = ["notice"], # Apache 2.0 ) diff --git a/tensorflow/core/ir/tests/BUILD b/tensorflow/core/ir/tests/BUILD index 315309f01fdce1..14304dfdee7d5e 100644 --- a/tensorflow/core/ir/tests/BUILD +++ b/tensorflow/core/ir/tests/BUILD @@ -1,10 +1,10 @@ load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") load("//tensorflow:tensorflow.bzl", "tf_native_cc_binary") -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_compatible_with = get_compatible_with_cloud(), + default_compatible_with = get_compatible_with_portable(), default_visibility = [":__subpackages__"], licenses = ["notice"], # Apache 2.0 ) diff --git a/tensorflow/core/ir/types/BUILD b/tensorflow/core/ir/types/BUILD index dcaa5ed2b765f4..f638202aa752c0 100644 --- a/tensorflow/core/ir/types/BUILD +++ b/tensorflow/core/ir/types/BUILD @@ -1,10 +1,10 @@ load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") load("//tensorflow:tensorflow.bzl", "tf_cc_test") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_compatible_with = get_compatible_with_cloud(), + default_compatible_with = get_compatible_with_portable(), default_visibility = ["//tensorflow/core:__subpackages__"], licenses = ["notice"], # Apache 2.0 ) diff --git a/tensorflow/core/transforms/BUILD b/tensorflow/core/transforms/BUILD index cd20ec56853754..bc4d89f8aafbeb 100644 --- a/tensorflow/core/transforms/BUILD +++ b/tensorflow/core/transforms/BUILD @@ -1,10 +1,10 @@ load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test") -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_compatible_with = get_compatible_with_cloud(), + default_compatible_with = get_compatible_with_portable(), default_visibility = [ "//tensorflow/core:__subpackages__", "//tensorflow/tools/tfg_graph_transforms:__subpackages__", diff --git a/tensorflow/core/transforms/cf_sink/BUILD b/tensorflow/core/transforms/cf_sink/BUILD index e5a77916b5a1cb..6cf78e4eeee0c9 100644 --- a/tensorflow/core/transforms/cf_sink/BUILD +++ b/tensorflow/core/transforms/cf_sink/BUILD @@ -1,9 +1,9 @@ -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_compatible_with = get_compatible_with_cloud(), + default_compatible_with = get_compatible_with_portable(), default_visibility = [ "//tensorflow/core:__subpackages__", ], diff --git a/tensorflow/core/transforms/consolidate_attrs/BUILD b/tensorflow/core/transforms/consolidate_attrs/BUILD index 50525a5058200e..5558172c0669e2 100644 --- a/tensorflow/core/transforms/consolidate_attrs/BUILD +++ b/tensorflow/core/transforms/consolidate_attrs/BUILD @@ -1,9 +1,9 @@ -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_compatible_with = get_compatible_with_cloud(), + default_compatible_with = get_compatible_with_portable(), default_visibility = [ "//tensorflow/core:__subpackages__", ], diff --git a/tensorflow/core/transforms/const_dedupe_hoist/BUILD b/tensorflow/core/transforms/const_dedupe_hoist/BUILD index b6a81a5a93f848..381b666a80a711 100644 --- a/tensorflow/core/transforms/const_dedupe_hoist/BUILD +++ b/tensorflow/core/transforms/const_dedupe_hoist/BUILD @@ -1,9 +1,9 @@ -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_compatible_with = get_compatible_with_cloud(), + default_compatible_with = get_compatible_with_portable(), default_visibility = [ "//tensorflow/core:__subpackages__", ], diff --git a/tensorflow/core/transforms/constant_folding/BUILD b/tensorflow/core/transforms/constant_folding/BUILD index 1b5b0fb43f4c34..e64e9d868f2677 100644 --- a/tensorflow/core/transforms/constant_folding/BUILD +++ b/tensorflow/core/transforms/constant_folding/BUILD @@ -1,9 +1,9 @@ -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_compatible_with = get_compatible_with_cloud(), + default_compatible_with = get_compatible_with_portable(), default_visibility = [ "//tensorflow/core:__subpackages__", ], diff --git a/tensorflow/core/transforms/cse/BUILD b/tensorflow/core/transforms/cse/BUILD index a6c6914204cd8f..6a4dd774bbcbbc 100644 --- a/tensorflow/core/transforms/cse/BUILD +++ b/tensorflow/core/transforms/cse/BUILD @@ -1,9 +1,9 @@ -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_compatible_with = get_compatible_with_cloud(), + default_compatible_with = get_compatible_with_portable(), default_visibility = [ "//tensorflow/core:__subpackages__", ], diff --git a/tensorflow/core/transforms/drop_unregistered_attribute/BUILD b/tensorflow/core/transforms/drop_unregistered_attribute/BUILD index 98a5fe7d236f19..73cc8918341602 100644 --- a/tensorflow/core/transforms/drop_unregistered_attribute/BUILD +++ b/tensorflow/core/transforms/drop_unregistered_attribute/BUILD @@ -1,9 +1,9 @@ -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_compatible_with = get_compatible_with_cloud(), + default_compatible_with = get_compatible_with_portable(), default_visibility = [ "//tensorflow/core:__subpackages__", ], diff --git a/tensorflow/core/transforms/eliminate_passthrough_iter_args/BUILD b/tensorflow/core/transforms/eliminate_passthrough_iter_args/BUILD index fe69bb24386eb2..b19c211a8abdcb 100644 --- a/tensorflow/core/transforms/eliminate_passthrough_iter_args/BUILD +++ b/tensorflow/core/transforms/eliminate_passthrough_iter_args/BUILD @@ -1,9 +1,9 @@ -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_compatible_with = get_compatible_with_cloud(), + default_compatible_with = get_compatible_with_portable(), default_visibility = [ "//tensorflow/core:__subpackages__", ], diff --git a/tensorflow/core/transforms/func_to_graph/BUILD b/tensorflow/core/transforms/func_to_graph/BUILD index 4cd2e365f3d384..0c62a5a2f90894 100644 --- a/tensorflow/core/transforms/func_to_graph/BUILD +++ b/tensorflow/core/transforms/func_to_graph/BUILD @@ -1,9 +1,9 @@ -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_compatible_with = get_compatible_with_cloud(), + default_compatible_with = get_compatible_with_portable(), default_visibility = [ "//tensorflow/core:__subpackages__", ], diff --git a/tensorflow/core/transforms/functional_to_region/BUILD b/tensorflow/core/transforms/functional_to_region/BUILD index 14addc62ce7b47..428c83441aad56 100644 --- a/tensorflow/core/transforms/functional_to_region/BUILD +++ b/tensorflow/core/transforms/functional_to_region/BUILD @@ -1,9 +1,9 @@ -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_compatible_with = get_compatible_with_cloud(), + default_compatible_with = get_compatible_with_portable(), default_visibility = [ "//tensorflow/core:__subpackages__", ], diff --git a/tensorflow/core/transforms/graph_compactor/BUILD b/tensorflow/core/transforms/graph_compactor/BUILD index 360246876f52df..35635b3e0169e5 100644 --- a/tensorflow/core/transforms/graph_compactor/BUILD +++ b/tensorflow/core/transforms/graph_compactor/BUILD @@ -1,9 +1,9 @@ -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_compatible_with = get_compatible_with_cloud(), + default_compatible_with = get_compatible_with_portable(), default_visibility = [ "//tensorflow/core:__subpackages__", ], diff --git a/tensorflow/core/transforms/graph_to_func/BUILD b/tensorflow/core/transforms/graph_to_func/BUILD index c4bcc7fb83bac2..69023cc514fb83 100644 --- a/tensorflow/core/transforms/graph_to_func/BUILD +++ b/tensorflow/core/transforms/graph_to_func/BUILD @@ -1,9 +1,9 @@ -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_compatible_with = get_compatible_with_cloud(), + default_compatible_with = get_compatible_with_portable(), default_visibility = [ "//tensorflow/core:__subpackages__", ], diff --git a/tensorflow/core/transforms/legacy_call/BUILD b/tensorflow/core/transforms/legacy_call/BUILD index c010fdb7333637..1784c67edc712d 100644 --- a/tensorflow/core/transforms/legacy_call/BUILD +++ b/tensorflow/core/transforms/legacy_call/BUILD @@ -1,9 +1,9 @@ -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_compatible_with = get_compatible_with_cloud(), + default_compatible_with = get_compatible_with_portable(), default_visibility = [ "//tensorflow/core:__subpackages__", ], diff --git a/tensorflow/core/transforms/region_to_functional/BUILD b/tensorflow/core/transforms/region_to_functional/BUILD index b49cb34ce65075..becc78b878bd8d 100644 --- a/tensorflow/core/transforms/region_to_functional/BUILD +++ b/tensorflow/core/transforms/region_to_functional/BUILD @@ -1,9 +1,9 @@ -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_compatible_with = get_compatible_with_cloud(), + default_compatible_with = get_compatible_with_portable(), default_visibility = [ "//tensorflow/core:__subpackages__", ], diff --git a/tensorflow/core/transforms/remapper/BUILD b/tensorflow/core/transforms/remapper/BUILD index 4a3ab37da4294b..a75461d412b23b 100644 --- a/tensorflow/core/transforms/remapper/BUILD +++ b/tensorflow/core/transforms/remapper/BUILD @@ -1,10 +1,10 @@ -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_compatible_with = get_compatible_with_cloud(), + default_compatible_with = get_compatible_with_portable(), default_visibility = [ "//tensorflow/core:__subpackages__", ], diff --git a/tensorflow/core/transforms/shape_inference/BUILD b/tensorflow/core/transforms/shape_inference/BUILD index c1fd69fbe2b619..d9eb50e9762b2a 100644 --- a/tensorflow/core/transforms/shape_inference/BUILD +++ b/tensorflow/core/transforms/shape_inference/BUILD @@ -1,9 +1,9 @@ -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_compatible_with = get_compatible_with_cloud(), + default_compatible_with = get_compatible_with_portable(), default_visibility = [ "//tensorflow/core/transforms:__subpackages__", ], diff --git a/tensorflow/core/transforms/toposort/BUILD b/tensorflow/core/transforms/toposort/BUILD index 7b39cc616a414d..be03bed1002e0b 100644 --- a/tensorflow/core/transforms/toposort/BUILD +++ b/tensorflow/core/transforms/toposort/BUILD @@ -1,9 +1,9 @@ -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_compatible_with = get_compatible_with_cloud(), + default_compatible_with = get_compatible_with_portable(), default_visibility = [ "//tensorflow/compiler:__subpackages__", "//tensorflow/core:__subpackages__", From 7a161e5fc1b7b9b71a0604f57ef505f1238e0421 Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Mon, 10 Jul 2023 14:59:41 -0700 Subject: [PATCH 083/376] Add a transform for torch_index_select to gather. At the moment, StableHLO neither has torch_index_select specification nor the interpreter for it. Some torch_index_select can be interpreted as gather, so this pass allows torch_index_select ops to lower to gather ops. This implementation is mostly derived from the [client HLO API](https://github.com/tensorflow/tensorflow/blob/3f0b7fed349f257dc2e6cfeec9611c7e86f9d0bb/tensorflow/compiler/xla/client/lib/slicing.cc#L267-L312). PiperOrigin-RevId: 546992292 --- tensorflow/compiler/xla/mlir_hlo/BUILD | 1 + .../mlir_hlo/mhlo/transforms/CMakeLists.txt | 1 + .../legalize_torch_index_select_to_gather.cc | 157 +++++++++++++++ .../mlir_hlo/mhlo/transforms/mhlo_passes.td | 5 + .../xla/mlir_hlo/mhlo/transforms/passes.h | 2 + .../xla/mlir_hlo/mhlo/transforms/rewriters.h | 4 + ...legalize-torch-index-select-to-gather.mlir | 189 ++++++++++++++++++ 7 files changed, 359 insertions(+) create mode 100644 tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/legalize_torch_index_select_to_gather/legalize_torch_index_select_to_gather.cc create mode 100644 tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-torch-index-select-to-gather.mlir diff --git a/tensorflow/compiler/xla/mlir_hlo/BUILD b/tensorflow/compiler/xla/mlir_hlo/BUILD index 9cf6296ae7c05a..fd47f29995196e 100644 --- a/tensorflow/compiler/xla/mlir_hlo/BUILD +++ b/tensorflow/compiler/xla/mlir_hlo/BUILD @@ -734,6 +734,7 @@ cc_library( "mhlo/transforms/legalize_to_linalg/legalize_to_linalg.cc", "mhlo/transforms/legalize_to_standard/generated_legalize_to_standard.inc", "mhlo/transforms/legalize_to_standard/legalize_to_standard.cc", + "mhlo/transforms/legalize_torch_index_select_to_gather/legalize_torch_index_select_to_gather.cc", "mhlo/transforms/legalize_trigonometric_to_approximation/legalize_trigonometric_to_approximation.cc", "mhlo/transforms/lower_complex/generated_lower_complex.inc", "mhlo/transforms/lower_complex/lower_complex.cc", diff --git a/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/CMakeLists.txt b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/CMakeLists.txt index 84c1449a3aaa40..e0fa81c241f3d1 100644 --- a/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/CMakeLists.txt +++ b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/CMakeLists.txt @@ -47,6 +47,7 @@ add_mlir_library(MhloPasses legalize_einsum_to_dot_general/legalize_einsum_to_dot_general.cc legalize_gather_to_torch_index_select/legalize_gather_to_torch_index_select.cc legalize_shape_computations/legalize_shape_computations.cc + legalize_torch_index_select_to_gather/legalize_torch_index_select_to_gather.cc legalize_trigonometric_to_approximation/legalize_trigonometric_to_approximation.cc lower_complex/lower_complex.cc lower_complex/lower_complex_patterns.td diff --git a/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/legalize_torch_index_select_to_gather/legalize_torch_index_select_to_gather.cc b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/legalize_torch_index_select_to_gather/legalize_torch_index_select_to_gather.cc new file mode 100644 index 00000000000000..daaafd01572399 --- /dev/null +++ b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/legalize_torch_index_select_to_gather/legalize_torch_index_select_to_gather.cc @@ -0,0 +1,157 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include + +#include "llvm/ADT/SmallVector.h" +#include "mhlo/IR/hlo_ops.h" +#include "mhlo/transforms/passes.h" +#include "mhlo/transforms/rewriters.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +namespace mlir { +namespace mhlo { + +#define GEN_PASS_DEF_LEGALIZETORCHINDEXSELECTTOGATHERPASS +#include "mhlo/transforms/mhlo_passes.h.inc" + +namespace { + +struct TorchIndexSelectIsGather : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TorchIndexSelectOp op, + PatternRewriter &rewriter) const override { + auto operand = op.getOperand(); + auto operandTy = operand.getType(); + if (!operandTy.hasRank()) { + return rewriter.notifyMatchFailure(op, "unranked operand"); + } + + auto index = op.getIndex(); + if (!operand.getType().hasStaticShape() || + !index.getType().hasStaticShape()) { + return rewriter.notifyMatchFailure( + op, "operand and index must have static shapes"); + } + + int64_t dim = static_cast(op.getDim()); + int64_t batchDims = op.getBatchDims(); + if (dim < batchDims) { + return rewriter.notifyMatchFailure( + op, "dim must be greater than or equal to the number of batch dims"); + } + + int64_t indexVectorDim = index.getType().getRank(); + auto indexTy = index.getType(); + auto indexElementTy = indexTy.getElementType().dyn_cast(); + if (!indexElementTy) { + return rewriter.notifyMatchFailure( + op, "index must have integer element type"); + } + + if (index.getType().getElementType().getIntOrFloatBitWidth() == 64 && + operandTy.getShape()[dim] < std::numeric_limits::max()) { + index = rewriter.create( + op.getLoc(), index, rewriter.getIntegerType(32, /*isSigned=*/false)); + } + + if (batchDims > 0) { + llvm::SmallVector newIndexShape(indexTy.getShape()); + newIndexShape.push_back(1); + auto newIndexType = RankedTensorType::get( + newIndexShape, index.getType().getElementType()); + + llvm::SmallVector toConcat; + for (auto batchDim = 0; batchDim < batchDims; ++batchDim) { + toConcat.push_back( + rewriter.create(op.getLoc(), newIndexType, batchDim)); + } + toConcat.push_back( + rewriter.create(op.getLoc(), newIndexType, index)); + index = rewriter.create(op.getLoc(), ValueRange(toConcat), + indexVectorDim); + } + + llvm::SmallVector offsetDims; + llvm::SmallVector collapsedSliceDims; + llvm::SmallVector startIndexMap; + llvm::SmallVector sliceSizes(operandTy.getShape()); + for (auto i = 0; i < operandTy.getRank(); ++i) { + if (i < batchDims || i == dim) { + sliceSizes[i] = std::min(sliceSizes[i], static_cast(1)); + collapsedSliceDims.push_back(i); + startIndexMap.push_back(i); + } else { + if (i < dim) { + offsetDims.push_back(i); + } else { + offsetDims.push_back(i + indexVectorDim - (1 + batchDims)); + } + } + } + + auto gatherDimensionNumbersAttr = GatherDimensionNumbersAttr::get( + rewriter.getContext(), offsetDims, collapsedSliceDims, startIndexMap, + indexVectorDim); + + auto sliceSizesAttr = rewriter.getI64TensorAttr(sliceSizes); + + auto gatherOp = + rewriter.create(op.getLoc(), operand, index, + gatherDimensionNumbersAttr, sliceSizesAttr); + rewriter.replaceOp(op, gatherOp); + return success(); + } +}; + +struct LegalizeTorchIndexSelectToGatherPass + : public impl::LegalizeTorchIndexSelectToGatherPassBase< + LegalizeTorchIndexSelectToGatherPass> { + /// Perform the lowering of standard dialect operations to approximations. + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + populateTorchIndexSelectToGatherPatterns(&getContext(), &patterns); + if (failed( + applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) + return signalPassFailure(); + } +}; +} // namespace + +void populateTorchIndexSelectToGatherPatterns(mlir::MLIRContext *context, + RewritePatternSet *patterns) { + patterns->add(context); +} + +std::unique_ptr> +createLegalizeTorchIndexSelectToGatherPass() { + return std::make_unique(); +} + +} // namespace mhlo +} // namespace mlir diff --git a/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/mhlo_passes.td b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/mhlo_passes.td index 1f740b45734df3..c8d3b9b6cf5806 100644 --- a/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/mhlo_passes.td +++ b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/mhlo_passes.td @@ -93,6 +93,11 @@ def LegalizeGatherToTorchIndexSelectPass : Pass<"mhlo-legalize-gather-to-torch-i let constructor = "createLegalizeGatherToTorchIndexSelectPass()"; } +def LegalizeTorchIndexSelectToGatherPass : Pass<"mhlo-legalize-torch-index-select-to-gather", "func::FuncOp"> { + let summary = "Legalizes torch index select to a gather."; + let constructor = "createLegalizeTorchIndexSelectToGatherPass()"; +} + def LegalizeTanhToApproximationPass : Pass<"mhlo-legalize-trigonometric-to-approximation", "func::FuncOp"> { let summary = "Legalize trigonometric operations from standard dialect to an approximation."; diff --git a/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/passes.h b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/passes.h index 50deb8d5b3353b..316aa825584d41 100644 --- a/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/passes.h +++ b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/passes.h @@ -160,6 +160,8 @@ std::unique_ptr> createLegalizeEinsumToDotGeneralPass(); std::unique_ptr> createLegalizeGatherToTorchIndexSelectPass(); +std::unique_ptr> +createLegalizeTorchIndexSelectToGatherPass(); std::unique_ptr> createFlattenTuplePass(); // Creates a pass for expanding mhlo.tuple ops. diff --git a/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/rewriters.h b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/rewriters.h index dcca8e78cd8671..f2e4f85148f8be 100644 --- a/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/rewriters.h +++ b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/rewriters.h @@ -49,6 +49,10 @@ void populateEinsumToDotGeneralPatterns(mlir::MLIRContext *context, void populateGatherToTorchIndexSelectPatterns(mlir::MLIRContext *context, RewritePatternSet *patterns); +// Rewrite patterns for torch index select to equivalent gather legalization. +void populateTorchIndexSelectToGatherPatterns(mlir::MLIRContext *context, + RewritePatternSet *patterns); + void populateMhloToStdPatterns(RewritePatternSet *patterns, MLIRContext *ctx); // Collection of rewrite patterns for lowering all mhlo ops to their diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-torch-index-select-to-gather.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-torch-index-select-to-gather.mlir new file mode 100644 index 00000000000000..19230564c80108 --- /dev/null +++ b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-torch-index-select-to-gather.mlir @@ -0,0 +1,189 @@ +// RUN: mlir-hlo-opt -mhlo-legalize-torch-index-select-to-gather -split-input-file %s -o - | FileCheck %s + +// CHECK-LABEL: @index_select_to_gather_convert_index_type +func.func @index_select_to_gather_convert_index_type(%arg0 : tensor<5x1x5xi64>, %arg1 : tensor<2xi64>) -> tensor<2x1x5xi64> { + // CHECK: [[ARG1:%.+]] = mhlo.convert %arg1 : (tensor<2xi64>) -> tensor<2xui32> + // CHECK: [[RES:%.+]] = "mhlo.gather"(%arg0, [[ARG1]]) { + // CHECK-SAME: dimension_numbers = #mhlo.gather< + // CHECK-SAME: offset_dims = [1, 2], + // CHECK-SAME: collapsed_slice_dims = [0], + // CHECK-SAME: start_index_map = [0], + // CHECK-SAME: index_vector_dim = 1 + // CHECK-SAME: >, + // CHECK-SAME: indices_are_sorted = false, + // CHECK-SAME: slice_sizes = dense<[1, 1, 5]> : tensor<3xi64> + // CHECK-SAME: } : (tensor<5x1x5xi64>, tensor<2xui32>) -> tensor<2x1x5xi64> + %0 = "mhlo.torch_index_select"(%arg0, %arg1) { + dim = 0 : i64, + batch_dims = 0 : i64 + } : (tensor<5x1x5xi64>, tensor<2xi64>) -> tensor<2x1x5xi64> + // CHECK: return [[RES]] : tensor<2x1x5xi64> + func.return %0 : tensor<2x1x5xi64> +} + +// ----- + +// CHECK-LABEL: @index_select_to_gather_multi_offset_dims +func.func @index_select_to_gather_multi_offset_dims(%arg0 : tensor<5x1x5xi32>, %arg1 : tensor<2xi32>) -> tensor<2x1x5xi32> { + // CHECK: [[RES:%.+]] = "mhlo.gather"(%arg0, %arg1) { + // CHECK-SAME: dimension_numbers = #mhlo.gather< + // CHECK-SAME: offset_dims = [1, 2], + // CHECK-SAME: collapsed_slice_dims = [0], + // CHECK-SAME: start_index_map = [0], + // CHECK-SAME: index_vector_dim = 1 + // CHECK-SAME: >, + // CHECK-SAME: indices_are_sorted = false, + // CHECK-SAME: slice_sizes = dense<[1, 1, 5]> : tensor<3xi64> + // CHECK-SAME: } : (tensor<5x1x5xi32>, tensor<2xi32>) -> tensor<2x1x5xi32> + %0 = "mhlo.torch_index_select"(%arg0, %arg1) { + dim = 0 : i64, + batch_dims = 0 : i64 + } : (tensor<5x1x5xi32>, tensor<2xi32>) -> tensor<2x1x5xi32> + // CHECK: return [[RES]] : tensor<2x1x5xi32> + func.return %0 : tensor<2x1x5xi32> +} + +// ----- + +// CHECK-LABEL: @index_select_to_gather_larger_output +func.func @index_select_to_gather_larger_output(%arg0 : tensor<5x4xf32>, %arg1 : tensor<1x3x1xi32>) -> tensor<1x3x1x4xf32> { + // CHECK: [[RES:%.+]] = "mhlo.gather"(%arg0, %arg1) { + // CHECK-SAME: dimension_numbers = #mhlo.gather< + // CHECK-SAME: offset_dims = [3], + // CHECK-SAME: collapsed_slice_dims = [0], + // CHECK-SAME: start_index_map = [0], + // CHECK-SAME: index_vector_dim = 3 + // CHECK-SAME: >, + // CHECK-SAME: indices_are_sorted = false, + // CHECK-SAME: slice_sizes = dense<[1, 4]> : tensor<2xi64> + // CHECK-SAME: } : (tensor<5x4xf32>, tensor<1x3x1xi32>) -> tensor<1x3x1x4xf32> + %0 = "mhlo.torch_index_select"(%arg0, %arg1) { + dim = 0 : i64, + batch_dims = 0 : i64 + } : (tensor<5x4xf32>, tensor<1x3x1xi32>) -> tensor<1x3x1x4xf32> + // CHECK: return [[RES]] : tensor<1x3x1x4xf32> + func.return %0 : tensor<1x3x1x4xf32> +} + +// ----- + +// CHECK-LABEL: @index_select_to_gather_regular_map +func.func @index_select_to_gather_regular_map(%arg0: tensor<3x4xi32>, %arg1: tensor<2xi32>) -> tensor<2x4xi32> { + // CHECK: [[RES:%.+]] = "mhlo.gather"(%arg0, %arg1) { + // CHECK-SAME: dimension_numbers = #mhlo.gather< + // CHECK-SAME: offset_dims = [1], + // CHECK-SAME: collapsed_slice_dims = [0], + // CHECK-SAME: start_index_map = [0], + // CHECK-SAME: index_vector_dim = 1 + // CHECK-SAME: >, + // CHECK-SAME: indices_are_sorted = false, + // CHECK-SAME: slice_sizes = dense<[1, 4]> : tensor<2xi64> + // CHECK-SAME: } : (tensor<3x4xi32>, tensor<2xi32>) -> tensor<2x4xi32> + %0 = "mhlo.torch_index_select"(%arg0, %arg1) { + dim = 0 : i64, + batch_dims = 0 : i64 + } : (tensor<3x4xi32>, tensor<2xi32>) -> tensor<2x4xi32> + // CHECK: return [[RES]] : tensor<2x4xi32> + func.return %0 : tensor<2x4xi32> +} + +// ----- + +// CHECK-LABEL: @index_select_to_gather_reverse_map +func.func @index_select_to_gather_reverse_map(%arg0: tensor<3x4xi32>, %arg1: tensor<2xi32>) -> tensor<3x2xi32> { + // CHECK: [[RES:%.+]] = "mhlo.gather"(%arg0, %arg1) { + // CHECK-SAME: dimension_numbers = #mhlo.gather< + // CHECK-SAME: offset_dims = [0], + // CHECK-SAME: collapsed_slice_dims = [1], + // CHECK-SAME: start_index_map = [1], + // CHECK-SAME: index_vector_dim = 1 + // CHECK-SAME: >, + // CHECK-SAME: indices_are_sorted = false, + // CHECK-SAME: slice_sizes = dense<[3, 1]> : tensor<2xi64> + // CHECK-SAME: } : (tensor<3x4xi32>, tensor<2xi32>) -> tensor<3x2xi32> + %0 = "mhlo.torch_index_select"(%arg0, %arg1) { + dim = 1 : i64, + batch_dims = 0 : i64 + } : (tensor<3x4xi32>, tensor<2xi32>) -> tensor<3x2xi32> + // CHECK: return [[RES]] : tensor<3x2xi32> + func.return %0 : tensor<3x2xi32> +} + +// ----- + +// CHECK-LABEL: @index_select_to_gather_batch_dim_greater_than_1 +func.func @index_select_to_gather_batch_dim_greater_than_1(%arg0 : tensor<5x1x5xi32>, %arg1 : tensor<2xi32>) -> tensor<2x5xi32> { + // CHECK: [[ARG0:%.+]] = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<2x1xi32> + // CHECK: [[ARG1:%.+]] = mhlo.reshape %arg1 : (tensor<2xi32>) -> tensor<2x1xi32> + // CHECK: [[ARG2:%.+]] = "mhlo.concatenate"([[ARG0]], [[ARG1]]) {dimension = 1 : i64} : (tensor<2x1xi32>, tensor<2x1xi32>) -> tensor<2x2xi32> + // CHECK: [[RES:%.+]] = "mhlo.gather"(%arg0, [[ARG2]]) { + // CHECK-SAME: dimension_numbers = #mhlo.gather< + // CHECK-SAME: offset_dims = [1], + // CHECK-SAME: collapsed_slice_dims = [0, 1], + // CHECK-SAME: start_index_map = [0, 1], + // CHECK-SAME: index_vector_dim = 1 + // CHECK-SAME: >, + // CHECK-SAME: indices_are_sorted = false, + // CHECK-SAME: slice_sizes = dense<[1, 1, 5]> : tensor<3xi64> + // CHECK-SAME: } : (tensor<5x1x5xi32>, tensor<2x2xi32>) -> tensor<2x5xi32> + %0 = "mhlo.torch_index_select"(%arg0, %arg1) { + dim = 1 : i64, + batch_dims = 1 : i64 + } : (tensor<5x1x5xi32>, tensor<2xi32>) -> tensor<2x5xi32> + func.return %0 : tensor<2x5xi32> +} + +// ----- + +func.func @index_select_to_gather_unranked(%arg0 : tensor<*xi32>, %arg1 : tensor<*xi32>) -> tensor<*xi32> { + // CHECK: mhlo.torch_index_select + %0 = "mhlo.torch_index_select"(%arg0, %arg1) { + dim = 0 : i64, + batch_dims = 0 : i64 + } : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32> + func.return %0 : tensor<*xi32> +} + +// ----- + +func.func @index_select_to_gather_non_static_operand(%arg0 : tensor<5x1x?xi32>, %arg1 : tensor<2xi32>) -> tensor<2x1x5xi32> { + // CHECK: mhlo.torch_index_select + %0 = "mhlo.torch_index_select"(%arg0, %arg1) { + dim = 0 : i64, + batch_dims = 0 : i64 + } : (tensor<5x1x?xi32>, tensor<2xi32>) -> tensor<2x1x5xi32> + func.return %0 : tensor<2x1x5xi32> +} + +// ----- + +func.func @index_select_to_gather_non_static_index(%arg0 : tensor<5x1x5xi32>, %arg1 : tensor) -> tensor<2x1x5xi32> { + // CHECK: mhlo.torch_index_select + %0 = "mhlo.torch_index_select"(%arg0, %arg1) { + dim = 0 : i64, + batch_dims = 0 : i64 + } : (tensor<5x1x5xi32>, tensor) -> tensor<2x1x5xi32> + func.return %0 : tensor<2x1x5xi32> +} + +// ----- + +func.func @index_select_to_gather_dim_less_than_batch_dims(%arg0 : tensor<5x1x5xi32>, %arg1 : tensor<2xi32>) -> tensor<2x1x5xi32> { + // CHECK: mhlo.torch_index_select + %0 = "mhlo.torch_index_select"(%arg0, %arg1) { + dim = 0 : i64, + batch_dims = 1 : i64 + } : (tensor<5x1x5xi32>, tensor<2xi32>) -> tensor<2x1x5xi32> + func.return %0 : tensor<2x1x5xi32> +} + +// ----- + +func.func @index_select_to_gather_non_integer_index(%arg0 : tensor<5x1x5xi32>, %arg1 : tensor<2xf32>) -> tensor<2x1x5xi32> { + // CHECK: mhlo.torch_index_select + %0 = "mhlo.torch_index_select"(%arg0, %arg1) { + dim = 0 : i64, + batch_dims = 0 : i64 + } : (tensor<5x1x5xi32>, tensor<2xf32>) -> tensor<2x1x5xi32> + func.return %0 : tensor<2x1x5xi32> +} From c4efe17499f3b4f0276635deb08c7624923f837e Mon Sep 17 00:00:00 2001 From: Fiona Lang Date: Mon, 10 Jul 2023 15:08:58 -0700 Subject: [PATCH 084/376] Update ops.Tensor references to //third_party/tensorflow/python/framework/tensor.py. PiperOrigin-RevId: 546994884 --- tensorflow/python/saved_model/BUILD | 13 +++++++------ tensorflow/python/saved_model/builder_impl.py | 5 +++-- .../saved_model/function_deserialization.py | 13 +++++++------ .../python/saved_model/model_utils/BUILD | 3 ++- .../saved_model/model_utils/export_output.py | 13 +++++++------ .../model_utils/export_output_test.py | 8 ++++---- .../nested_structure_coder_test.py | 19 ++++++++----------- .../python/saved_model/saved_model_test.py | 5 +++-- .../saved_model/signature_def_utils_impl.py | 5 +++-- .../saved_model/signature_serialization.py | 15 +++++++-------- tensorflow/python/saved_model/utils_test.py | 5 +++-- tensorflow/python/trackable/BUILD | 1 + tensorflow/python/trackable/resource.py | 3 ++- 13 files changed, 57 insertions(+), 51 deletions(-) diff --git a/tensorflow/python/saved_model/BUILD b/tensorflow/python/saved_model/BUILD index 85e99a18630fbf..49168e55b893cf 100644 --- a/tensorflow/python/saved_model/BUILD +++ b/tensorflow/python/saved_model/BUILD @@ -76,6 +76,7 @@ py_strict_library( "//tensorflow/core:protos_all_py", "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/lib/io:lib", "//tensorflow/python/ops:variables", "//tensorflow/python/platform:tf_logging", @@ -187,6 +188,7 @@ tf_py_strict_test( "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:errors", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:test_lib", "//tensorflow/python/framework:test_ops", "//tensorflow/python/lib/io:lib", @@ -248,11 +250,11 @@ tf_py_strict_test( ":utils", "//tensorflow/core:protos_all_py", "//tensorflow/python/eager:context", - "//tensorflow/python/eager:def_function", "//tensorflow/python/framework:constant_op", "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", "//tensorflow/python/framework:sparse_tensor", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:test_lib", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:control_flow_ops", @@ -275,6 +277,7 @@ py_strict_library( "//tensorflow/core:protos_all_py", "//tensorflow/python/framework:errors", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_util", "//tensorflow/python/util:deprecation", "//tensorflow/python/util:tf_export", @@ -326,8 +329,7 @@ py_strict_library( "//tensorflow/python/eager:function", "//tensorflow/python/eager/polymorphic_function:attributes", "//tensorflow/python/framework:composite_tensor", - "//tensorflow/python/framework:ops", - "//tensorflow/python/framework:tensor_spec", + "//tensorflow/python/framework:tensor", "//tensorflow/python/ops:resource_variable_ops", "//tensorflow/python/trackable:base", "//tensorflow/python/types:core", @@ -734,7 +736,7 @@ py_strict_library( "//tensorflow/python/framework:function_def_to_graph", "//tensorflow/python/framework:op_def_registry", "//tensorflow/python/framework:ops", - "//tensorflow/python/framework:tensor_spec", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:type_spec", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:custom_gradient", @@ -771,10 +773,9 @@ tf_py_strict_test( "//tensorflow/python/framework:constant_op", "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:extension_type", - "//tensorflow/python/framework:ops", "//tensorflow/python/framework:sparse_tensor", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_shape", - "//tensorflow/python/framework:tensor_spec", "//tensorflow/python/framework:tensor_util", "//tensorflow/python/framework:test_lib", "//tensorflow/python/framework:type_spec", diff --git a/tensorflow/python/saved_model/builder_impl.py b/tensorflow/python/saved_model/builder_impl.py index bf2e6241d7caab..18bdc53a3d77e8 100644 --- a/tensorflow/python/saved_model/builder_impl.py +++ b/tensorflow/python/saved_model/builder_impl.py @@ -25,6 +25,7 @@ from tensorflow.core.protobuf import saver_pb2 from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.lib.io import file_io from tensorflow.python.ops import variables from tensorflow.python.platform import tf_logging @@ -514,7 +515,7 @@ def _add_train_op(self, train_op): TypeError if Train op is not of type `Operation`. """ if train_op is not None: - if (not isinstance(train_op, ops.Tensor) and + if (not isinstance(train_op, tensor.Tensor) and not isinstance(train_op, ops.Operation)): raise TypeError(f"`train_op` {train_op} needs to be a Tensor or Op.") ops.add_to_collection(constants.TRAIN_OP_KEY, train_op) @@ -737,7 +738,7 @@ def _asset_path_from_tensor(path_tensor): Raises: TypeError if tensor does not match expected op type, dtype or value. """ - if not isinstance(path_tensor, ops.Tensor): + if not isinstance(path_tensor, tensor.Tensor): raise TypeError(f"Asset path tensor {path_tensor} must be a Tensor.") if path_tensor.op.type != "Const": raise TypeError(f"Asset path tensor {path_tensor} must be of type constant." diff --git a/tensorflow/python/saved_model/function_deserialization.py b/tensorflow/python/saved_model/function_deserialization.py index 58d185e185b666..4b1c57a746e240 100644 --- a/tensorflow/python/saved_model/function_deserialization.py +++ b/tensorflow/python/saved_model/function_deserialization.py @@ -30,7 +30,7 @@ from tensorflow.python.framework import function_def_to_graph as function_def_lib from tensorflow.python.framework import op_def_registry from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_spec +from tensorflow.python.framework import tensor from tensorflow.python.framework import type_spec from tensorflow.python.ops import array_ops from tensorflow.python.ops import custom_gradient @@ -44,7 +44,8 @@ def _is_tensor(t): - return isinstance(t, (ops.Tensor, resource_variable_ops.BaseResourceVariable)) + return isinstance( + t, (tensor.Tensor, resource_variable_ops.BaseResourceVariable)) # TODO(b/205016027): Update this to just use ConcreteFunction.__call__ with the @@ -72,7 +73,7 @@ def _call_concrete_function(function, inputs): flatten_expected = nest.flatten(expected_structure, expand_composites=True) tensor_inputs = [] for arg, expected in zip(flatten_inputs, flatten_expected): - if isinstance(expected, tensor_spec.TensorSpec): + if isinstance(expected, tensor.TensorSpec): tensor_inputs.append( ops.convert_to_tensor(arg, dtype_hint=expected.dtype)) elif isinstance(expected, resource_variable_ops.VariableSpec): @@ -89,7 +90,7 @@ def _try_convert_to_tensor_spec(arg, dtype_hint): # Note: try conversion in a FuncGraph to avoid polluting current context. with func_graph_lib.FuncGraph(name="guess_conversion").as_default(): result = ops.convert_to_tensor(arg, dtype_hint=dtype_hint) - return tensor_spec.TensorSpec(shape=result.shape, dtype=result.dtype) + return tensor.TensorSpec(shape=result.shape, dtype=result.dtype) except (TypeError, ValueError): return None @@ -103,10 +104,10 @@ def _concrete_function_callable_with(function, inputs, allow_conversion): return False for arg, expected in zip(flatten_inputs, nest.flatten(expected_structure)): - if isinstance(expected, tensor_spec.TensorSpec): + if isinstance(expected, tensor.TensorSpec): if allow_conversion: arg = _try_convert_to_tensor_spec(arg, dtype_hint=expected.dtype) - if not _is_tensor(arg) and not isinstance(arg, tensor_spec.TensorSpec): + if not _is_tensor(arg) and not isinstance(arg, tensor.TensorSpec): return False if arg.dtype != expected.dtype: return False diff --git a/tensorflow/python/saved_model/model_utils/BUILD b/tensorflow/python/saved_model/model_utils/BUILD index 86485c3b7619b9..4fd87596a74aed 100644 --- a/tensorflow/python/saved_model/model_utils/BUILD +++ b/tensorflow/python/saved_model/model_utils/BUILD @@ -43,6 +43,7 @@ py_strict_library( "//tensorflow/python/framework:constant_op", "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_util", "//tensorflow/python/saved_model:signature_def_utils", ], @@ -59,8 +60,8 @@ py_strict_test( "//tensorflow/python/eager:context", "//tensorflow/python/framework:constant_op", "//tensorflow/python/framework:dtypes", - "//tensorflow/python/framework:ops", "//tensorflow/python/framework:sparse_tensor", + "//tensorflow/python/framework:tensor", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:control_flow_ops", "//tensorflow/python/ops:metrics", diff --git a/tensorflow/python/saved_model/model_utils/export_output.py b/tensorflow/python/saved_model/model_utils/export_output.py index c38b12525d90d9..8903a08fba798d 100644 --- a/tensorflow/python/saved_model/model_utils/export_output.py +++ b/tensorflow/python/saved_model/model_utils/export_output.py @@ -21,6 +21,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.framework import tensor_util from tensorflow.python.saved_model import signature_def_utils @@ -86,7 +87,7 @@ def _wrap_and_check_outputs( for key, value in outputs.items(): error_name = error_label or single_output_default_name key = self._check_output_key(key, error_name) - if not isinstance(value, ops.Tensor): + if not isinstance(value, tensor.Tensor): raise ValueError( '{} output value must be a Tensor; got {}.'.format( error_name, value)) @@ -128,12 +129,12 @@ def __init__(self, scores=None, classes=None): `Tensor` with the correct dtype. """ if (scores is not None - and not (isinstance(scores, ops.Tensor) + and not (isinstance(scores, tensor.Tensor) and scores.dtype.is_floating)): raise ValueError('Classification scores must be a float32 Tensor; ' 'got {}'.format(scores)) if (classes is not None - and not (isinstance(classes, ops.Tensor) + and not (isinstance(classes, tensor.Tensor) and dtypes.as_dtype(classes.dtype) == dtypes.string)): raise ValueError('Classification classes must be a string Tensor; ' 'got {}'.format(classes)) @@ -186,7 +187,7 @@ def __init__(self, value): Raises: ValueError: if the value is not a `Tensor` with dtype tf.float32. """ - if not (isinstance(value, ops.Tensor) and value.dtype.is_floating): + if not (isinstance(value, tensor.Tensor) and value.dtype.is_floating): raise ValueError('Regression output value must be a float32 Tensor; ' 'got {}'.format(value)) self._value = value @@ -355,7 +356,7 @@ def _wrap_and_check_metrics(self, metrics): val_name = key + self._SEPARATOR_CHAR + self.METRIC_VALUE_SUFFIX op_name = key + self._SEPARATOR_CHAR + self.METRIC_UPDATE_SUFFIX - if not isinstance(metric_val, ops.Tensor): + if not isinstance(metric_val, tensor.Tensor): raise ValueError( '{} output value must be a Tensor; got {}.'.format( key, metric_val)) @@ -368,7 +369,7 @@ def _wrap_and_check_metrics(self, metrics): # We must wrap any ops (or variables) in a Tensor before export, as the # SignatureDef proto expects tensors only. See b/109740581 metric_op_tensor = metric_op - if not isinstance(metric_op, ops.Tensor): + if not isinstance(metric_op, tensor.Tensor): with ops.control_dependencies([metric_op]): metric_op_tensor = constant_op.constant([], name='metric_op_wrapper') diff --git a/tensorflow/python/saved_model/model_utils/export_output_test.py b/tensorflow/python/saved_model/model_utils/export_output_test.py index 072208e9f30868..9c84e544ec35d2 100644 --- a/tensorflow/python/saved_model/model_utils/export_output_test.py +++ b/tensorflow/python/saved_model/model_utils/export_output_test.py @@ -20,8 +20,8 @@ from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import metrics as metrics_module @@ -385,15 +385,15 @@ def test_metric_op_is_tensor(self): self.assertTrue(outputter.metrics['metrics_1/update_op'].name.startswith( 'mean/update_op')) self.assertIsInstance( - outputter.metrics['metrics_1/update_op'], ops.Tensor) - self.assertIsInstance(outputter.metrics['metrics_1/value'], ops.Tensor) + outputter.metrics['metrics_1/update_op'], tensor.Tensor) + self.assertIsInstance(outputter.metrics['metrics_1/value'], tensor.Tensor) self.assertEqual(outputter.metrics['metrics_2/value'], metrics['metrics_2'][0]) self.assertTrue(outputter.metrics['metrics_2/update_op'].name.startswith( 'metric_op_wrapper')) self.assertIsInstance( - outputter.metrics['metrics_2/update_op'], ops.Tensor) + outputter.metrics['metrics_2/update_op'], tensor.Tensor) if __name__ == '__main__': diff --git a/tensorflow/python/saved_model/nested_structure_coder_test.py b/tensorflow/python/saved_model/nested_structure_coder_test.py index f010471da138de..c2b9e12d437605 100644 --- a/tensorflow/python/saved_model/nested_structure_coder_test.py +++ b/tensorflow/python/saved_model/nested_structure_coder_test.py @@ -26,10 +26,9 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import extension_type -from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import tensor_util from tensorflow.python.framework import test_util from tensorflow.python.framework import type_spec @@ -165,7 +164,7 @@ def testDtype(self): self.assertEqual(structure, decoded) def testEncodeDecodeTensorSpec(self): - structure = [tensor_spec.TensorSpec([1, 2, 3], dtypes.int64, "hello")] + structure = [tensor.TensorSpec([1, 2, 3], dtypes.int64, "hello")] self.assertTrue(nested_structure_coder.can_encode(structure)) encoded = nested_structure_coder.encode_structure(structure) expected = struct_pb2.StructuredValue() @@ -181,7 +180,7 @@ def testEncodeDecodeTensorSpec(self): self.assertEqual(structure, decoded) def testEncodeDecodeTensorSpecWithNoName(self): - structure = [tensor_spec.TensorSpec([1, 2, 3], dtypes.int64)] + structure = [tensor.TensorSpec([1, 2, 3], dtypes.int64)] self.assertTrue(nested_structure_coder.can_encode(structure)) encoded = nested_structure_coder.encode_structure(structure) expected = struct_pb2.StructuredValue() @@ -276,12 +275,12 @@ def testEncodeDecodeExtensionTypeSpec(self): class Zoo(extension_type.ExtensionType): __name__ = "tf.nested_structure_coder_test.Zoo" zookeepers: typing.Tuple[str, ...] - animals: typing.Mapping[str, ops.Tensor] + animals: typing.Mapping[str, tensor.Tensor] structure = [ Zoo.Spec( zookeepers=["Zoey", "Zack"], - animals={"tiger": tensor_spec.TensorSpec([16])}) + animals={"tiger": tensor.TensorSpec([16])}) ] self.assertTrue(nested_structure_coder.can_encode(structure)) @@ -327,8 +326,7 @@ def testDecodeUnknownTensorSpec(self): def testEncodeDecodeBoundedTensorSpec(self): structure = [ - tensor_spec.BoundedTensorSpec([1, 2, 3], dtypes.int64, 0, 10, - "hello_0_10") + tensor.BoundedTensorSpec([1, 2, 3], dtypes.int64, 0, 10, "hello_0_10") ] self.assertTrue(nested_structure_coder.can_encode(structure)) encoded = nested_structure_coder.encode_structure(structure) @@ -350,8 +348,7 @@ def testEncodeDecodeBoundedTensorSpec(self): def testEncodeDecodeBoundedTensorSpecNoName(self): structure = [ - tensor_spec.BoundedTensorSpec((28, 28, 3), dtypes.float64, -2, - (1, 1, 20)) + tensor.BoundedTensorSpec((28, 28, 3), dtypes.float64, -2, (1, 1, 20)) ] self.assertTrue(nested_structure_coder.can_encode(structure)) encoded = nested_structure_coder.encode_structure(structure) @@ -378,7 +375,7 @@ def testEncodeDataSetSpec(self): dataset_ops.DatasetSpec({ "rt": ragged_tensor.RaggedTensorSpec([10, None], dtypes.int32), "st": sparse_tensor.SparseTensorSpec([10, 20], dtypes.float32), - "t": tensor_spec.TensorSpec([10, 8], dtypes.string) + "t": tensor.TensorSpec([10, 8], dtypes.string) }) ] self.assertTrue(nested_structure_coder.can_encode(structure)) diff --git a/tensorflow/python/saved_model/saved_model_test.py b/tensorflow/python/saved_model/saved_model_test.py index a0f6478ae55fa1..87fd9151bba05c 100644 --- a/tensorflow/python/saved_model/saved_model_test.py +++ b/tensorflow/python/saved_model/saved_model_test.py @@ -24,6 +24,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import test_ops from tensorflow.python.framework import test_util from tensorflow.python.lib.io import file_io @@ -939,7 +940,7 @@ def testTrainOp(self): "AssignAddVariableOp") else: self.assertIsInstance( - loader_impl.get_train_op(meta_graph_def), ops.Tensor) + loader_impl.get_train_op(meta_graph_def), tensor_lib.Tensor) def testTrainOpGroup(self): export_dir = self._get_export_dir("test_train_op_group") @@ -995,7 +996,7 @@ def testTrainOpAfterVariables(self): "AssignAddVariableOp") else: self.assertIsInstance( - loader_impl.get_train_op(meta_graph_def), ops.Tensor) + loader_impl.get_train_op(meta_graph_def), tensor_lib.Tensor) with self.session(graph=ops.Graph()) as sess: loader.load(sess, ["pre_foo"], export_dir) diff --git a/tensorflow/python/saved_model/signature_def_utils_impl.py b/tensorflow/python/saved_model/signature_def_utils_impl.py index b2911b174b2239..5de175e2a85687 100644 --- a/tensorflow/python/saved_model/signature_def_utils_impl.py +++ b/tensorflow/python/saved_model/signature_def_utils_impl.py @@ -19,6 +19,7 @@ from tensorflow.core.protobuf import meta_graph_pb2 from tensorflow.python.framework import errors from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_util from tensorflow.python.saved_model import signature_constants from tensorflow.python.saved_model import utils_impl as utils @@ -104,7 +105,7 @@ def regression_signature_def(examples, predictions): """ if examples is None: raise ValueError('Regression `examples` cannot be None.') - if not isinstance(examples, ops.Tensor): + if not isinstance(examples, tensor_lib.Tensor): raise ValueError('Expected regression `examples` to be of type Tensor. ' f'Found `examples` of type {type(examples)}.') if predictions is None: @@ -157,7 +158,7 @@ def classification_signature_def(examples, classes, scores): """ if examples is None: raise ValueError('Classification `examples` cannot be None.') - if not isinstance(examples, ops.Tensor): + if not isinstance(examples, tensor_lib.Tensor): raise ValueError('Classification `examples` must be a string Tensor. ' f'Found `examples` of type {type(examples)}.') if classes is None and scores is None: diff --git a/tensorflow/python/saved_model/signature_serialization.py b/tensorflow/python/saved_model/signature_serialization.py index 9cadfc9076e3ff..38362c8087a838 100644 --- a/tensorflow/python/saved_model/signature_serialization.py +++ b/tensorflow/python/saved_model/signature_serialization.py @@ -20,8 +20,7 @@ from tensorflow.python.eager import function as defun from tensorflow.python.eager.polymorphic_function import attributes from tensorflow.python.framework import composite_tensor -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_spec +from tensorflow.python.framework import tensor from tensorflow.python.ops import resource_variable_ops from tensorflow.python.saved_model import function_serialization from tensorflow.python.saved_model import revived_types @@ -192,7 +191,7 @@ def signature_wrapper(**kwargs): if signature_function.structured_input_signature is not None: # The structured input signature may contain other non-tensor arguments. inputs = filter( - lambda x: isinstance(x, tensor_spec.TensorSpec), + lambda x: isinstance(x, tensor.TensorSpec), nest.flatten( signature_function.structured_input_signature, expand_composites=True, @@ -207,10 +206,10 @@ def signature_wrapper(**kwargs): inputs, ): keyword = compat.as_str(keyword) - if isinstance(inp, tensor_spec.TensorSpec): - spec = tensor_spec.TensorSpec(inp.shape, inp.dtype, name=keyword) + if isinstance(inp, tensor.TensorSpec): + spec = tensor.TensorSpec(inp.shape, inp.dtype, name=keyword) else: - spec = tensor_spec.TensorSpec.from_tensor(inp, name=keyword) + spec = tensor.TensorSpec.from_tensor(inp, name=keyword) tensor_spec_signature[keyword] = spec final_concrete = wrapped_function._get_concrete_function_garbage_collected( # pylint: disable=protected-access **tensor_spec_signature @@ -240,7 +239,7 @@ def signature_wrapper(**kwargs): arg_names[-len_default:], # pylint: disable=protected-access flattened_defaults or [], ): - if not isinstance(default, ops.Tensor): + if not isinstance(default, tensor.Tensor): continue defaults.setdefault(signature_key, {})[arg] = default return concrete_signatures, wrapped_functions, defaults @@ -269,7 +268,7 @@ def _normalize_outputs(outputs, function_name, signature_key): f"the function {compat.as_str_any(function_name)} used to generate " f"the SavedModel signature {signature_key!r}." ) - if not isinstance(value, (ops.Tensor, composite_tensor.CompositeTensor)): + if not isinstance(value, (tensor.Tensor, composite_tensor.CompositeTensor)): raise ValueError( f"Got a non-Tensor value {value!r} for key {key!r} in the output of " f"the function {compat.as_str_any(function_name)} used to generate " diff --git a/tensorflow/python/saved_model/utils_test.py b/tensorflow/python/saved_model/utils_test.py index a30cd5253d1136..52d44b8999b603 100644 --- a/tensorflow/python/saved_model/utils_test.py +++ b/tensorflow/python/saved_model/utils_test.py @@ -21,6 +21,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops @@ -102,7 +103,7 @@ def testGetTensorFromInfoDense(self): expected = array_ops.placeholder(dtypes.float32, 1, name="x") tensor_info = utils.build_tensor_info(expected) actual = utils.get_tensor_from_tensor_info(tensor_info) - self.assertIsInstance(actual, ops.Tensor) + self.assertIsInstance(actual, tensor.Tensor) self.assertEqual(expected.name, actual.name) @test_util.run_v1_only( @@ -134,7 +135,7 @@ def testGetTensorFromInfoInOtherGraph(self): array_ops.placeholder(dtypes.float32, 1, name="other") actual = utils.get_tensor_from_tensor_info(tensor_info, graph=expected_graph) - self.assertIsInstance(actual, ops.Tensor) + self.assertIsInstance(actual, tensor.Tensor) self.assertIs(actual.graph, expected_graph) self.assertEqual(expected.name, actual.name) diff --git a/tensorflow/python/trackable/BUILD b/tensorflow/python/trackable/BUILD index f84fad992352b0..6d787f3c242e86 100644 --- a/tensorflow/python/trackable/BUILD +++ b/tensorflow/python/trackable/BUILD @@ -185,6 +185,7 @@ py_strict_library( "//tensorflow/python/eager:context", "//tensorflow/python/eager:def_function", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/util:tf_decorator", "//tensorflow/python/util:tf_export", ], diff --git a/tensorflow/python/trackable/resource.py b/tensorflow/python/trackable/resource.py index 823d70b8f10c34..ee4a5c1361cbba 100644 --- a/tensorflow/python/trackable/resource.py +++ b/tensorflow/python/trackable/resource.py @@ -21,6 +21,7 @@ from tensorflow.python.eager import context from tensorflow.python.eager import def_function from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.trackable import base from tensorflow.python.util import tf_contextlib from tensorflow.python.util.tf_export import tf_export @@ -152,7 +153,7 @@ def _resource_handle(self): @_resource_handle.setter def _resource_handle(self, value): - if isinstance(value, (ops.Tensor, ops.EagerTensor)): + if isinstance(value, (tensor.Tensor, ops.EagerTensor)): value._parent_trackable = weakref.ref(self) # pylint: disable=protected-access self._resource_handle_value = value From 29da06547177e610a0c1cf7083711b06bf9ec583 Mon Sep 17 00:00:00 2001 From: David Dunleavy Date: Mon, 10 Jul 2023 15:10:13 -0700 Subject: [PATCH 085/376] [NFC] Change uses of get_compatible_with_cloud to get_compatible_with_portable. PiperOrigin-RevId: 546995216 --- tensorflow/dtensor/mlir/BUILD | 6 +++--- tensorflow/dtensor/mlir/dtensor_dialect/BUILD | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tensorflow/dtensor/mlir/BUILD b/tensorflow/dtensor/mlir/BUILD index 95c79716168a9b..a6ef38cf6bb63d 100644 --- a/tensorflow/dtensor/mlir/BUILD +++ b/tensorflow/dtensor/mlir/BUILD @@ -3,7 +3,7 @@ load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load("@bazel_skylib//rules:build_test.bzl", "build_test") load("//tensorflow:tensorflow.bzl", "tf_cc_test") load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -18,7 +18,7 @@ package( gentbl_cc_library( name = "tensorflow_dtensor_ops_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["-gen-op-decls"], @@ -46,7 +46,7 @@ gentbl_cc_library( gentbl_cc_library( name = "dtensor_passes_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [( [ "-gen-pass-decls", diff --git a/tensorflow/dtensor/mlir/dtensor_dialect/BUILD b/tensorflow/dtensor/mlir/dtensor_dialect/BUILD index cee04f594da170..feb41e226c21c0 100644 --- a/tensorflow/dtensor/mlir/dtensor_dialect/BUILD +++ b/tensorflow/dtensor/mlir/dtensor_dialect/BUILD @@ -1,7 +1,7 @@ # DTensor MLIR dialect. load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -21,7 +21,7 @@ td_library( "ir/dtensor_dialect.td", "ir/dtensor_ops.td", ], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "@llvm-project//mlir:FuncTdFiles", "@llvm-project//mlir:InferTypeOpInterfaceTdFiles", @@ -31,7 +31,7 @@ td_library( gentbl_cc_library( name = "DialectIncGen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["-gen-op-decls"], From 0b91ac89fb6e51208b126314448af50ae607b284 Mon Sep 17 00:00:00 2001 From: Juan Martinez Castellanos Date: Mon, 10 Jul 2023 15:14:37 -0700 Subject: [PATCH 086/376] Apply pytype to conjugate_gradient.py. PiperOrigin-RevId: 546996250 --- tensorflow/python/ops/linalg/sparse/BUILD | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/ops/linalg/sparse/BUILD b/tensorflow/python/ops/linalg/sparse/BUILD index ec7abda24c2910..c14f3e9c0e1a83 100644 --- a/tensorflow/python/ops/linalg/sparse/BUILD +++ b/tensorflow/python/ops/linalg/sparse/BUILD @@ -1,4 +1,5 @@ load("//tensorflow:strict.default.bzl", "py_strict_library") +load("//tensorflow:pytype.default.bzl", "pytype_strict_library") # Description: Sparse CSR support for TensorFlow. load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py") @@ -53,7 +54,7 @@ py_strict_library( srcs = ["__init__.py"], ) -py_strict_library( +pytype_strict_library( name = "conjugate_gradient", srcs = ["conjugate_gradient.py"], deps = [ From 9ce00c72ebbac87a5a3c3e205723e2b563c1ef43 Mon Sep 17 00:00:00 2001 From: Fiona Lang Date: Mon, 10 Jul 2023 15:15:34 -0700 Subject: [PATCH 087/376] Update ops.Tensor references to //third_party/tensorflow/python/framework/tensor.py. PiperOrigin-RevId: 546996508 --- tensorflow/dtensor/python/BUILD | 7 ++++-- tensorflow/dtensor/python/api.py | 19 ++++++++------ tensorflow/dtensor/python/dtensor_device.py | 3 ++- tensorflow/dtensor/python/input_util.py | 5 ++-- tensorflow/dtensor/python/layout.py | 3 ++- tensorflow/dtensor/python/save_restore.py | 28 +++++++++++++-------- 6 files changed, 41 insertions(+), 24 deletions(-) diff --git a/tensorflow/dtensor/python/BUILD b/tensorflow/dtensor/python/BUILD index 2676f040d0cd86..56d789c6c0bc65 100644 --- a/tensorflow/dtensor/python/BUILD +++ b/tensorflow/dtensor/python/BUILD @@ -52,7 +52,7 @@ pytype_strict_library( ":layout", "//tensorflow/python/eager:context", "//tensorflow/python/framework:ops", - "//tensorflow/python/framework:tensor_util", + "//tensorflow/python/framework:tensor", "//tensorflow/python/util:deprecation", "//tensorflow/python/util:tf_export", ], @@ -94,6 +94,7 @@ pytype_strict_library( "//tensorflow/python:_pywrap_dtensor_device", "//tensorflow/python/framework:device", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/util:tf_export", "//third_party/py/numpy", ], @@ -179,7 +180,7 @@ pytype_strict_library( "//tensorflow/python/eager:context", "//tensorflow/python/framework:errors", "//tensorflow/python/framework:ops", - "//tensorflow/python/ops:array_ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/ops:io_ops", "//tensorflow/python/ops:variables", "//tensorflow/python/util:tf_export", @@ -204,6 +205,7 @@ pytype_strict_library( "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", "//tensorflow/python/framework:sparse_tensor", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_util", "//tensorflow/python/util:_pywrap_utils", "//third_party/py/numpy", @@ -310,6 +312,7 @@ pytype_strict_library( "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:errors", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/framework:tensor_spec", "//tensorflow/python/ops:array_ops", diff --git a/tensorflow/dtensor/python/api.py b/tensorflow/dtensor/python/api.py index d971a4dfa7bb0c..2f303e9aa3a218 100644 --- a/tensorflow/dtensor/python/api.py +++ b/tensorflow/dtensor/python/api.py @@ -23,6 +23,7 @@ from tensorflow.dtensor.python import layout as layout_lib from tensorflow.python.eager import context from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.util import deprecation from tensorflow.python.util.tf_export import tf_export @@ -167,7 +168,7 @@ def is_dtensor(tensor) -> bool: def copy_to_mesh( tensor: Any, layout: layout_lib.Layout, - source_layout: Optional[layout_lib.Layout] = None) -> ops.Tensor: + source_layout: Optional[layout_lib.Layout] = None) -> tensor_lib.Tensor: """Copies a tf.Tensor onto the DTensor device with the given layout. Copies a regular tf.Tensor onto the DTensor device. Use the mesh attached to @@ -377,7 +378,7 @@ def unpack(tensor: Any) -> Sequence[Any]: @tf_export("experimental.dtensor.fetch_layout", v1=[]) -def fetch_layout(tensor: ops.Tensor) -> layout_lib.Layout: +def fetch_layout(tensor: tensor_lib.Tensor) -> layout_lib.Layout: """Fetches the layout of a DTensor. Args: @@ -393,7 +394,7 @@ def fetch_layout(tensor: ops.Tensor) -> layout_lib.Layout: @tf_export("experimental.dtensor.check_layout", v1=[]) -def check_layout(tensor: ops.Tensor, layout: layout_lib.Layout) -> None: +def check_layout(tensor: tensor_lib.Tensor, layout: layout_lib.Layout) -> None: """Asserts that the layout of the DTensor is `layout`. Args: @@ -410,8 +411,10 @@ def check_layout(tensor: ops.Tensor, layout: layout_lib.Layout) -> None: @tf_export("experimental.dtensor.relayout", v1=[]) def relayout( - tensor: ops.Tensor, layout: layout_lib.Layout, name: Optional[str] = None -) -> ops.Tensor: + tensor: tensor_lib.Tensor, + layout: layout_lib.Layout, + name: Optional[str] = None, +) -> tensor_lib.Tensor: """Changes the layout of `tensor`. Changes the layout of `tensor` to `layout`. This is used to fine-tune the @@ -449,8 +452,10 @@ def relayout( @tf_export("experimental.dtensor.relayout_like", v1=[]) def relayout_like( - tensor: ops.Tensor, layout_tensor: ops.Tensor, name: Optional[str] = None -) -> ops.Tensor: + tensor: tensor_lib.Tensor, + layout_tensor: tensor_lib.Tensor, + name: Optional[str] = None, +) -> tensor_lib.Tensor: """Changes the layout of `tensor` to the same as `layout_tensor`. `relayout_like` is often used inside a `tf.function`, to ensure a tensor is diff --git a/tensorflow/dtensor/python/dtensor_device.py b/tensorflow/dtensor/python/dtensor_device.py index c475b8fb8158ec..8a252e64f1e402 100644 --- a/tensorflow/dtensor/python/dtensor_device.py +++ b/tensorflow/dtensor/python/dtensor_device.py @@ -32,6 +32,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_util from tensorflow.python.util import _pywrap_utils @@ -119,7 +120,7 @@ def _register_mesh(self, mesh: layout_lib.Mesh): def meshes(self) -> Set[layout_lib.Mesh]: return self._meshes - def copy_to_mesh(self, tensor, new_layout) -> ops.Tensor: + def copy_to_mesh(self, tensor, new_layout) -> tensor_lib.Tensor: """Copy `tensor` to `device` with the given layout.""" self._register_mesh(new_layout.mesh) with ops.device(self.name): diff --git a/tensorflow/dtensor/python/input_util.py b/tensorflow/dtensor/python/input_util.py index 230ae42fdcb5fc..a504c3639a1467 100644 --- a/tensorflow/dtensor/python/input_util.py +++ b/tensorflow/dtensor/python/input_util.py @@ -74,6 +74,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_spec from tensorflow.python.ops import array_ops @@ -110,7 +111,7 @@ class _DTensorIterator(iterator_ops.OwnedIterator): def __init__( self, - dtensor_components: Tuple[ops.Tensor], + dtensor_components: Tuple[tensor.Tensor], global_element_spec: tensor_spec.TensorSpec, layouts: Any): """Initializes a distributed iterator for DTensor datasets. @@ -283,7 +284,7 @@ def _shard_counts(layout: layout_lib.Layout, def _index_matrix(layout: layout_lib.Layout, - elem_spec: tensor_spec.TensorSpec) -> ops.Tensor: + elem_spec: tensor_spec.TensorSpec) -> tensor.Tensor: """Computes a utility matrix to derive device-based slice offsets. This function builds a matrix of shape `[mesh.rank, layout.rank]` for each diff --git a/tensorflow/dtensor/python/layout.py b/tensorflow/dtensor/python/layout.py index d17cc1f18ba911..c3b21d744c2e20 100644 --- a/tensorflow/dtensor/python/layout.py +++ b/tensorflow/dtensor/python/layout.py @@ -25,6 +25,7 @@ from tensorflow.python import _pywrap_dtensor_device from tensorflow.python.framework import device as tf_device from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.util.tf_export import tf_export # UNSHARDED indicates a tensor dimension is not sharded over any mesh dimension. @@ -245,7 +246,7 @@ def __reduce__(self): return Mesh.from_string, (self.to_string(),) # TODO(b/242201545): implement this in Mesh C++ class - def coords(self, device_idx: int) -> ops.Tensor: + def coords(self, device_idx: int) -> tensor.Tensor: """Converts the device index into a tensor of mesh coordinates.""" strides = ops.convert_to_tensor(self.strides) shape = ops.convert_to_tensor(self.shape()) diff --git a/tensorflow/dtensor/python/save_restore.py b/tensorflow/dtensor/python/save_restore.py index dde0b00d4d5e27..25bd78cdf00a99 100644 --- a/tensorflow/dtensor/python/save_restore.py +++ b/tensorflow/dtensor/python/save_restore.py @@ -25,6 +25,7 @@ from tensorflow.python.eager import context from tensorflow.python.framework import errors_impl from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.ops import io_ops from tensorflow.python.ops import variables as tf_variables from tensorflow.python.util.tf_export import tf_export @@ -33,10 +34,10 @@ @tf_export('experimental.dtensor.sharded_save', v1=[]) def sharded_save( mesh: layout_lib.Mesh, - file_prefix: Union[str, ops.Tensor], - tensor_names: Union[List[str], ops.Tensor], - shape_and_slices: Union[List[str], ops.Tensor], - tensors: List[Union[ops.Tensor, tf_variables.Variable]], + file_prefix: Union[str, tensor_lib.Tensor], + tensor_names: Union[List[str], tensor_lib.Tensor], + shape_and_slices: Union[List[str], tensor_lib.Tensor], + tensors: List[Union[tensor_lib.Tensor, tf_variables.Variable]], ): """Saves given named tensor slices in a sharded, multi-client safe fashion. @@ -100,7 +101,8 @@ def enable_save_as_bf16(variables: List[tf_variables.Variable]): def name_based_restore( mesh: layout_lib.Mesh, checkpoint_prefix: str, - name_tensor_dict: Dict[str, Union[ops.Tensor, tf_variables.Variable]], + name_tensor_dict: Dict[ + str, Union[tensor_lib.Tensor, tf_variables.Variable]], ): """Restores from checkpoint_prefix to name based DTensors. @@ -163,17 +165,21 @@ def name_based_restore( shape_and_slices=shape_and_slices, input_shapes=input_shapes, input_layouts=input_layouts, - dtypes=[tensor.dtype for tensor in ordered_name_tensor_dict.values()]) + dtypes=[tensor.dtype for tensor in ordered_name_tensor_dict.values()], + ) return collections.OrderedDict( - zip(ordered_name_tensor_dict.keys(), restored_cpu_tensors)) + zip(ordered_name_tensor_dict.keys(), restored_cpu_tensors) + ) @tf_export('experimental.dtensor.name_based_save', v1=[]) -def name_based_save(mesh: layout_lib.Mesh, checkpoint_prefix: Union[str, - ops.Tensor], - name_tensor_dict: Dict[str, Union[ops.Tensor, - tf_variables.Variable]]): +def name_based_save( + mesh: layout_lib.Mesh, + checkpoint_prefix: Union[str, tensor_lib.Tensor], + name_tensor_dict: Dict[ + str, Union[tensor_lib.Tensor, tf_variables.Variable]], +): """Saves name based Tensor into a Checkpoint. The function prepares the input dictionary to the format of a `sharded_save`, From 67d46ce25fad87e98c3931ae53be142d88e497ca Mon Sep 17 00:00:00 2001 From: David Dunleavy Date: Mon, 10 Jul 2023 15:28:59 -0700 Subject: [PATCH 088/376] Remove extraneous `extra_copts` argument to xla_cc_test PiperOrigin-RevId: 546999931 --- tensorflow/compiler/xla/tests/build_defs.bzl | 4 ++-- tensorflow/compiler/xla/xla.bzl | 2 -- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/tensorflow/compiler/xla/tests/build_defs.bzl b/tensorflow/compiler/xla/tests/build_defs.bzl index 210422b6009571..425c144d1a9033 100644 --- a/tensorflow/compiler/xla/tests/build_defs.bzl +++ b/tensorflow/compiler/xla/tests/build_defs.bzl @@ -136,8 +136,8 @@ def xla_test( name = test_name, srcs = srcs, tags = tags + backend_tags.get(backend, []) + this_backend_tags, - extra_copts = copts + ["-DXLA_TEST_BACKEND_%s=1" % backend.upper()] + - this_backend_copts, + copts = copts + ["-DXLA_TEST_BACKEND_%s=1" % backend.upper()] + + this_backend_copts, args = args + this_backend_args, deps = deps + backend_deps, data = data + this_backend_data, diff --git a/tensorflow/compiler/xla/xla.bzl b/tensorflow/compiler/xla/xla.bzl index 6259d8a540a098..445c7d4ed46138 100644 --- a/tensorflow/compiler/xla/xla.bzl +++ b/tensorflow/compiler/xla/xla.bzl @@ -79,11 +79,9 @@ def xla_cc_binary(deps = None, copts = tsl_copts(), **kwargs): def xla_cc_test( name, deps = [], - extra_copts = [], **kwargs): native.cc_test( name = name, - copts = extra_copts, deps = deps + if_tsl_link_protobuf( [], [ From 1cbf7f621256a328a7a2ba6927dde665d494ea7b Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 10 Jul 2023 15:30:03 -0700 Subject: [PATCH 089/376] AllReduce experimental clustering/grouping optimization on GPU by topological order and distance. Introduce a new environment variable `DTENSOR_ALLREDUCE_COMBINE_OPTIMIZATION_TOPOLOGICAL_DISTANCE` to enable and control AllReduce grouping by their topological level distance on the compute graph. The goal of this optimization is to reflect locality in the grouping algorithm. PiperOrigin-RevId: 547000224 --- tensorflow/dtensor/cc/constants.h | 10 ++ .../dtensor/cc/dtensor_graph_to_mlir_pass.cc | 6 + tensorflow/dtensor/cc/dtensor_utils.cc | 18 +++ tensorflow/dtensor/cc/dtensor_utils.h | 17 ++- .../dtensor_allreduce_combine_optimization.cc | 138 +++++++++++++++++- ...tensor_allreduce_combine_optimization.mlir | 74 ++++++++++ 6 files changed, 254 insertions(+), 9 deletions(-) diff --git a/tensorflow/dtensor/cc/constants.h b/tensorflow/dtensor/cc/constants.h index 9dea0928fe4e1d..3ad6d6c39a706b 100644 --- a/tensorflow/dtensor/cc/constants.h +++ b/tensorflow/dtensor/cc/constants.h @@ -136,6 +136,8 @@ static constexpr int kSparseTensorNum = 3; // Attribute which stores the environment variable value for all_reduce // optimization group size: DTENSOR_ALLREDUCE_COMBINE_OPTIMIZATION_GROUP_SIZE. +// This represents the maximum number of AllReduce ops to merge into one op. It +// is a determining factor used during dtensor_allreduce_combine_optimization. static constexpr char kAllReduceNumOpsInGroup[] = "dtensor.all_reduce_combiner.num_ops_in_group"; @@ -144,6 +146,14 @@ static constexpr char kAllReduceNumOpsInGroup[] = static constexpr char kEnableMultiDeviceMode[] = "dtensor.enable_multi_device_mode"; +// Attribute which stores the environment variable value for all_reduce +// optimization group size: DTENSOR_ALLREDUCE_COMBINE_OPTIMIZATION_GROUP_SIZE. +// This represents the maximum distance between two AllReduce on the compute +// graph in terms of topological level. It is a determining factor used during +// dtensor_allreduce_combine_optimization. +static constexpr char kAllReduceTopologicalDistance[] = + "dtensor.all_reduce_combiner.topological_distance"; + } // namespace dtensor } // namespace tensorflow diff --git a/tensorflow/dtensor/cc/dtensor_graph_to_mlir_pass.cc b/tensorflow/dtensor/cc/dtensor_graph_to_mlir_pass.cc index 1e207cbca6fe37..28691771e2376e 100644 --- a/tensorflow/dtensor/cc/dtensor_graph_to_mlir_pass.cc +++ b/tensorflow/dtensor/cc/dtensor_graph_to_mlir_pass.cc @@ -128,6 +128,12 @@ DTensorMlirPassRunner::ImportGraphToMlir( mlir::IntegerAttr::get(mlir::IntegerType::get(&context_, /*width=*/64), group_size)); + int topo_dist = dtensor::AllReduceCombineOptimizationTopologicalDistance(); + module->setAttr( + dtensor::kAllReduceTopologicalDistance, + mlir::IntegerAttr::get(mlir::IntegerType::get(&context_, /*width=*/64), + topo_dist)); + if (dtensor::EnableMultiDeviceMode()) { module->setAttr(dtensor::kEnableMultiDeviceMode, mlir::BoolAttr::get(&context_, true)); diff --git a/tensorflow/dtensor/cc/dtensor_utils.cc b/tensorflow/dtensor/cc/dtensor_utils.cc index 3c915c4b25c845..b22c9ad5b0023d 100644 --- a/tensorflow/dtensor/cc/dtensor_utils.cc +++ b/tensorflow/dtensor/cc/dtensor_utils.cc @@ -158,6 +158,24 @@ int AllReduceCombineOptimizationGroupSize() { return 0; } +int AllReduceCombineOptimizationTopologicalDistance() { + int64_t topo_dist; + absl::Status status = tsl::ReadInt64FromEnvVar( + "DTENSOR_ALLREDUCE_COMBINE_OPTIMIZATION_TOPOLOGICAL_DISTANCE", + /*default_val=*/0, &topo_dist); + if (!status.ok()) { + LOG(WARNING) << "Invalid DTENSOR_ALLREDUCE_COMBINE_OPTIMIZATION_TOPOLOGICAL" + "_DISTANCE, using the default value 0."; + return 0; + } else if (topo_dist < 0) { + LOG(WARNING) << "Invalid DTENSOR_ALLREDUCE_COMBINE_OPTIMIZATION_TOPOLOGICAL" + "_DISTANCE, value must be a positive integer, using the " + "default value 0."; + return 0; + } + return topo_dist; +} + bool EnableMultiDeviceMode() { bool multi_device_mode; absl::Status status = tsl::ReadBoolFromEnvVar( diff --git a/tensorflow/dtensor/cc/dtensor_utils.h b/tensorflow/dtensor/cc/dtensor_utils.h index 06626ddcda6bd1..b89a85cd5ec5b6 100644 --- a/tensorflow/dtensor/cc/dtensor_utils.h +++ b/tensorflow/dtensor/cc/dtensor_utils.h @@ -66,9 +66,24 @@ bool EnableReplicatedSpmdAsDefault(const std::string& op_name); // Returns whether to use all-to-all collective for relayout when possible. bool EnableAllToAllForRelayout(); -// Returns the maximum number of AllReduce ops to merge into a group. +// Returns the maximum number of AllReduce ops to merge into a group. This value +// determines the AllReduce grouping in dtensor_allreduce_combine_optimization. +// The input value should be in range of [0, INT_MAX]. It is advised to pick +// a value based on knowledge of the total number of AllReduces. When the value +// is too big, the behaviour will act as aggressive grouping. When the value is +// too small, the behaviour will act as having no extended grouping. int AllReduceCombineOptimizationGroupSize(); +// Returns the maximum topological distance between two AllReduce ops to merge +// into a single AllReduce. This value is used to determine AllReduce grouping +// in dtensor_allreduce_combine_optimization. The input value should be in range +// of [0, INT_MAX]. However, it is advised to select a value based on knowledge +// of the compute graph, such as the minimum distance between two model layers. +// When the input value is too big, the behaviour will act as aggressive group- +// ing. When the input value is too small, the behaviour will act as having no +// extended grouping. +int AllReduceCombineOptimizationTopologicalDistance(); + // Returns whether to perform multi-device expansion. bool EnableMultiDeviceMode(); } // namespace dtensor diff --git a/tensorflow/dtensor/mlir/dtensor_allreduce_combine_optimization.cc b/tensorflow/dtensor/mlir/dtensor_allreduce_combine_optimization.cc index f4e9ee9d3beed7..11099801b45ae7 100644 --- a/tensorflow/dtensor/mlir/dtensor_allreduce_combine_optimization.cc +++ b/tensorflow/dtensor/mlir/dtensor_allreduce_combine_optimization.cc @@ -546,6 +546,10 @@ createSubgroupsByGroupAssignment( // Experimental extended grouping logics to avoid aggressive grouping. // This function performs the same grouping method as tf.distribute, which group // all reduce ops by user defined group size (number of ops) in the input order. +// Note that group_size will be in range of [0, INT_MAX]. It is advised to pick +// a value based on knowledge of the total number of AllReduces. When group_size +// is too big, the function will act as aggressive grouping. When group_size is +// too small, the function will act as having no extended grouping. std::vector> createSubgroupsByExtendedNumOps( std::vector> all_reduce_groups, @@ -580,6 +584,103 @@ createSubgroupsByExtendedNumOps( return all_reduce_new_groups; } +// Experimental grouping logics to optimize from aggressive grouping. +// This function first sort by topological level, then create AllReduce sub- +// groups by accessing each topological distance from its previous AllReduce. +// Note that topo_dist will be in range of [0, INT_MAX]. It is advised to select +// a value based on knowledge of the compute graph, such as the minimum distance +// between two model layers. When topo_dist is too big, the function will act +// as aggressive grouping. When topo_dist is too small, the function will act as +// having no extended grouping. +StatusOr>> +createSubgroupsByTopoDist( + std::vector> all_reduce_groups, + llvm::DenseMap all_reduce_topo, + int topo_dist) { + // Disable extended grouping if topological distance is set to zero or less + if (topo_dist <= 0) return all_reduce_groups; + VLOG(4) << "current number of groups: " << all_reduce_groups.size(); + std::vector> all_reduce_new_groups; + + // Further break down the current all_reduced_groups by topological distance + // between two ops + for (auto& all_reduce_group : all_reduce_groups) { + std::vector new_group; + Status status = absl::OkStatus(); + + // Sort AllReduces by topological level as the input order may not reflect + // their dependencies on the operands in the compute graph. + std::sort(all_reduce_group.begin(), all_reduce_group.end(), + [&all_reduce_topo, &status](mlir::TF::DTensorAllReduceOp& lhs, + mlir::TF::DTensorAllReduceOp& rhs) { + if ((all_reduce_topo.find(lhs) == all_reduce_topo.end()) || + (all_reduce_topo.find(rhs) == all_reduce_topo.end())) { + status = absl::InternalError( + "Error: encounter AllReduce op with no topological level" + " assignment."); + return false; + } + return all_reduce_topo[lhs] < all_reduce_topo[rhs]; + }); + // Unable to sort AllReduces based on topological level due to error. Return + // directly as we are not able to group based on incorrect/partial topology. + if (!status.ok()) return status; + + // Form AllReduce groups based on the topological distance between ops + DCHECK(!all_reduce_group.empty()); + int prev_topo_level = all_reduce_topo[all_reduce_group[0]]; + for (const auto& all_reduce : all_reduce_group) { + DCHECK(all_reduce_topo.find(all_reduce) != all_reduce_topo.end()); + int cur_topo_level = all_reduce_topo[all_reduce]; + if (abs(cur_topo_level - prev_topo_level) <= topo_dist) { + new_group.push_back(all_reduce); + } else { + // Start a new group + all_reduce_new_groups.push_back( + std::vector(new_group.begin(), + new_group.end())); + new_group.clear(); + new_group.push_back(all_reduce); + } + prev_topo_level = cur_topo_level; + } + all_reduce_new_groups.push_back(new_group); + } + VLOG(4) << "new number of groups: " << all_reduce_new_groups.size(); + return all_reduce_new_groups; +} + +// Compute the topological level for each AllReduce op in a cluster. The level +// is defined as 1 + max operands' depth in the compute graph. If an op do not +// depend on any input/operand, then it is level 0. +llvm::DenseMap computeAllReduceTopoLevel( + mlir::tf_device::ClusterOp cluster) { + llvm::DenseMap op_topo_level; + llvm::DenseMap all_reduce_topo; + + // Compute topological level for each op. + cluster.getBody().walk([&](mlir::Operation* op) { + int max_depth = 0; + for (mlir::Value operand : op->getOperands()) { + if (mlir::Operation* operand_op = operand.getDefiningOp()) { + if (op_topo_level.find(operand_op) != op_topo_level.end()) { + max_depth = fmax(max_depth, op_topo_level[operand_op]); + } + } + } + op_topo_level[op] = max_depth + 1; + + // Save the AllReduce topological level + mlir::TF::DTensorAllReduceOp all_reduce = + llvm::dyn_cast(op); + if (all_reduce && !all_reduce.getDeviceType().contains("TPU")) { + all_reduce_topo[all_reduce] = op_topo_level[op]; + } + }); + + return all_reduce_topo; +} + struct DTensorAllReduceCombineOptimization : public impl::DTensorAllReduceCombineOptimizationBase< DTensorAllReduceCombineOptimization> { @@ -590,11 +691,10 @@ struct DTensorAllReduceCombineOptimization std::vector ordered_all_reduces; std::vector ordered_blocks; llvm::DenseSet blocks; - cluster.GetBody().walk([&](mlir::TF::DTensorAllReduceOp all_reduce) { if (!all_reduce.getDeviceType().contains("TPU")) { // Only combine all reduces for GPU and CPU - auto all_reduce_ranked_type = + mlir::RankedTensorType all_reduce_ranked_type = all_reduce.getType().dyn_cast(); if (all_reduce_ranked_type && @@ -621,15 +721,37 @@ struct DTensorAllReduceCombineOptimization all_reduce_groups = createSubgroupsByReductionAttr(all_reduce_groups); all_reduce_groups = createSubgroupsByGroupAssignment(all_reduce_groups); - // Experimental extended grouping - int group_size = 0; + // Experimental extended grouping: topological distance + if (module->hasAttrOfType( + kAllReduceTopologicalDistance)) { + llvm::DenseMap all_reduce_topo = + computeAllReduceTopoLevel(cluster); + + StatusOr>> + group = createSubgroupsByTopoDist( + all_reduce_groups, all_reduce_topo, + module + ->getAttrOfType( + kAllReduceTopologicalDistance) + .getInt()); + if (!group.ok()) { + // This is a non-fatal error since topological level distance is one + // of the optimizations in this combiner pass. Output an error and + // continue with the rest of the grouping optimization. + LOG(WARNING) << "Failed to create subgroups using topological " + << "level distance: " << group.status(); + } else { + all_reduce_groups = group.value(); + } + } + + // Experimental extended grouping: fixed number of AllReduce ops if (module->hasAttrOfType(kAllReduceNumOpsInGroup)) { - group_size = + all_reduce_groups = createSubgroupsByExtendedNumOps( + all_reduce_groups, module->getAttrOfType(kAllReduceNumOpsInGroup) - .getInt(); + .getInt()); } - all_reduce_groups = - createSubgroupsByExtendedNumOps(all_reduce_groups, group_size); // Maintain relative order of ALLReduces within the block. std::sort(all_reduce_groups.begin(), all_reduce_groups.end(), diff --git a/tensorflow/dtensor/mlir/tests/dtensor_allreduce_combine_optimization.mlir b/tensorflow/dtensor/mlir/tests/dtensor_allreduce_combine_optimization.mlir index 5da8b06d804c5f..de72a60589b4dc 100644 --- a/tensorflow/dtensor/mlir/tests/dtensor_allreduce_combine_optimization.mlir +++ b/tensorflow/dtensor/mlir/tests/dtensor_allreduce_combine_optimization.mlir @@ -198,4 +198,78 @@ module attributes {dtensor.all_reduce_combiner.num_ops_in_group = 2} { }) : () -> tensor<4x4xf32> "func.return"() : () -> () } +} + +// ----- +module attributes {dtensor.all_reduce_combiner.topological_distance = 2} { + // Check that when topologicial grouping is enabled in AllReduce combiner, the + // independent DTensorAllReduce ops of the same element type and group assign- + // ment are combined according to the topological distance between two ops. + // + // The following scenario would have 1 group of 7 AllReduces when topological + // distance is *not* set. + // - level 1: %4, %5 (case: <= topo_dist, simple case with same level) + // - level 2: %7 (case: <= topo_dist, simple case for eligible to group) + // - level 4: %16 (case: <= topo_dist, out of order, test for topo sort) + // - level 5: %15 (case: < topo_dist, out of order, test for topo sort) + // - level 8: %14 (case: > topo_dist, ineligible to group and out of order), + // %17 (case: > topo_dist, ineligible to group with 1st group, + // but should get grouped with %14) + // + // Detailed level computations are listed in the test below. + // + // With topological_distance set to 2, we expect the following grouping result + // - group 1: %4, %5, %7, %15, %16 + // - group 2: %14, %17 + // + // Note use of dummy AllReduces (with the same input) gaurantees ops to be + // grouped together if topologicial grouping is not enabled. + // + // CHECK-LABEL: func @main + func.func @main() { + // CHECK: %[[ALL_REDUCE_1:.*]] = "tf.DTensorAllReduce" + // CHECK-SAME: (tensor<1024xf32>, tensor<2x2xi32>) -> tensor<1024xf32> + // CHECK: %[[ALL_REDUCE_2:.*]] = "tf.DTensorAllReduce" + // CHECK-SAME: (tensor<1024xf32>, tensor<2x2xi32>) -> tensor<1024xf32> + // CHECK: %[[ADD:.*]] = "tf.Add" + // CHECK-SAME: (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> + %0 = "tf_device.cluster"() ({ + // topological level 0 for all tf.Const + %1 = "tf.Const"() {value = dense<0.0> : tensor<4x4xf32>} : () -> tensor<4x4xf32> + %2 = "tf.Const"() {value = dense<[[0, 1], [2, 3]]> : tensor<2x2xi32>} : () -> tensor<2x2xi32> + %3 = "tf.Const"() {value = dense<1.0> : tensor<4x4xf32>} : () -> tensor<4x4xf32> + // %4 topological_level: 1 = max(0, 0) + 1 + %4 = "tf.DTensorAllReduce"(%1, %2) {_layout = ["sharding_specs:x,y, mesh:|x=2,y=2|*GPU"], device_type = "GPU", reduce_op = "Add"} : (tensor<4x4xf32>, tensor<2x2xi32>) -> tensor<4x4xf32> + // %5 topological_level: 1 = max(0, 0) + 1 + %5 = "tf.DTensorAllReduce"(%3, %2) {_layout = ["sharding_specs:x,y, mesh:|x=2,y=2|*GPU"], device_type = "GPU", reduce_op = "Add"} : (tensor<4x4xf32>, tensor<2x2xi32>) -> tensor<4x4xf32> + // %6 topological_level: 1 = max(0, 0) + 1 + %6 = "tf.Add"(%1, %3) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> + // %7 topological_level: 2 = max(1, 0) + 1 + %7 = "tf.DTensorAllReduce"(%6, %2) {_layout = ["sharding_specs:x,y, mesh:|x=2,y=2|*GPU"], device_type = "GPU", reduce_op = "Add"} : (tensor<4x4xf32>, tensor<2x2xi32>) -> tensor<4x4xf32> + // Dummy Adds to construct depth in compute graph + // %8 topological_level: 2 = max(1, 0) + 1 + // %9 topological_level: 3 = max(2, 0) + 1 + // %10 topological_level: 4 = max(3, 0) + 1 + // %11 topological_level: 5 = max(4, 0) + 1 + // %12 topological_level: 6 = max(5, 0) + 1 + // %13 topological_level: 7 = max(6, 0) + 1 + %8 = "tf.Add"(%6, %3) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> + %9 = "tf.Add"(%8, %1) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> + %10 = "tf.Add"(%9, %3) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> + %11 = "tf.Add"(%10, %1) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> + %12 = "tf.Add"(%11, %3) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> + %13 = "tf.Add"(%12, %1) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> + // %14 topological_level: 8 = max(7, 0) + 1 + %14 = "tf.DTensorAllReduce"(%13, %2) {_layout = ["sharding_specs:x,y, mesh:|x=2,y=2|*GPU"], device_type = "GPU", reduce_op = "Add"} : (tensor<4x4xf32>, tensor<2x2xi32>) -> tensor<4x4xf32> + // %15 topological_level: 5 = max(4, 0) + 1 + %15 = "tf.DTensorAllReduce"(%10, %2) {_layout = ["sharding_specs:x,y, mesh:|x=2,y=2|*GPU"], device_type = "GPU", reduce_op = "Add"} : (tensor<4x4xf32>, tensor<2x2xi32>) -> tensor<4x4xf32> + // %16 topological_level: 4 = max(3, 0) + 1 + %16 = "tf.DTensorAllReduce"(%9, %2) {_layout = ["sharding_specs:x,y, mesh:|x=2,y=2|*GPU"], device_type = "GPU", reduce_op = "Add"} : (tensor<4x4xf32>, tensor<2x2xi32>) -> tensor<4x4xf32> + // %17 topological_level: 8 = max(7, 0) + 1 + %17 = "tf.DTensorAllReduce"(%13, %2) {_layout = ["sharding_specs:x,y, mesh:|x=2,y=2|*GPU"], device_type = "GPU", reduce_op = "Add"} : (tensor<4x4xf32>, tensor<2x2xi32>) -> tensor<4x4xf32> + %18 = "tf.Add"(%15, %7) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> + "tf_device.return"(%18) : (tensor<4x4xf32>) -> () + }) : () -> tensor<4x4xf32> + "func.return"() : () -> () + } } \ No newline at end of file From 84d2f56ff62ba208d4ed4e51c5402758189c2215 Mon Sep 17 00:00:00 2001 From: Yu Feng Date: Mon, 10 Jul 2023 15:37:55 -0700 Subject: [PATCH 090/376] Internal build system changes. PiperOrigin-RevId: 547002118 --- tensorflow/dtensor/python/tests/BUILD | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/tensorflow/dtensor/python/tests/BUILD b/tensorflow/dtensor/python/tests/BUILD index 27a60cadc3d082..00b4303a37c360 100644 --- a/tensorflow/dtensor/python/tests/BUILD +++ b/tensorflow/dtensor/python/tests/BUILD @@ -305,23 +305,6 @@ pytype_strict_library( ], ) -pytype_strict_library( - name = "test_backend_util_oss", - srcs = ["test_backend_util.oss.py"], - deps = [ - ":test_util", - "//tensorflow/dtensor/python:config", - "//tensorflow/dtensor/python:layout", - "//tensorflow/dtensor/python:tpu_util", - "//tensorflow/python/platform:client_testlib", - ], -) - -pytype_strict_library( - name = "test_backend_name_oss", - srcs = ["test_backend_name.oss.py"], -) - dtensor_test( name = "multi_client_test", srcs = ["multi_client_test.py"], From 342c297be9b8750dd88b15e166f08b9a672c94ac Mon Sep 17 00:00:00 2001 From: David Dunleavy Date: Mon, 10 Jul 2023 15:50:04 -0700 Subject: [PATCH 091/376] [NFC] Change uses of get_compatible_with_cloud to get_compatible_with_portable. PiperOrigin-RevId: 547005099 --- tensorflow/compiler/mlir/lite/BUILD | 38 +++++++-------- .../mlir/lite/experimental/common/BUILD | 4 +- .../compiler/mlir/lite/experimental/tac/BUILD | 10 ++-- .../compiler/mlir/lite/quantization/BUILD | 6 +-- .../compiler/mlir/lite/quantization/ir/BUILD | 10 ++-- .../mlir/lite/quantization/tensorflow/BUILD | 6 +-- .../mlir/lite/stablehlo/serializer/BUILD | 6 +-- .../mlir/quantization/stablehlo/BUILD | 16 +++---- .../mlir/quantization/tensorflow/BUILD | 48 +++++++++---------- .../quantization/tensorflow/calibrator/BUILD | 8 ++-- .../mlir/quantization/tensorflow/cc/BUILD | 14 +++--- .../quantization/tensorflow/debugging/BUILD | 6 +-- .../mlir/quantization/tensorflow/python/BUILD | 6 +-- .../mlir/quantization/tensorflow/utils/BUILD | 12 ++--- tensorflow/compiler/mlir/tensorflow/BUILD | 44 ++++++++--------- tensorflow/compiler/mlir/tfr/BUILD | 8 ++-- tensorflow/compiler/mlir/tfrt/BUILD | 4 +- .../compiler/mlir/tfrt/jit/transforms/BUILD | 8 ++-- .../compiler/mlir/tools/kernel_gen/ir/BUILD | 8 ++-- .../mlir/tools/kernel_gen/transforms/BUILD | 14 +++--- tensorflow/compiler/mlir/tosa/BUILD | 20 ++++---- 21 files changed, 148 insertions(+), 148 deletions(-) diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD index 429d1da9754f4a..af0cf97c31b64d 100644 --- a/tensorflow/compiler/mlir/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/BUILD @@ -1,4 +1,4 @@ -load("//tensorflow:tensorflow.default.bzl", "filegroup", "get_compatible_with_cloud") +load("//tensorflow:tensorflow.default.bzl", "filegroup", "get_compatible_with_portable") load("@bazel_skylib//rules:build_test.bzl", "build_test") load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test", "tf_native_cc_binary") @@ -39,7 +39,7 @@ td_library( "ir/tfl_op_interfaces.td", "ir/tfl_ops.td", ], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "//tensorflow/compiler/mlir/lite/quantization:quantization_td_files", "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_td_files", @@ -63,7 +63,7 @@ td_library( "transforms/tensorlist_patterns.td", "utils/utils.td", ], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), visibility = ["//visibility:private"], deps = [ ":tensorflow_lite_ops_td_files", @@ -76,7 +76,7 @@ td_library( gentbl_cc_library( name = "tensorflow_lite_passes_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( [ @@ -95,7 +95,7 @@ gentbl_cc_library( gentbl_cc_library( name = "tensorflow_lite_ops_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["-gen-op-decls"], @@ -122,7 +122,7 @@ gentbl_cc_library( gentbl_cc_library( name = "tensorflow_lite_op_interfaces_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["-gen-op-interface-decls"], @@ -150,7 +150,7 @@ gentbl_cc_library( gentbl_cc_library( name = "tensorflow_lite_op_enums_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["-gen-enum-decls"], @@ -178,7 +178,7 @@ gentbl_cc_library( gentbl_cc_library( name = "tensorflow_lite_prepare_tf_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["-gen-rewriters"], @@ -192,7 +192,7 @@ gentbl_cc_library( gentbl_cc_library( name = "tensorflow_lite_lower_static_tensor_list_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["-gen-rewriters"], @@ -206,7 +206,7 @@ gentbl_cc_library( gentbl_cc_library( name = "tensorflow_lite_legalize_tf_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["-gen-rewriters"], @@ -220,7 +220,7 @@ gentbl_cc_library( gentbl_cc_library( name = "tensorflow_lite_legalize_variables_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["-gen-rewriters"], @@ -234,7 +234,7 @@ gentbl_cc_library( gentbl_cc_library( name = "tensorflow_lite_optimize_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["-gen-rewriters"], @@ -248,7 +248,7 @@ gentbl_cc_library( gentbl_cc_library( name = "tensorflow_lite_quantize_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["-gen-rewriters"], @@ -262,7 +262,7 @@ gentbl_cc_library( gentbl_cc_library( name = "tensorflow_lite_post_quantize_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["-gen-rewriters"], @@ -276,7 +276,7 @@ gentbl_cc_library( gentbl_cc_library( name = "tensorflow_lite_legalize_tensorlist_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["-gen-rewriters"], @@ -852,7 +852,7 @@ filegroup( gentbl_cc_library( name = "op_quant_spec_getters_inc", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [([], "utils/generated_op_quant_spec_getters.inc")], tblgen = "//tensorflow/compiler/mlir/lite/quantization:op_quant_spec_getters_gen", td_file = "ir/tfl_ops.td", @@ -863,7 +863,7 @@ gentbl_cc_library( gentbl_cc_library( name = "tflite_op_coverage_spec_inc", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [([], "utils/tflite_op_coverage_spec.inc")], tblgen = "//tensorflow/compiler/mlir/lite/quantization:tflite_op_coverage_spec_getters_gen", td_file = "ir/tfl_ops.td", @@ -878,7 +878,7 @@ tf_native_cc_binary( srcs = [ "converter_gen.cc", ], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "@llvm-project//llvm:Support", "@llvm-project//llvm:TableGen", @@ -888,7 +888,7 @@ tf_native_cc_binary( gentbl_cc_library( name = "converter_inc", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["--gen-operator-converters"], diff --git a/tensorflow/compiler/mlir/lite/experimental/common/BUILD b/tensorflow/compiler/mlir/lite/experimental/common/BUILD index 02fab009fda976..c7e4fa006d868a 100644 --- a/tensorflow/compiler/mlir/lite/experimental/common/BUILD +++ b/tensorflow/compiler/mlir/lite/experimental/common/BUILD @@ -1,4 +1,4 @@ -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") # copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"]) @@ -6,7 +6,7 @@ cc_library( name = "outline_operations", srcs = ["outline_operations.cc"], hdrs = ["outline_operations.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), visibility = ["//visibility:public"], deps = [ "//tensorflow/compiler/mlir/lite:tensorflow_lite", diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/BUILD b/tensorflow/compiler/mlir/lite/experimental/tac/BUILD index b012bd60b154ed..b0514e13937956 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/BUILD +++ b/tensorflow/compiler/mlir/lite/experimental/tac/BUILD @@ -2,7 +2,7 @@ load("//tensorflow:strict.default.bzl", "py_strict_library") load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load("@flatbuffers//:build_defs.bzl", "flatbuffer_cc_library") load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test") -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") load( "@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", @@ -19,7 +19,7 @@ package( flatbuffer_cc_library( name = "runtime_metadata_fbs", srcs = ["runtime_metadata.fbs"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), ) cc_library( @@ -88,7 +88,7 @@ cc_library( gentbl_cc_library( name = "transform_patterns_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["-gen-rewriters"], @@ -386,11 +386,11 @@ py_strict_library( proto_library( name = "tac_filter_proto", srcs = ["tac_filter.proto"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), ) cc_proto_library( name = "tac_filter_cc_proto", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [":tac_filter_proto"], ) diff --git a/tensorflow/compiler/mlir/lite/quantization/BUILD b/tensorflow/compiler/mlir/lite/quantization/BUILD index 33d54e4e449d60..ec839967182c61 100644 --- a/tensorflow/compiler/mlir/lite/quantization/BUILD +++ b/tensorflow/compiler/mlir/lite/quantization/BUILD @@ -1,6 +1,6 @@ load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") load("//tensorflow:tensorflow.bzl", "tf_cc_test", "tf_native_cc_binary") -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") load("//tensorflow/core/platform:build_config.bzl", "tf_proto_library") load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") @@ -33,7 +33,7 @@ td_library( srcs = [ "quantization.td", ], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "//tensorflow/compiler/mlir/lite/quantization/ir:QuantizationOpsTdFiles", "@llvm-project//mlir:OpBaseTdFiles", @@ -42,7 +42,7 @@ td_library( gentbl_cc_library( name = "quantization_interfaces_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["-gen-op-interface-decls"], diff --git a/tensorflow/compiler/mlir/lite/quantization/ir/BUILD b/tensorflow/compiler/mlir/lite/quantization/ir/BUILD index dc1c7d841b5d05..727fb03d833964 100644 --- a/tensorflow/compiler/mlir/lite/quantization/ir/BUILD +++ b/tensorflow/compiler/mlir/lite/quantization/ir/BUILD @@ -1,6 +1,6 @@ load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -14,7 +14,7 @@ td_library( "QuantOps.td", "QuantOpsBase.td", ], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "@llvm-project//mlir:InferTypeOpInterfaceTdFiles", "@llvm-project//mlir:OpBaseTdFiles", @@ -25,7 +25,7 @@ td_library( gentbl_cc_library( name = "QuantOpsIncGen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["-gen-op-decls"], @@ -57,7 +57,7 @@ gentbl_cc_library( gentbl_cc_library( name = "QuantPassIncGen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( [ @@ -89,7 +89,7 @@ cc_library( "QuantizeUtils.h", "UniformSupport.h", ], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ ":QuantOpsIncGen", ":QuantPassIncGen", diff --git a/tensorflow/compiler/mlir/lite/quantization/tensorflow/BUILD b/tensorflow/compiler/mlir/lite/quantization/tensorflow/BUILD index 9652196367f398..cdefbdb1e28a4e 100644 --- a/tensorflow/compiler/mlir/lite/quantization/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/lite/quantization/tensorflow/BUILD @@ -1,6 +1,6 @@ load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -24,7 +24,7 @@ td_library( srcs = [ "fallback_to_flex_patterns.td", ], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), visibility = ["//visibility:private"], deps = [ "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_td_files", @@ -35,7 +35,7 @@ td_library( gentbl_cc_library( name = "ptq_fallback_to_flex_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["-gen-rewriters"], diff --git a/tensorflow/compiler/mlir/lite/stablehlo/serializer/BUILD b/tensorflow/compiler/mlir/lite/stablehlo/serializer/BUILD index c7e8f44563dd42..8910f55f81eb8b 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/serializer/BUILD +++ b/tensorflow/compiler/mlir/lite/stablehlo/serializer/BUILD @@ -1,4 +1,4 @@ -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -17,7 +17,7 @@ cc_library( "flatbuffer_operator.h", "flatbuffer_translator.h", ], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "//tensorflow/compiler/mlir:op_or_arg_name_mapper", "//tensorflow/compiler/mlir/tensorflow", @@ -48,7 +48,7 @@ cc_library( "flatbuffer_export.cc", ], hdrs = ["flatbuffer_export.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ ":flatbuffer_translator", "//tensorflow/compiler/mlir:op_or_arg_name_mapper", diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/BUILD b/tensorflow/compiler/mlir/quantization/stablehlo/BUILD index 50404aab815c4a..c3b045f19a86b0 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/BUILD +++ b/tensorflow/compiler/mlir/quantization/stablehlo/BUILD @@ -1,4 +1,4 @@ -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test") load("//tensorflow/core/platform:build_config.bzl", "tf_proto_library") load("//tensorflow/compiler/mlir/quantization/stablehlo:internal_visibility_allowlist.bzl", "internal_visibility_allowlist") @@ -36,7 +36,7 @@ cc_library( hdrs = [ "passes/passes.h", ], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ ":fill_quantization_options", ":quantization_options_proto_cc", @@ -63,7 +63,7 @@ cc_library( gentbl_cc_library( name = "bridge_passes_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( [ @@ -89,7 +89,7 @@ cc_library( hdrs = [ "passes/bridge/passes.h", ], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), visibility = [ "//tensorflow/compiler/mlir/lite:__subpackages__", "//tensorflow/compiler/mlir/tf2xla:__subpackages__", @@ -162,7 +162,7 @@ cc_library( hdrs = [ "quantize_passes.h", ], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), visibility = [":internal_visibility_allowlist_package"], deps = [ ":fill_quantization_options", @@ -180,7 +180,7 @@ cc_library( gentbl_cc_library( name = "stablehlo_passes_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( [ @@ -200,7 +200,7 @@ cc_library( name = "fill_quantization_options", srcs = ["utils/fill_quantization_options.cc"], hdrs = ["utils/fill_quantization_options.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ ":quantization_options_proto_cc", "//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes", @@ -213,7 +213,7 @@ cc_library( name = "math_utils", srcs = ["utils/math_utils.cc"], hdrs = ["utils/math_utils.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = ["@llvm-project//mlir:Support"], ) diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/BUILD b/tensorflow/compiler/mlir/quantization/tensorflow/BUILD index 1c9e70d1a260a0..d32fbcc7ed853c 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/quantization/tensorflow/BUILD @@ -1,6 +1,6 @@ load("//tensorflow:strict.default.bzl", "py_strict_binary") load("//tensorflow:tensorflow.bzl", "tf_cc_binary") -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") load("//tensorflow/core/platform:build_config.bzl", "tf_proto_library") load("//tensorflow/compiler/mlir/quantization/tensorflow:internal_visibility_allowlist.bzl", "internal_visibility_allowlist") load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") @@ -45,7 +45,7 @@ genrule( "passes/quantized_function_library.h", ], cmd = "$(location gen_quantized_function_library) --output_file $(RULEDIR)/passes/quantized_function_library.h --src '$(SRCS)'", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tools = ["gen_quantized_function_library"], ) @@ -57,7 +57,7 @@ cc_library( hdrs = [ "passes/utils.h", ], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ ":quantization_options_proto_cc", "//tensorflow/compiler/mlir/lite/quantization:quantization_lib", @@ -76,7 +76,7 @@ cc_library( hdrs = [ "passes/manipulate_model_attr.h", ], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "//tensorflow/compiler/mlir/tensorflow", "@llvm-project//llvm:Support", @@ -93,7 +93,7 @@ cc_library( hdrs = [ "passes/remove_identity_op_pattern.h", ], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "//tensorflow/compiler/mlir/tensorflow", "@llvm-project//mlir:IR", @@ -118,7 +118,7 @@ td_library( "passes/tf_quant_ops.td", "passes/utils.td", ], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "//tensorflow/compiler/mlir/lite:tensorflow_lite_ops_td_files", "//tensorflow/compiler/mlir/quantization/tensorflow/utils:lift_as_function_call_utils_td_files", @@ -130,7 +130,7 @@ td_library( gentbl_cc_library( name = "cast_bf16_ops_to_f32_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["-gen-rewriters"], @@ -144,7 +144,7 @@ gentbl_cc_library( gentbl_cc_library( name = "prepare_lifting_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["-gen-rewriters"], @@ -158,7 +158,7 @@ gentbl_cc_library( gentbl_cc_library( name = "lift_quantizable_spots_as_functions_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["-gen-rewriters"], @@ -172,7 +172,7 @@ gentbl_cc_library( gentbl_cc_library( name = "lift_quantizable_spots_as_functions_drq_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["-gen-rewriters"], @@ -186,7 +186,7 @@ gentbl_cc_library( gentbl_cc_library( name = "prepare_quantize_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["-gen-rewriters"], @@ -200,7 +200,7 @@ gentbl_cc_library( gentbl_cc_library( name = "quantize_composite_functions_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["-gen-rewriters"], @@ -214,7 +214,7 @@ gentbl_cc_library( gentbl_cc_library( name = "tf_quant_ops_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["-gen-op-decls"], @@ -234,7 +234,7 @@ gentbl_cc_library( gentbl_cc_library( name = "optimize_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["-gen-rewriters"], @@ -248,7 +248,7 @@ gentbl_cc_library( gentbl_cc_library( name = "convert_tpu_model_to_cpu_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["-gen-rewriters"], @@ -262,7 +262,7 @@ gentbl_cc_library( gentbl_cc_library( name = "post_quantize_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["-gen-rewriters"], @@ -276,7 +276,7 @@ gentbl_cc_library( gentbl_cc_library( name = "preprocess_op_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["-gen-rewriters"], @@ -296,7 +296,7 @@ cc_library( "passes/tf_quant_ops.h.inc", ], hdrs = ["passes/tf_quant_ops.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "//tensorflow/compiler/mlir/tensorflow:tensorflow_attributes", "//tensorflow/compiler/mlir/tensorflow:tensorflow_op_interfaces", @@ -324,7 +324,7 @@ cc_library( "ops/tf_op_quant_spec.cc", ], hdrs = ["ops/tf_op_quant_spec.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "//tensorflow/compiler/mlir/lite/quantization:quantization_config", "//tensorflow/compiler/mlir/lite/quantization:quantization_lib", @@ -340,7 +340,7 @@ cc_library( "ops/uniform_op_quant_spec.cc", ], hdrs = ["ops/uniform_op_quant_spec.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ ":tf_quant_ops", "//tensorflow/compiler/mlir/lite/quantization:quantization_config", @@ -353,7 +353,7 @@ cc_library( gentbl_cc_library( name = "replace_cast_hacks_with_tf_xla_ops_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["-gen-rewriters"], @@ -414,7 +414,7 @@ cc_library( "passes/constants.h", "passes/passes.h", ], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ ":manipulate_model_attr", ":pass_utils", @@ -489,7 +489,7 @@ cc_library( hdrs = [ "quantize_preprocess.h", ], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ ":passes", "//tensorflow/compiler/mlir/quantization/tensorflow/cc:run_passes", @@ -517,7 +517,7 @@ cc_library( hdrs = [ "quantize_passes.h", ], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ ":pass_utils", ":passes", diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/BUILD b/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/BUILD index d8a0c975aac48a..97553651573e31 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/BUILD +++ b/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/BUILD @@ -5,7 +5,7 @@ load( ) load( "//tensorflow:tensorflow.default.bzl", - "get_compatible_with_cloud", + "get_compatible_with_portable", "tf_kernel_library", "tf_py_strict_test", ) @@ -29,7 +29,7 @@ cc_library( name = "calibrator_singleton_impl", srcs = ["calibrator_singleton.cc"], hdrs = ["calibrator_singleton.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), visibility = ["//visibility:private"], deps = [ "@com_google_absl//absl/container:flat_hash_map", @@ -42,7 +42,7 @@ cc_library( cc_library( name = "calibrator_singleton", hdrs = ["calibrator_singleton.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = if_static([":calibrator_singleton_impl"]) + [ "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", @@ -65,7 +65,7 @@ tf_cc_test( tf_kernel_library( name = "custom_aggregator_op", srcs = ["custom_aggregator_op.cc"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), visibility = [ "//tensorflow:__pkg__", "//tensorflow/compiler/mlir/quantization/tensorflow/python:__pkg__", diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/cc/BUILD b/tensorflow/compiler/mlir/quantization/tensorflow/cc/BUILD index f121b219d2ed4a..86e6efc5ec43a2 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/cc/BUILD +++ b/tensorflow/compiler/mlir/quantization/tensorflow/cc/BUILD @@ -1,6 +1,6 @@ load( "//tensorflow:tensorflow.default.bzl", - "get_compatible_with_cloud", + "get_compatible_with_portable", ) load( "//tensorflow:tensorflow.bzl", @@ -21,7 +21,7 @@ cc_library( name = "save_variables", srcs = ["save_variables.cc"], hdrs = ["save_variables.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/core:framework", @@ -68,7 +68,7 @@ cc_library( name = "const_op_size", srcs = ["const_op_size.cc"], hdrs = ["const_op_size.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:tensorflow_remaining_ops", @@ -97,7 +97,7 @@ cc_library( name = "convert_asset_args", srcs = ["convert_asset_args.cc"], hdrs = ["convert_asset_args.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:import_model", @@ -128,7 +128,7 @@ tf_cc_test( cc_library( name = "status_macro", hdrs = ["status_macro.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "//tensorflow/tsl/platform:macros", "@com_google_absl//absl/status", @@ -150,7 +150,7 @@ cc_library( name = "run_passes", srcs = ["run_passes.cc"], hdrs = ["run_passes.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "//tensorflow/compiler/mlir/quantization/tensorflow/debugging:mlir_dump", "//tensorflow/compiler/mlir/tensorflow:error_util", @@ -170,7 +170,7 @@ cc_library( hdrs = [ "constant_fold.h", ], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "//tensorflow/compiler/mlir/quantization/tensorflow/utils:lift_as_function_call_utils", "//tensorflow/compiler/mlir/tensorflow", diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/debugging/BUILD b/tensorflow/compiler/mlir/quantization/tensorflow/debugging/BUILD index 879ccc88de0583..cd755f83e2d963 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/debugging/BUILD +++ b/tensorflow/compiler/mlir/quantization/tensorflow/debugging/BUILD @@ -1,5 +1,5 @@ load("//tensorflow:tensorflow.bzl", "tf_cc_test") -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -13,7 +13,7 @@ cc_library( name = "mlir_dump", srcs = ["mlir_dump.cc"], hdrs = ["mlir_dump.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "//tensorflow/tsl/platform:env", "//tensorflow/tsl/platform:path", @@ -31,7 +31,7 @@ cc_library( tf_cc_test( name = "mlir_dump_test", srcs = ["mlir_dump_test.cc"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ ":mlir_dump", "//tensorflow/tsl/platform:path", diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/BUILD b/tensorflow/compiler/mlir/quantization/tensorflow/python/BUILD index 38363200a62b4a..8faafa1c17770c 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/BUILD +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/BUILD @@ -5,7 +5,7 @@ load( ) load( "//tensorflow:tensorflow.default.bzl", - "get_compatible_with_cloud", + "get_compatible_with_portable", "tf_py_test", "tf_python_pybind_extension", ) @@ -25,7 +25,7 @@ cc_library( name = "quantize_model_cc_impl", srcs = ["quantize_model.cc"], hdrs = ["quantize_model.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), visibility = [ # Directly linked to `libtensorflow_cc.so` or # `_pywrap_tensorflow_internal.so` if static build. @@ -82,7 +82,7 @@ cc_library( cc_library( name = "quantize_model_cc", hdrs = ["quantize_model.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = if_static([":quantize_model_cc_impl"]) + [ "//tensorflow/compiler/mlir/quantization/tensorflow:exported_model_proto_cc", "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc", diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/utils/BUILD b/tensorflow/compiler/mlir/quantization/tensorflow/utils/BUILD index a01f881d88e04c..80955803f1cc5b 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/utils/BUILD +++ b/tensorflow/compiler/mlir/quantization/tensorflow/utils/BUILD @@ -1,5 +1,5 @@ load("@llvm-project//mlir:tblgen.bzl", "td_library") -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") load("//tensorflow:tensorflow.bzl", "tf_cc_test") package( @@ -14,7 +14,7 @@ cc_library( name = "fake_quant_utils", srcs = ["fake_quant_utils.cc"], hdrs = ["fake_quant_utils.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "//tensorflow/compiler/mlir/lite/quantization:quantization_lib", "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps", @@ -29,7 +29,7 @@ td_library( srcs = [ "lift_as_function_call_utils.td", ], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "@llvm-project//mlir:FuncTdFiles", ], @@ -39,7 +39,7 @@ cc_library( name = "lift_as_function_call_utils", srcs = ["lift_as_function_call_utils.cc"], hdrs = ["lift_as_function_call_utils.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "//tensorflow/compiler/mlir/lite/quantization:quantization_lib", "//tensorflow/compiler/mlir/quantization/tensorflow:pass_utils", @@ -56,7 +56,7 @@ cc_library( name = "tf_to_uniform_attribute_utils", srcs = ["tf_to_uniform_attribute_utils.cc"], hdrs = ["tf_to_uniform_attribute_utils.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "//tensorflow/compiler/mlir/lite/quantization:quantization_lib", "//tensorflow/compiler/mlir/quantization/tensorflow:pass_utils", @@ -73,7 +73,7 @@ cc_library( name = "tf_to_xla_attribute_utils", srcs = ["tf_to_xla_attribute_utils.cc"], hdrs = ["tf_to_xla_attribute_utils.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "//tensorflow/compiler/mlir/quantization/tensorflow:pass_utils", "//tensorflow/compiler/mlir/quantization/tensorflow/cc:constant_fold", diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index 60dcd27cff61d4..207ec7970aa1c1 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -1,4 +1,4 @@ -load("//tensorflow:tensorflow.default.bzl", "filegroup", "get_compatible_with_cloud") +load("//tensorflow:tensorflow.default.bzl", "filegroup", "get_compatible_with_portable") load("@bazel_skylib//rules:build_test.bzl", "build_test") load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") @@ -30,7 +30,7 @@ td_library( "ir/tf_ops.td", "ir/tfrt_ops.td", ], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "@llvm-project//mlir:CallInterfacesTdFiles", "@llvm-project//mlir:ControlFlowInterfacesTdFiles", @@ -43,7 +43,7 @@ td_library( gentbl_cc_library( name = "tensorflow_op_interfaces_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["-gen-op-interface-decls"], @@ -64,7 +64,7 @@ gentbl_cc_library( gentbl_cc_library( name = "tensorflow_struct_doc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["-gen-dialect-doc"], @@ -103,7 +103,7 @@ cc_library( gentbl_cc_library( name = "tensorflow_all_ops_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["-gen-op-decls"], @@ -123,7 +123,7 @@ gentbl_cc_library( gentbl_cc_library( name = "tensorflow_tfrt_ops_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["-gen-op-decls"], @@ -156,7 +156,7 @@ tf_ops_category_list = [ [[ gentbl_cc_library( name = "tensorflow_" + target["name"] + "_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( [ @@ -183,7 +183,7 @@ tf_ops_category_list = [ gentbl_cc_library( name = "tensorflow_remaining_ops_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( [ @@ -209,7 +209,7 @@ gentbl_cc_library( gentbl_cc_library( name = "tf_saved_model_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["-gen-op-decls"], @@ -235,7 +235,7 @@ gentbl_cc_library( gentbl_cc_library( name = "tensorflow_executor_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["-gen-op-decls"], @@ -266,7 +266,7 @@ gentbl_cc_library( gentbl_cc_library( name = "tensorflow_device_ops_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["-gen-op-decls"], @@ -294,7 +294,7 @@ gentbl_cc_library( gentbl_cc_library( name = "tensorflow_canonicalize_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["-gen-rewriters"], @@ -326,7 +326,7 @@ gentbl_cc_library( gentbl_cc_library( name = "hlo_legalize_tf_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["-gen-rewriters"], @@ -753,7 +753,7 @@ cc_library( gentbl_cc_library( name = "decompose_resource_ops_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["-gen-rewriters"], @@ -807,7 +807,7 @@ td_library( srcs = [ "transforms/rewrite_util.td", ], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "@llvm-project//mlir:OpBaseTdFiles", ], @@ -829,7 +829,7 @@ cc_library( gentbl_cc_library( name = "tf_data_optimization_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["-gen-rewriters"], @@ -1109,7 +1109,7 @@ cc_library( gentbl_cc_library( name = "tf_pass_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( [ @@ -1132,7 +1132,7 @@ gentbl_cc_library( gentbl_cc_library( name = "tf_device_pass_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( [ @@ -1155,7 +1155,7 @@ gentbl_cc_library( gentbl_cc_library( name = "tf_savedmodel_pass_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( [ @@ -1178,7 +1178,7 @@ gentbl_cc_library( gentbl_cc_library( name = "tf_test_passes_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( [ @@ -2233,7 +2233,7 @@ filegroup( gentbl_cc_library( name = "tensorflow_optimize_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["-gen-rewriters"], @@ -2338,7 +2338,7 @@ tf_gen_op_wrapper_py( # without linking any of the other tensorflow passes. gentbl_cc_library( name = "lower_tf_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["-gen-rewriters"], diff --git a/tensorflow/compiler/mlir/tfr/BUILD b/tensorflow/compiler/mlir/tfr/BUILD index 7fd0a0446d2e77..b18a7872d5e0db 100644 --- a/tensorflow/compiler/mlir/tfr/BUILD +++ b/tensorflow/compiler/mlir/tfr/BUILD @@ -4,7 +4,7 @@ load( "tf_cc_binary", "tf_cc_test", ) -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud", "tf_py_strict_test", "tf_python_pybind_extension") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable", "tf_py_strict_test", "tf_python_pybind_extension") load("//tensorflow/compiler/mlir/tfr:build_defs.bzl", "gen_op_libraries") load( "@llvm-project//mlir:tblgen.bzl", @@ -37,7 +37,7 @@ td_library( srcs = [ "ir/tfr_ops.td", ], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "//tensorflow/compiler/mlir/lite/quantization/ir:QuantizationOpsTdFiles", "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_td_files", @@ -52,7 +52,7 @@ td_library( gentbl_cc_library( name = "tfr_ops_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["-gen-op-decls"], @@ -72,7 +72,7 @@ gentbl_cc_library( gentbl_cc_library( name = "tfr_decompose_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["-gen-rewriters"], diff --git a/tensorflow/compiler/mlir/tfrt/BUILD b/tensorflow/compiler/mlir/tfrt/BUILD index 9032b75d11ee44..e76b65b6172349 100644 --- a/tensorflow/compiler/mlir/tfrt/BUILD +++ b/tensorflow/compiler/mlir/tfrt/BUILD @@ -8,7 +8,7 @@ load("//tensorflow:tensorflow.bzl", "if_google", "tf_cc_binary") load("@tf_runtime//:build_defs.bzl", "tfrt_cc_library", "tfrt_cc_test") # Note: keep the following lines separate due to the way copybara works -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud", "get_compatible_with_portable") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") # TF to TFRT kernels conversion. package( @@ -108,7 +108,7 @@ cc_library( name = "tf_jitrt_pipeline", srcs = ["jit/tf_jitrt_pipeline.cc"], hdrs = ["jit/tf_jitrt_pipeline.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "//tensorflow/compiler/jit:flags", "//tensorflow/compiler/mlir/tensorflow:tensorflow_passes", diff --git a/tensorflow/compiler/mlir/tfrt/jit/transforms/BUILD b/tensorflow/compiler/mlir/tfrt/jit/transforms/BUILD index 78b4f9d64d4a13..cffa4511b88d13 100644 --- a/tensorflow/compiler/mlir/tfrt/jit/transforms/BUILD +++ b/tensorflow/compiler/mlir/tfrt/jit/transforms/BUILD @@ -1,6 +1,6 @@ load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") load("@tf_runtime//:build_defs.bzl", "tfrt_cc_library") -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") # TF to TFRT kernels conversion. package( @@ -13,7 +13,7 @@ tfrt_cc_library( name = "tf_jitrt_clustering", srcs = ["tf_jitrt_clustering.cc"], hdrs = ["tf_jitrt_clustering.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "//tensorflow/compiler/jit:flags", "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops", @@ -27,7 +27,7 @@ tfrt_cc_library( gentbl_cc_library( name = "tf_jitrt_passes_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( [ @@ -55,7 +55,7 @@ cc_library( "tf_jitrt_passes.cc", ], hdrs = ["tf_jitrt_passes.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ ":tf_jitrt_clustering", ":tf_jitrt_passes_inc_gen", diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/ir/BUILD b/tensorflow/compiler/mlir/tools/kernel_gen/ir/BUILD index e4892ba1d2e3ce..05dbdcc437d562 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/ir/BUILD +++ b/tensorflow/compiler/mlir/tools/kernel_gen/ir/BUILD @@ -1,6 +1,6 @@ load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -18,7 +18,7 @@ td_library( "tf_framework_ops.td", "tf_status.td", ], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "@llvm-project//mlir:AllocationOpInterfaceTdFiles", "@llvm-project//mlir:ControlFlowInterfacesTdFiles", @@ -29,7 +29,7 @@ td_library( gentbl_cc_library( name = "tf_framework_ops_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["-gen-op-decls"], @@ -55,7 +55,7 @@ gentbl_cc_library( gentbl_cc_library( name = "tf_status_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["-gen-enum-decls"], diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD index 3aeaca1a5a1b48..1f29b8555cab9f 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD @@ -7,7 +7,7 @@ load( "@local_config_rocm//rocm:build_defs.bzl", "if_rocm_is_configured", ) -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -19,7 +19,7 @@ cc_library( name = "utils", srcs = ["utils.cc"], hdrs = ["utils.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", @@ -31,7 +31,7 @@ cc_library( name = "tf_framework_legalize_to_llvm", srcs = ["tf_framework_legalize_to_llvm.cc"], hdrs = ["rewriters.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ ":utils", "//tensorflow/compiler/mlir/tools/kernel_gen/ir:tf_framework_ops", @@ -47,7 +47,7 @@ cc_library( name = "bufferize", srcs = ["bufferize.cc"], hdrs = ["rewriters.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "//tensorflow/compiler/mlir/tools/kernel_gen/ir:tf_framework_ops", "@llvm-project//mlir:ArithDialect", @@ -66,7 +66,7 @@ cc_library( name = "embed_tf_framework", srcs = ["embed_tf_framework.cc"], hdrs = ["rewriters.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "//tensorflow/compiler/mlir/tools/kernel_gen/ir:tf_framework_ops", "@llvm-project//mlir:ControlFlowDialect", @@ -79,7 +79,7 @@ cc_library( gentbl_cc_library( name = "kernel_gen_passes_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [( [ "-gen-pass-decls", @@ -187,7 +187,7 @@ cc_library( "tf_to_jit_invocations.cc", ], hdrs = ["passes.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ ":bufferize", # buildcleaner: keep ":embed_tf_framework", # buildcleaner: keep diff --git a/tensorflow/compiler/mlir/tosa/BUILD b/tensorflow/compiler/mlir/tosa/BUILD index 586204d9594062..33eb74dc1f9560 100644 --- a/tensorflow/compiler/mlir/tosa/BUILD +++ b/tensorflow/compiler/mlir/tosa/BUILD @@ -3,7 +3,7 @@ # https://developer.mlplatform.org/w/tosa/ # https://github.com/llvm/llvm-project/blob/main/mlir/docs/Dialects/TOSA.md -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") # TODO: Tighten visibility once targets are at the right granularity. @@ -35,12 +35,12 @@ filegroup( srcs = [ "@llvm-project//mlir:TosaDialectTdFiles", ], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), ) gentbl_cc_library( name = "tosa_passes_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( [ @@ -63,7 +63,7 @@ cc_library( "transforms/passes.h", "transforms/passes.h.inc", ], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "//tensorflow/compiler/mlir/lite:tensorflow_lite", "@llvm-project//mlir:FuncDialect", @@ -82,7 +82,7 @@ cc_library( "transforms/legalize_common.h", "transforms/legalize_utils.h", ], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "//tensorflow/compiler/mlir/lite:tensorflow_lite", "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps", @@ -105,7 +105,7 @@ cc_library( gentbl_cc_library( name = "tosa_legalize_tf_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["-gen-rewriters"], @@ -135,7 +135,7 @@ cc_library( "tf_passes.h", "transforms/passes.h", ], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), visibility = [":friends"], deps = [ ":legalize_common", @@ -158,7 +158,7 @@ cc_library( gentbl_cc_library( name = "tosa_legalize_tfl_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["-gen-rewriters"], @@ -194,7 +194,7 @@ cc_library( "tfl_passes.h", "transforms/passes.h", ], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), visibility = [":friends"], deps = [ ":legalize_common", @@ -228,7 +228,7 @@ cc_library( "tf_tfl_passes.h", "transforms/passes.h", ], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), visibility = [":friends"], deps = [ ":legalize_common", From 2d2fdf5ed602c665fa86713a7af2abac7c7c2b63 Mon Sep 17 00:00:00 2001 From: Jie Sun Date: Mon, 10 Jul 2023 16:01:00 -0700 Subject: [PATCH 092/376] remove legacy code for handling legacy trace. PiperOrigin-RevId: 547007950 --- .../core/profiler/utils/hlo_proto_map.cc | 53 +++++++------------ 1 file changed, 18 insertions(+), 35 deletions(-) diff --git a/tensorflow/core/profiler/utils/hlo_proto_map.cc b/tensorflow/core/profiler/utils/hlo_proto_map.cc index 1cca0c182ede14..a0f90aaecb0d78 100644 --- a/tensorflow/core/profiler/utils/hlo_proto_map.cc +++ b/tensorflow/core/profiler/utils/hlo_proto_map.cc @@ -53,41 +53,24 @@ ParseHloProtosFromXSpace(const XSpace& space) { const XPlane* raw_plane = FindPlaneWithName(space, kMetadataPlaneName); if (raw_plane != nullptr) { XPlaneVisitor plane = tsl::profiler::CreateTfXPlaneVisitor(raw_plane); - if (raw_plane->stats_size() > 0) { - // Fallback for legacy aggregated XPlane. - // TODO(b/235990417): Remove after 06/14/2023. - plane.ForEachStat([&](const XStatVisitor& stat) { - if (stat.ValueCase() != XStat::kBytesValue) return; - auto hlo_proto = std::make_unique(); - absl::string_view byte_value = stat.BytesValue(); - if (hlo_proto->ParseFromArray(byte_value.data(), byte_value.size())) { - hlo_protos.emplace_back(stat.Id(), std::move(hlo_proto)); - } - }); - } else { - const XStatMetadata* hlo_proto_stat_metadata = - plane.GetStatMetadataByType(StatType::kHloProto); - if (hlo_proto_stat_metadata == nullptr) { - // Fallback for legacy XPlane. - // TODO(b/235990417): Remove after 06/14/2023. - hlo_proto_stat_metadata = plane.GetStatMetadata(StatType::kHloProto); - } - if (hlo_proto_stat_metadata != nullptr) { - plane.ForEachEventMetadata( - [&](const XEventMetadataVisitor& event_metadata) { - auto hlo_proto_stat = event_metadata.GetStat( - StatType::kHloProto, *hlo_proto_stat_metadata); - if (!hlo_proto_stat) return; - if (hlo_proto_stat->ValueCase() != XStat::kBytesValue) return; - auto hlo_proto = std::make_unique(); - absl::string_view byte_value = hlo_proto_stat->BytesValue(); - if (hlo_proto->ParseFromArray(byte_value.data(), - byte_value.size())) { - hlo_protos.emplace_back(event_metadata.Id(), - std::move(hlo_proto)); - } - }); - } + + const XStatMetadata* hlo_proto_stat_metadata = + plane.GetStatMetadataByType(StatType::kHloProto); + if (hlo_proto_stat_metadata != nullptr) { + plane.ForEachEventMetadata( + [&](const XEventMetadataVisitor& event_metadata) { + auto hlo_proto_stat = event_metadata.GetStat( + StatType::kHloProto, *hlo_proto_stat_metadata); + if (!hlo_proto_stat) return; + if (hlo_proto_stat->ValueCase() != XStat::kBytesValue) return; + auto hlo_proto = std::make_unique(); + absl::string_view byte_value = hlo_proto_stat->BytesValue(); + if (hlo_proto->ParseFromArray(byte_value.data(), + byte_value.size())) { + hlo_protos.emplace_back(event_metadata.Id(), + std::move(hlo_proto)); + } + }); } } return hlo_protos; From bb255ef0e62146bcd1b032a296fd2c4a8ad835a0 Mon Sep 17 00:00:00 2001 From: Fiona Lang Date: Mon, 10 Jul 2023 16:04:54 -0700 Subject: [PATCH 093/376] Update ops.Tensor references to //third_party/tensorflow/python/framework/tensor.py. PiperOrigin-RevId: 547009218 --- tensorflow/python/distribute/BUILD | 17 +- .../python/distribute/cross_device_ops.py | 10 +- .../distribute/cross_device_ops_test.py | 3 +- .../python/distribute/distribute_lib.py | 3 +- .../python/distribute/input_lib_test.py | 3 +- .../distribute/mirrored_strategy_test.py | 5 +- .../distribute/mirrored_variable_test.py | 3 +- .../parameter_server_strategy_test.py | 3 +- tensorflow/python/distribute/ps_values.py | 14 +- .../python/distribute/sharded_variable.py | 281 +++++++++++------- .../python/distribute/summary_op_util.py | 4 +- tensorflow/python/distribute/test_util.py | 3 +- tensorflow/python/distribute/values.py | 3 +- tensorflow/python/distribute/values_test.py | 7 +- .../python/distribute/values_v2_test.py | 3 +- 15 files changed, 231 insertions(+), 131 deletions(-) diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD index 71e67c5381d9a3..e61e481e4d3814 100644 --- a/tensorflow/python/distribute/BUILD +++ b/tensorflow/python/distribute/BUILD @@ -48,6 +48,7 @@ py_strict_library( "//tensorflow/python/framework:indexed_slices", "//tensorflow/python/framework:kernels", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_util", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:math_ops", @@ -158,6 +159,7 @@ py_strict_library( "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:indexed_slices", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/framework:tensor_util", "//tensorflow/python/ops:array_ops", @@ -172,7 +174,6 @@ py_strict_library( "//tensorflow/python/trackable:base", "//tensorflow/python/types:distribute", "//tensorflow/python/util:deprecation", - "//tensorflow/python/util:lazy_loader", "//tensorflow/python/util:nest", "//tensorflow/python/util:tf_decorator", "//tensorflow/python/util:tf_export", @@ -893,7 +894,7 @@ py_strict_library( srcs_version = "PY3", deps = [ ":distribute_lib", - "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_util", ], ) @@ -928,6 +929,7 @@ py_strict_library( "//tensorflow/python/framework:composite_tensor", "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_conversion_registry", "//tensorflow/python/framework:tensor_util", "//tensorflow/python/framework:type_spec", @@ -978,6 +980,7 @@ distribute_py_strict_test( "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:indexed_slices", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:resource_variable_ops", "//tensorflow/python/ops:variables", @@ -998,8 +1001,8 @@ py_strict_library( "//tensorflow/python/eager:context", "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_conversion_registry", - "//tensorflow/python/framework:tensor_spec", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:handle_data_util", "//tensorflow/python/ops:lookup_ops", @@ -1323,6 +1326,7 @@ distribute_py_strict_test( "//tensorflow/python/framework:extension_type", "//tensorflow/python/framework:ops", "//tensorflow/python/framework:sparse_tensor", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:test_lib", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:control_flow_ops", @@ -1421,6 +1425,7 @@ cuda_py_strict_test( "//tensorflow/python/framework:errors", "//tensorflow/python/framework:indexed_slices", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:collective_ops", "//tensorflow/python/ops:cond", @@ -1460,6 +1465,7 @@ py_strict_library( "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:indexed_slices", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_conversion_registry", "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/framework:type_spec", @@ -1598,6 +1604,7 @@ distribute_py_strict_test( "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", "//tensorflow/python/framework:sparse_tensor", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:test_lib", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:math_ops", @@ -2001,6 +2008,7 @@ cuda_py_strict_test( "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:func_graph", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/framework:tensor_util", "//tensorflow/python/framework:test_lib", @@ -2044,6 +2052,7 @@ cuda_py_strict_test( "//tensorflow/python/framework:errors", "//tensorflow/python/framework:func_graph", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:custom_gradient", "//tensorflow/python/ops:math_ops", @@ -2238,6 +2247,7 @@ cuda_py_strict_test( "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:errors", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_util", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:control_flow_ops", @@ -2452,6 +2462,7 @@ py_strict_library( "//tensorflow/python/eager:context", "//tensorflow/python/framework:config", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:array_ops_stack", "//tensorflow/python/util:nest", diff --git a/tensorflow/python/distribute/cross_device_ops.py b/tensorflow/python/distribute/cross_device_ops.py index 0e60c337a8dbfd..002cb1d41e8070 100644 --- a/tensorflow/python/distribute/cross_device_ops.py +++ b/tensorflow/python/distribute/cross_device_ops.py @@ -38,6 +38,7 @@ from tensorflow.python.framework import indexed_slices from tensorflow.python.framework import kernels from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops @@ -59,7 +60,8 @@ def check_destinations(destinations): """ # Calling bool() on a ResourceVariable is not allowed. if isinstance(destinations, - (resource_variable_ops.BaseResourceVariable, ops.Tensor)): + (resource_variable_ops.BaseResourceVariable, + tensor_lib.Tensor)): return bool(destinations.device) return bool(destinations) @@ -68,9 +70,9 @@ def validate_destinations(destinations): """Validates the `destination` is one of expected types.""" if not isinstance( destinations, - (value_lib.DistributedValues, ops.Tensor, indexed_slices.IndexedSlices, - ps_values.AggregatingVariable, six.string_types, - tpu_values.TPUMirroredVariable + (value_lib.DistributedValues, tensor_lib.Tensor, + indexed_slices.IndexedSlices, ps_values.AggregatingVariable, + six.string_types, tpu_values.TPUMirroredVariable )) and not resource_variable_ops.is_resource_variable(destinations): raise ValueError("destinations must be one of a `DistributedValues` object," " a tf.Variable object, or a device string.") diff --git a/tensorflow/python/distribute/cross_device_ops_test.py b/tensorflow/python/distribute/cross_device_ops_test.py index dc0c5aad701ba9..dca6886ba25619 100644 --- a/tensorflow/python/distribute/cross_device_ops_test.py +++ b/tensorflow/python/distribute/cross_device_ops_test.py @@ -42,6 +42,7 @@ from tensorflow.python.framework import errors from tensorflow.python.framework import indexed_slices from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.ops import array_ops from tensorflow.python.ops import collective_ops from tensorflow.python.ops import cond @@ -204,7 +205,7 @@ def as_list(self, value): Returns: A list of `Tensor` or `IndexedSlices`. """ - if isinstance(value, ops.Tensor): + if isinstance(value, tensor_lib.Tensor): return [value] elif isinstance(value, IndexedSlices): return [value] diff --git a/tensorflow/python/distribute/distribute_lib.py b/tensorflow/python/distribute/distribute_lib.py index 66d1bb2ac4b1ff..4bd0cb86e345ed 100644 --- a/tensorflow/python/distribute/distribute_lib.py +++ b/tensorflow/python/distribute/distribute_lib.py @@ -214,6 +214,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import indexed_slices from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops @@ -3926,7 +3927,7 @@ class ReplicaContextV1(ReplicaContextBase): def _batch_reduce_destination(x): """Returns the destinations for batch all-reduce.""" - if isinstance(x, ops.Tensor): + if isinstance(x, tensor_lib.Tensor): # If this is a one device strategy. return x.device else: diff --git a/tensorflow/python/distribute/input_lib_test.py b/tensorflow/python/distribute/input_lib_test.py index 38b8c95529af1a..9b999e882184ee 100644 --- a/tensorflow/python/distribute/input_lib_test.py +++ b/tensorflow/python/distribute/input_lib_test.py @@ -47,6 +47,7 @@ from tensorflow.python.framework import extension_type from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor from tensorflow.python.framework import test_util as framework_test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops @@ -1982,7 +1983,7 @@ def testDistributeDatasetFromFunctionNested(self, distribution): num_replicas_in_sync=num_workers)) class InnerType(extension_type.ExtensionType): - tensor: ops.Tensor + tensor: tensor.Tensor class OuterType(extension_type.ExtensionType): inner: InnerType diff --git a/tensorflow/python/distribute/mirrored_strategy_test.py b/tensorflow/python/distribute/mirrored_strategy_test.py index 39912b62cb4b55..584565e38b7b5d 100644 --- a/tensorflow/python/distribute/mirrored_strategy_test.py +++ b/tensorflow/python/distribute/mirrored_strategy_test.py @@ -46,6 +46,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import func_graph from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util from tensorflow.python.framework import test_util as util @@ -1548,14 +1549,14 @@ def f(): def _replica_id(): replica_id = distribute_lib.get_replica_context().replica_id_in_sync_group - if not isinstance(replica_id, ops.Tensor): + if not isinstance(replica_id, tensor_lib.Tensor): replica_id = constant_op.constant(replica_id) return array_ops.identity(replica_id) def _replica_id_as_int(): replica_id = distribute_lib.get_replica_context().replica_id_in_sync_group - if isinstance(replica_id, ops.Tensor): + if isinstance(replica_id, tensor_lib.Tensor): replica_id = tensor_util.constant_value(replica_id) return replica_id diff --git a/tensorflow/python/distribute/mirrored_variable_test.py b/tensorflow/python/distribute/mirrored_variable_test.py index 169ba5211397c0..89bae84219418c 100644 --- a/tensorflow/python/distribute/mirrored_variable_test.py +++ b/tensorflow/python/distribute/mirrored_variable_test.py @@ -31,6 +31,7 @@ from tensorflow.python.framework import errors from tensorflow.python.framework import func_graph from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.ops import array_ops from tensorflow.python.ops import custom_gradient from tensorflow.python.ops import math_ops @@ -46,7 +47,7 @@ def _replica_id(): replica_id = distribute_lib.get_replica_context().replica_id_in_sync_group - if not isinstance(replica_id, ops.Tensor): + if not isinstance(replica_id, tensor.Tensor): replica_id = constant_op.constant(replica_id) return replica_id diff --git a/tensorflow/python/distribute/parameter_server_strategy_test.py b/tensorflow/python/distribute/parameter_server_strategy_test.py index d1818933264f52..3ce0aaa3e2d279 100644 --- a/tensorflow/python/distribute/parameter_server_strategy_test.py +++ b/tensorflow/python/distribute/parameter_server_strategy_test.py @@ -41,6 +41,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops @@ -60,7 +61,7 @@ def _get_replica_id_integer(): replica_id = distribute_lib.get_replica_context().replica_id_in_sync_group - if isinstance(replica_id, ops.Tensor): + if isinstance(replica_id, tensor.Tensor): replica_id = tensor_util.constant_value(replica_id) return replica_id diff --git a/tensorflow/python/distribute/ps_values.py b/tensorflow/python/distribute/ps_values.py index 3a866516b2faa1..73b49c8937fc3e 100644 --- a/tensorflow/python/distribute/ps_values.py +++ b/tensorflow/python/distribute/ps_values.py @@ -30,8 +30,8 @@ from tensorflow.python.eager import context from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.framework import tensor_conversion_registry -from tensorflow.python.framework import tensor_spec from tensorflow.python.ops import array_ops from tensorflow.python.ops import handle_data_util from tensorflow.python.ops import lookup_ops @@ -494,7 +494,7 @@ def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False): @classmethod def _overload_overloadable_operators(cls): """Register overloads for all operators.""" - for operator in ops.Tensor.OVERLOADABLE_OPERATORS: + for operator in tensor.Tensor.OVERLOADABLE_OPERATORS: # Overloading __eq__ or __ne__ does not work as expected. if operator == "__eq__" or operator == "__ne__": continue @@ -502,8 +502,8 @@ def _overload_overloadable_operators(cls): @classmethod def _tensor_overload_operator(cls, operator): - """Delegate an operator overload to `ops.Tensor`.""" - tensor_operator = getattr(ops.Tensor, operator) + """Delegate an operator overload to `tensor.Tensor`.""" + tensor_operator = getattr(tensor.Tensor, operator) def _operator(v, *args, **kwargs): return tensor_operator(v.value(), *args, **kwargs) # pylint: disable=protected-access @@ -655,7 +655,7 @@ def read_all(self): return [wv.get() for wv in self._per_worker_vars._values] # pylint: disable=protected-access -class PerWorkerVariableSpec(tensor_spec.TensorSpec): +class PerWorkerVariableSpec(tensor.TensorSpec): def __init__(self, value=None, name=None): super().__init__(value.shape, value.dtype, name=name) self._value = value @@ -745,7 +745,7 @@ def closure(): else: return self._coordinator_instance.resource_handle - return closure, tensor_spec.TensorSpec([], dtype=dtypes.resource) + return closure, tensor.TensorSpec([], dtype=dtypes.resource) def _maybe_build_distributed_table(self): """Create table objects and resources on each worker if hasn't been created.""" @@ -871,7 +871,7 @@ def closure(): return self._coordinator_instance.resource_handle - return closure, tensor_spec.TensorSpec(shape=(), dtype=dtypes.resource) + return closure, tensor.TensorSpec(shape=(), dtype=dtypes.resource) def __setattr__(self, name, value): if name in TRACKABLE_RESOURCE_METHODS: diff --git a/tensorflow/python/distribute/sharded_variable.py b/tensorflow/python/distribute/sharded_variable.py index f65677920f4579..68c05c4d867fb3 100644 --- a/tensorflow/python/distribute/sharded_variable.py +++ b/tensorflow/python/distribute/sharded_variable.py @@ -25,6 +25,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import indexed_slices as indexed_slices_lib from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_conversion_registry from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import type_spec @@ -93,7 +94,6 @@ class FixedShardsPartitioner(Partitioner): >>> # use in ParameterServerStrategy >>> # strategy = tf.distribute.experimental.ParameterServerStrategy( >>> # cluster_resolver=cluster_resolver, variable_partitioner=partitioner) - """ def __init__(self, num_shards): @@ -134,10 +134,9 @@ class MinSizePartitioner(Partitioner): >>> # cluster_resolver=cluster_resolver, variable_partitioner=partitioner) """ - def __init__(self, - min_shard_bytes=256 << 10, - max_shards=1, - bytes_per_string=16): + def __init__( + self, min_shard_bytes=256 << 10, max_shards=1, bytes_per_string=16 + ): """Creates a new `MinSizePartitioner`. Args: @@ -147,14 +146,19 @@ def __init__(self, an estimate of how large each string is. """ if min_shard_bytes < 1: - raise ValueError('Argument `min_shard_bytes` must be positive. ' - f'Received: {min_shard_bytes}') + raise ValueError( + 'Argument `min_shard_bytes` must be positive. ' + f'Received: {min_shard_bytes}' + ) if max_shards < 1: - raise ValueError('Argument `max_shards` must be positive. ' - f'Received: {max_shards}') + raise ValueError( + f'Argument `max_shards` must be positive. Received: {max_shards}' + ) if bytes_per_string < 1: - raise ValueError('Argument `bytes_per_string` must be positive. ' - f'Received: {bytes_per_string}') + raise ValueError( + 'Argument `bytes_per_string` must be positive. ' + f'Received: {bytes_per_string}' + ) self._min_shard_bytes = min_shard_bytes self._max_shards = max_shards self._bytes_per_string = bytes_per_string @@ -164,7 +168,8 @@ def __call__(self, shape, dtype, axis=0): max_partitions=self._max_shards, axis=axis, min_slice_size=self._min_shard_bytes, - bytes_per_string_element=self._bytes_per_string)(shape, dtype) + bytes_per_string_element=self._bytes_per_string, + )(shape, dtype) @tf_export('distribute.experimental.partitioners.MaxSizePartitioner', v1=[]) @@ -207,14 +212,19 @@ def __init__(self, max_shard_bytes, max_shards=None, bytes_per_string=16): an estimate of how large each string is. """ if max_shard_bytes < 1: - raise ValueError('Argument `max_shard_bytes` must be positive. ' - f'Received {max_shard_bytes}') + raise ValueError( + 'Argument `max_shard_bytes` must be positive. ' + f'Received {max_shard_bytes}' + ) if max_shards and max_shards < 1: - raise ValueError('Argument `max_shards` must be positive. ' - f'Received {max_shards}') + raise ValueError( + f'Argument `max_shards` must be positive. Received {max_shards}' + ) if bytes_per_string < 1: - raise ValueError('Argument `bytes_per_string` must be positive. ' - f'Received: {bytes_per_string}') + raise ValueError( + 'Argument `bytes_per_string` must be positive. ' + f'Received: {bytes_per_string}' + ) self._max_shard_bytes = max_shard_bytes self._max_shards = max_shards @@ -225,7 +235,8 @@ def __call__(self, shape, dtype, axis=0): max_shard_bytes=self._max_shard_bytes, max_shards=self._max_shards, bytes_per_string_element=self._bytes_per_string, - axis=axis)(shape, dtype) + axis=axis, + )(shape, dtype) class ShardedVariableSpec(type_spec.TypeSpec): @@ -264,7 +275,6 @@ class ShardedVariableMixin(trackable.Trackable): def __init__(self, variables, name='ShardedVariable'): """Treats `variables` as shards of a larger Variable. - Example: ``` @@ -287,16 +297,22 @@ def __init__(self, variables, name='ShardedVariable'): self._variables = variables self._name = name - if not isinstance(variables, Sequence) or not variables or any( - not isinstance(v, variables_lib.Variable) for v in variables): - raise TypeError('Argument `variables` should be a non-empty list of ' - f'`variables.Variable`s. Received {variables}') + if ( + not isinstance(variables, Sequence) + or not variables + or any(not isinstance(v, variables_lib.Variable) for v in variables) + ): + raise TypeError( + 'Argument `variables` should be a non-empty list of ' + f'`variables.Variable`s. Received {variables}' + ) var_dtypes = {v.dtype for v in variables} if len(var_dtypes) > 1: raise ValueError( 'All elements in argument `variables` must have the same dtype. ' - f'Received dtypes: {[v.dtype for v in variables]}') + f'Received dtypes: {[v.dtype for v in variables]}' + ) first_var = variables[0] self._dtype = first_var.dtype @@ -307,10 +323,12 @@ def __init__(self, variables, name='ShardedVariable'): raise ValueError( 'All elements in argument `variables` must have the same shapes ' 'except for the first axis. ' - f'Received shapes: {[v.shape for v in variables]}') + f'Received shapes: {[v.shape for v in variables]}' + ) first_dim = sum(int(v.shape.as_list()[0]) for v in variables) - self._shape = tensor_shape.TensorShape([first_dim] + - first_var.shape.as_list()[1:]) + self._shape = tensor_shape.TensorShape( + [first_dim] + first_var.shape.as_list()[1:] + ) for v in variables: v._sharded_container = weakref.ref(self) @@ -321,7 +339,8 @@ def __init__(self, variables, name='ShardedVariable'): for i in range(1, len(variables)): # Always partition on the first axis. Offsets on other axes are 0. self._var_offsets[i][0] += ( - self._var_offsets[i - 1][0] + variables[i - 1].shape.as_list()[0]) + self._var_offsets[i - 1][0] + variables[i - 1].shape.as_list()[0] + ) save_slice_info = [v._get_save_slice_info() for v in variables] # pylint: disable=protected-access if any(slice_info is not None for slice_info in save_slice_info): @@ -329,16 +348,20 @@ def __init__(self, variables, name='ShardedVariable'): '`SaveSliceInfo` should not be set for all elements in argument ' '`variables`. `ShardedVariable` will infer `SaveSliceInfo` according ' 'to the order of the elements `variables`. ' - f'Received save slice info {save_slice_info}') + f'Received save slice info {save_slice_info}' + ) # We create an uninitialized saving_variable with the full shape, which can # be later captured in signatures so that the signatures can treat this # ShardedVariable as one single variable. self._saving_variable = resource_variable_ops.UninitializedVariable( - shape=self._shape, dtype=self._dtype, name=self._name, + shape=self._shape, + dtype=self._dtype, + name=self._name, trainable=self._variables[0].trainable, synchronization=variables_lib.VariableSynchronization.NONE, - aggregation=variables_lib.VariableAggregation.NONE) + aggregation=variables_lib.VariableAggregation.NONE, + ) def __iter__(self): """Return an iterable for accessing the underlying sharded variables.""" @@ -365,9 +388,14 @@ def __getitem__(self, slice_spec): # TODO(b/177482728): Support tensor input. # TODO(b/177482728): Support slice assign, similar to variable slice assign. - if (isinstance(slice_spec, bool) or (isinstance(slice_spec, ops.Tensor) and - slice_spec.dtype == dtypes.bool) or - (isinstance(slice_spec, np.ndarray) and slice_spec.dtype == bool)): + if ( + isinstance(slice_spec, bool) + or ( + isinstance(slice_spec, tensor_lib.Tensor) + and slice_spec.dtype == dtypes.bool + ) + or (isinstance(slice_spec, np.ndarray) and slice_spec.dtype == bool) + ): tensor = _var_to_tensor(self) return array_ops.boolean_mask(tensor=tensor, mask=slice_spec) @@ -385,30 +413,36 @@ def __getitem__(self, slice_spec): if s.step is not None and s.step < 0: values.reverse() if not values: - return constant_op.constant([], - dtype=self._dtype, - shape=((0,) + self._shape[1:])) + return constant_op.constant( + [], dtype=self._dtype, shape=((0,) + self._shape[1:]) + ) return array_ops.concat(values, axis=0) elif s is Ellipsis: - return array_ops.concat([var[slice_spec] for var in self._variables], - axis=0) + return array_ops.concat( + [var[slice_spec] for var in self._variables], axis=0 + ) elif s is array_ops.newaxis: - return array_ops.concat([var[slice_spec[1:]] for var in self._variables], - axis=0)[array_ops.newaxis] + return array_ops.concat( + [var[slice_spec[1:]] for var in self._variables], axis=0 + )[array_ops.newaxis] else: - if isinstance(s, ops.Tensor): + if isinstance(s, tensor_lib.Tensor): raise TypeError( - 'ShardedVariable: using Tensor for indexing is not allowed.') + 'ShardedVariable: using Tensor for indexing is not allowed.' + ) if s < 0: s += self._shape[0] if s < 0 or s >= self._shape[0]: raise IndexError( - f'ShardedVariable: slice index {s} of dimension 0 out of bounds.') + f'ShardedVariable: slice index {s} of dimension 0 out of bounds.' + ) for i in range(len(self._variables)): - if i == len(self._variables) - 1 or (s > self._var_offsets[i][0] and - s < self._var_offsets[i + 1][0]): - return self._variables[i][(s - self._var_offsets[i][0],) + - slice_spec[1:]] + if i == len(self._variables) - 1 or ( + s > self._var_offsets[i][0] and s < self._var_offsets[i + 1][0] + ): + return self._variables[i][ + (s - self._var_offsets[i][0],) + slice_spec[1:] + ] def _decompose_slice_spec(self, slice_spec): """Decompose a global slice_spec into a list of per-variable slice_spec. @@ -441,11 +475,15 @@ def _decompose_slice_spec(self, slice_spec): v1[returned[1]] = [5] v2[returned[2]] = [9, 7] """ - if isinstance(slice_spec.start, ops.Tensor) or isinstance( - slice_spec.stop, ops.Tensor) or isinstance(slice_spec.step, ops.Tensor): + if ( + isinstance(slice_spec.start, tensor_lib.Tensor) + or isinstance(slice_spec.stop, tensor_lib.Tensor) + or isinstance(slice_spec.step, tensor_lib.Tensor) + ): raise TypeError( 'ShardedVariable: using Tensor in slice_spec is not allowed. Please ' - 'file a feature request with the TensorFlow team.') + 'file a feature request with the TensorFlow team.' + ) result = [] # Normalize start, end and stop. @@ -479,7 +517,9 @@ def _decompose_slice_spec(self, slice_spec): var_start = self._var_offsets[i][0] var_end = ( self._var_offsets[i + 1][0] - if i < len(self._var_offsets) - 1 else self._shape[0]) + if i < len(self._var_offsets) - 1 + else self._shape[0] + ) if cur < var_start: cur += slice_step * int(math.ceil((var_start - cur) / slice_step)) if cur >= var_end or cur >= slice_end: @@ -493,7 +533,9 @@ def _decompose_slice_spec(self, slice_spec): var_start = self._var_offsets[i][0] var_end = ( self._var_offsets[i + 1][0] - if i < len(self._var_offsets) - 1 else self._shape[0]) + if i < len(self._var_offsets) - 1 + else self._shape[0] + ) if cur >= var_end: cur += slice_step * int(math.ceil((var_end - cur - 1) / slice_step)) if cur < var_start or cur <= slice_end: @@ -513,8 +555,11 @@ def _decompose_slice_spec(self, slice_spec): @property def _type_spec(self): return ShardedVariableSpec( - *(resource_variable_ops.VariableSpec(v.shape, v.dtype) - for v in self._variables)) + *( + resource_variable_ops.VariableSpec(v.shape, v.dtype) + for v in self._variables + ) + ) @property def variables(self): @@ -546,13 +591,15 @@ def assign(self, value, use_locking=None, name=None, read_value=True): def assign_add(self, delta, use_locking=False, name=None, read_value=True): for i, v in enumerate(self._variables): v.assign_add( - array_ops.slice(delta, self._var_offsets[i], v.shape.as_list())) + array_ops.slice(delta, self._var_offsets[i], v.shape.as_list()) + ) return self def assign_sub(self, delta, use_locking=False, name=None, read_value=True): for i, v in enumerate(self._variables): v.assign_sub( - array_ops.slice(delta, self._var_offsets[i], v.shape.as_list())) + array_ops.slice(delta, self._var_offsets[i], v.shape.as_list()) + ) return self def _decompose_indices(self, indices): @@ -560,7 +607,8 @@ def _decompose_indices(self, indices): if indices.shape.rank != 1: raise ValueError( 'ShardedVariable: indices must be 1D Tensor for sparse operations. ' - f'Received shape: {indices.shape}') + f'Received shape: {indices.shape}' + ) base = self._shape[0] // len(self._variables) extra = self._shape[0] % len(self._variables) @@ -573,7 +621,8 @@ def _decompose_indices(self, indices): if expect_first_dim != actual_first_dim: raise NotImplementedError( 'scater_xxx ops are not supported in ShardedVariale that does not ' - 'conform to "div" sharding') + 'conform to "div" sharding' + ) # For index that falls into the partition that has extra 1, assignment is # `index // (base + 1)` (no less than `(indices - extra) // base`) @@ -585,30 +634,35 @@ def _decompose_indices(self, indices): # base = 10, extra = 2, partitions: [0, 11), [11, 22), [22, 32) # index = 10 -> partition_assigment = 0 # index = 22 -> partition_assiment = 2 - partition_assignments = math_ops.maximum(indices // (base + 1), - (indices - extra) // base) - local_indices = array_ops.where(partition_assignments < extra, - indices % (base + 1), - (indices - extra) % base) + partition_assignments = math_ops.maximum( + indices // (base + 1), (indices - extra) // base + ) + local_indices = array_ops.where( + partition_assignments < extra, + indices % (base + 1), + (indices - extra) % base, + ) # For whatever reason `dynamic_partition` only supports int32 partition_assignments = math_ops.cast(partition_assignments, dtypes.int32) - per_var_indices = data_flow_ops.dynamic_partition(local_indices, - partition_assignments, - len(self._variables)) + per_var_indices = data_flow_ops.dynamic_partition( + local_indices, partition_assignments, len(self._variables) + ) return per_var_indices, partition_assignments def _decompose_indexed_slices(self, indexed_slices): """Decompose a global `IndexedSlices` into a list of per-variable ones.""" per_var_indices, partition_assignments = self._decompose_indices( - indexed_slices.indices) - per_var_values = data_flow_ops.dynamic_partition(indexed_slices.values, - partition_assignments, - len(self._variables)) + indexed_slices.indices + ) + per_var_values = data_flow_ops.dynamic_partition( + indexed_slices.values, partition_assignments, len(self._variables) + ) return [ indexed_slices_lib.IndexedSlices( - values=per_var_values[i], indices=per_var_indices[i]) + values=per_var_values[i], indices=per_var_indices[i] + ) for i in range(len(self._variables)) ] @@ -720,24 +774,32 @@ def _saveable_factory(name=self.name): full_name=self.name, full_shape=self.shape.as_list(), var_offset=copy.copy(var_offset), - var_shape=v.shape.as_list()) + var_shape=v.shape.as_list(), + ) saveables.append( saveable_object_util.ResourceVariableSaveable( - v, save_slice_info.spec, name)) + v, save_slice_info.spec, name + ) + ) var_offset[0] += int(v.shape[0]) return saveables return {trackable.VARIABLE_VALUE_KEY: _saveable_factory} - def _export_to_saved_model_graph(self, object_map, tensor_map, - options, **kwargs): + def _export_to_saved_model_graph( + self, object_map, tensor_map, options, **kwargs + ): """For implementing `Trackable`.""" resource_list = [] for v in self._variables + [self._saving_variable]: - resource_list.extend(v._export_to_saved_model_graph( # pylint:disable=protected-access - object_map, tensor_map, options, **kwargs)) - object_map[self] = ShardedVariable([object_map[self._saving_variable]], - name=self.name) + resource_list.extend( + v._export_to_saved_model_graph( # pylint:disable=protected-access + object_map, tensor_map, options, **kwargs + ) + ) + object_map[self] = ShardedVariable( + [object_map[self._saving_variable]], name=self.name + ) return resource_list @property @@ -828,7 +890,7 @@ def _type_spec(self): @classmethod def _overload_all_operators(cls): """Register overloads for all operators.""" - for operator in ops.Tensor.OVERLOADABLE_OPERATORS: + for operator in tensor_lib.Tensor.OVERLOADABLE_OPERATORS: if operator == '__getitem__': continue @@ -836,16 +898,17 @@ def _overload_all_operators(cls): @classmethod def _overload_operator(cls, operator): - """Delegate an operator overload to `ops.Tensor`.""" - tensor_operator = getattr(ops.Tensor, operator) + """Delegate an operator overload to `tensor_lib.Tensor`.""" + tensor_operator = getattr(tensor_lib.Tensor, operator) def _operator(v, *args, **kwargs): return tensor_operator(_var_to_tensor(v), *args, **kwargs) setattr(cls, operator, _operator) - def __tf_experimental_restore_capture__(self, concrete_function, - internal_capture): + def __tf_experimental_restore_capture__( + self, concrete_function, internal_capture + ): # Avoid restoring captures for functions that use ShardedVariable - the # layer will be recreated during Keras model loading # TODO(jmullenbach): support loading models with ShardedVariables using @@ -858,7 +921,8 @@ def _should_act_as_resource_variable(self): def _write_object_proto(self, proto, options): resource_variable_ops.write_object_proto_for_resource_variable( - self._saving_variable, proto, options, enforce_naming=False) + self._saving_variable, proto, options, enforce_naming=False + ) def _var_to_tensor(var, dtype=None, name=None, as_ref=False): @@ -867,10 +931,12 @@ def _var_to_tensor(var, dtype=None, name=None, as_ref=False): if dtype is not None and not dtype.is_compatible_with(var.dtype): raise ValueError( 'Incompatible type conversion requested to type {!r} for variable ' - 'of type {!r}'.format(dtype.name, var.dtype.name)) + 'of type {!r}'.format(dtype.name, var.dtype.name) + ) if as_ref: raise NotImplementedError( - "ShardedVariable doesn't support being used as a reference.") + "ShardedVariable doesn't support being used as a reference." + ) # We use op dispatch mechanism to override embedding_lookup ops when called # with ShardedVariable. This requires embedding_lookup ops to raise TypeError # when called with ShardedVariable. However since ShardedVariable can be @@ -885,32 +951,42 @@ def _var_to_tensor(var, dtype=None, name=None, as_ref=False): # TODO(chenkai): Find a more robust way to do this, which should not rely # on namescope. if 'embedding_lookup' in ops.get_name_scope(): - raise TypeError('Converting ShardedVariable to tensor in embedding lookup' - ' ops is disallowed.') + raise TypeError( + 'Converting ShardedVariable to tensor in embedding lookup' + ' ops is disallowed.' + ) return array_ops.concat(var.variables, axis=0) # Register a conversion function which reads the value of the variable, # allowing instances of the class to be used as tensors. tensor_conversion_registry.register_tensor_conversion_function( - ShardedVariable, _var_to_tensor) + ShardedVariable, _var_to_tensor +) ShardedVariable._overload_all_operators() # pylint: disable=protected-access # Override the behavior of embedding_lookup(sharded_variable, ...) @dispatch.dispatch_for_types(embedding_ops.embedding_lookup, ShardedVariable) -def embedding_lookup(params, - ids, - partition_strategy='mod', - name=None, - validate_indices=True, - max_norm=None): +def embedding_lookup( + params, + ids, + partition_strategy='mod', + name=None, + validate_indices=True, + max_norm=None, +): if isinstance(params, list): params = params[0] - return embedding_ops.embedding_lookup(params.variables, ids, - partition_strategy, name, - validate_indices, max_norm) + return embedding_ops.embedding_lookup( + params.variables, + ids, + partition_strategy, + name, + validate_indices, + max_norm, + ) # Separately override safe_embedding_lookup_sparse, to avoid conversion of @@ -937,4 +1013,5 @@ def safe_embedding_lookup_sparse( name=name, partition_strategy=partition_strategy, max_norm=max_norm, - allow_fast_lookup=allow_fast_lookup) + allow_fast_lookup=allow_fast_lookup, + ) diff --git a/tensorflow/python/distribute/summary_op_util.py b/tensorflow/python/distribute/summary_op_util.py index 59e619a871e388..7ccb6a181bd206 100644 --- a/tensorflow/python/distribute/summary_op_util.py +++ b/tensorflow/python/distribute/summary_op_util.py @@ -16,7 +16,7 @@ from tensorflow.python.distribute import distribute_lib -from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.framework import tensor_util @@ -39,6 +39,6 @@ def skip_summary(): # TODO(b/118385803): when replica_id of _TPUReplicaContext is properly # initialized, remember to change here as well. replica_id = replica_context.replica_id_in_sync_group - if isinstance(replica_id, ops.Tensor): + if isinstance(replica_id, tensor.Tensor): replica_id = tensor_util.constant_value(replica_id) return replica_id and replica_id > 0 diff --git a/tensorflow/python/distribute/test_util.py b/tensorflow/python/distribute/test_util.py index 866d86a4dd3f59..8dbf6c020c1242 100644 --- a/tensorflow/python/distribute/test_util.py +++ b/tensorflow/python/distribute/test_util.py @@ -32,6 +32,7 @@ from tensorflow.python.eager import context from tensorflow.python.framework import config from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops_stack from tensorflow.python.util import nest @@ -142,7 +143,7 @@ def _op_dependencies(op): """Returns the data and control dependencies of a tf.Operation combined.""" deps = [] for node in itertools.chain(op.inputs, op.control_inputs): - if isinstance(node, ops.Tensor): + if isinstance(node, tensor.Tensor): node = node.op assert isinstance(node, ops.Operation) deps.append(node) diff --git a/tensorflow/python/distribute/values.py b/tensorflow/python/distribute/values.py index 5bf5e4ec52aca2..fa7456b19d4ea5 100644 --- a/tensorflow/python/distribute/values.py +++ b/tensorflow/python/distribute/values.py @@ -29,6 +29,7 @@ from tensorflow.python.framework import composite_tensor from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_conversion_registry from tensorflow.python.framework import tensor_util from tensorflow.python.framework import type_spec @@ -1040,7 +1041,7 @@ def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False): def __tf_tensor__(self, dtype: Optional[dtypes.DType] = None, - name: Optional[str] = None) -> ops.Tensor: + name: Optional[str] = None) -> tensor_lib.Tensor: return self._dense_var_to_tensor(dtype, name) def _export_to_saved_model_graph(self, diff --git a/tensorflow/python/distribute/values_test.py b/tensorflow/python/distribute/values_test.py index dd0add62d863a8..70cd0fb6a608a5 100644 --- a/tensorflow/python/distribute/values_test.py +++ b/tensorflow/python/distribute/values_test.py @@ -36,6 +36,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops @@ -498,12 +499,12 @@ def testTensorConversion(self, distribution): _, replica_local = _make_replica_local( variable_scope.VariableAggregation.SUM, distribution) converted = ops.convert_to_tensor(replica_local, as_ref=False) - self.assertIsInstance(converted, ops.Tensor) + self.assertIsInstance(converted, tensor.Tensor) self.assertEqual(converted.dtype, replica_local.dtype) converted = ops.convert_to_tensor(replica_local, as_ref=True) # Resources variable are converted to tensors as well when as_ref is True. - self.assertIsInstance(converted, ops.Tensor) + self.assertIsInstance(converted, tensor.Tensor) self.assertEqual(converted.dtype, replica_local.dtype) @combinations.generate(combinations.combine( @@ -517,7 +518,7 @@ def testValueInCrossReplicaContext(self, distribution): value_list, replica_local = _make_replica_local( variable_scope.VariableAggregation.ONLY_FIRST_REPLICA, distribution) - self.assertIsInstance(replica_local.value(), ops.Tensor) + self.assertIsInstance(replica_local.value(), tensor.Tensor) self.assertEqual(self.evaluate(replica_local.value()), self.evaluate(value_list[0].value())) diff --git a/tensorflow/python/distribute/values_v2_test.py b/tensorflow/python/distribute/values_v2_test.py index e7dcd958fe3240..0e4c3298654b5e 100644 --- a/tensorflow/python/distribute/values_v2_test.py +++ b/tensorflow/python/distribute/values_v2_test.py @@ -25,6 +25,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import indexed_slices from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.ops import array_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variables as variables_lib @@ -329,7 +330,7 @@ def testSlice(self): # ==== Begin ResourceVariable interface === def testHandle(self): v = self.create_variable() - self.assertIsInstance(v.handle, ops.Tensor) + self.assertIsInstance(v.handle, tensor.Tensor) self.assertEqual(v.handle.dtype, dtypes.resource) def testInGraphMode(self): From eaaee50ea226fcda85b70f357f860064247983d1 Mon Sep 17 00:00:00 2001 From: Scott Zhu Date: Mon, 10 Jul 2023 16:08:40 -0700 Subject: [PATCH 094/376] Change the logging level for the warning message of tf.data.AutoShardPolicy.AUTO. The original message is not actionable, and is also on by default, which lead to user confusion. eg https://github.com/keras-team/keras/pull/16604. Change this to VLOG(2) to reduce the noise level. PiperOrigin-RevId: 547010445 --- tensorflow/core/grappler/optimizers/data/auto_shard.cc | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tensorflow/core/grappler/optimizers/data/auto_shard.cc b/tensorflow/core/grappler/optimizers/data/auto_shard.cc index 9c757fa333a53a..bcea37d586529f 100644 --- a/tensorflow/core/grappler/optimizers/data/auto_shard.cc +++ b/tensorflow/core/grappler/optimizers/data/auto_shard.cc @@ -784,10 +784,12 @@ Status ApplyAutoShard(const NodeDef& sink_node, int64_t num_workers, default: Status s = ShardByFile(sink_node, num_workers, index, &flib, graph); if (absl::IsNotFound(s)) { - LOG(WARNING) << "AUTO sharding policy will apply DATA sharding policy " - "as it failed to apply FILE sharding policy because of " - "the following reason: " - << s.message(); + if (VLOG_IS_ON(2)) { + VLOG(2) << "AUTO sharding policy will apply DATA sharding policy " + "as it failed to apply FILE sharding policy because of " + "the following reason: " + << s.message(); + } *policy_applied = AutoShardPolicy::DATA; return ShardByData(sink_node, num_workers, index, num_replicas, graph); } From 2b942bfcef15799c99a777862aa848f77dfeaf95 Mon Sep 17 00:00:00 2001 From: "ag.ramesh" Date: Mon, 10 Jul 2023 16:24:02 -0700 Subject: [PATCH 095/376] Refactoring to move oneDNN threadpool wrapper to tsl and clean up of the code for clarity, and renaming mkl to oneDNN. --- .../core/kernels/mkl/mkl_avgpooling_op.cc | 16 +++- .../core/kernels/mkl/mkl_batch_matmul_op.cc | 11 ++- tensorflow/core/kernels/mkl/mkl_concat_op.cc | 9 +- .../kernels/mkl/mkl_conv_grad_filter_ops.cc | 9 +- .../kernels/mkl/mkl_conv_grad_input_ops.cc | 9 +- tensorflow/core/kernels/mkl/mkl_conv_ops.cc | 9 +- .../core/kernels/mkl/mkl_dequantize_op.cc | 9 +- tensorflow/core/kernels/mkl/mkl_einsum_op.cc | 9 +- .../kernels/mkl/mkl_fused_batch_norm_op.cc | 18 +++- .../kernels/mkl/mkl_fused_instance_norm_op.cc | 11 ++- .../core/kernels/mkl/mkl_layer_norm_op.cc | 11 ++- tensorflow/core/kernels/mkl/mkl_matmul_op.cc | 9 +- .../core/kernels/mkl/mkl_matmul_op_fused.cc | 9 +- .../core/kernels/mkl/mkl_matmul_ops_common.h | 11 ++- .../core/kernels/mkl/mkl_maxpooling_op.cc | 16 +++- tensorflow/core/kernels/mkl/mkl_qmatmul_op.cc | 10 ++- .../core/kernels/mkl/mkl_quantize_op.cc | 9 +- tensorflow/core/kernels/mkl/mkl_relu_op.cc | 18 +++- ...mkl_requantization_range_per_channel_op.cc | 3 +- .../mkl/mkl_requantize_per_channel_op.cc | 11 ++- tensorflow/core/kernels/mkl/mkl_softmax_op.cc | 9 +- .../core/kernels/mkl/mkl_transpose_op.cc | 9 +- tensorflow/core/util/BUILD | 3 +- tensorflow/core/util/mkl_util.h | 42 ++++++--- tensorflow/tsl/util/BUILD | 12 +++ .../util/onednn_threadpool.h} | 89 ++++++++++--------- 26 files changed, 291 insertions(+), 90 deletions(-) rename tensorflow/{core/util/mkl_threadpool.h => tsl/util/onednn_threadpool.h} (72%) diff --git a/tensorflow/core/kernels/mkl/mkl_avgpooling_op.cc b/tensorflow/core/kernels/mkl/mkl_avgpooling_op.cc index b1c7081e3dbe8a..42c5ed61bd888d 100644 --- a/tensorflow/core/kernels/mkl/mkl_avgpooling_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_avgpooling_op.cc @@ -123,7 +123,14 @@ class MklAvgPoolingOp : public MklPoolingForwardOpBase { dnnl::algorithm::pooling_avg_exclude_padding, pooling_prop_kind, static_cast(this->data_format_mkldnn_), input_md, this->native_format_); - MklDnnThreadPool eigen_tp(context); + // Create the oneDNN wrapper over eigen threapool and set max threads + // in oneDNN. + Eigen::ThreadPoolInterface* eigen_interface = + context->device() + ->tensorflow_cpu_worker_threads() + ->workers->AsEigenThreadPool(); + tsl::OneDnnThreadPool eigen_tp(eigen_interface, + ThreadPoolUseCallerThread()); pooling_fwd = MklPoolingFwdPrimitiveFactory::Get(fwdParams); // Allocate output tensor. @@ -340,7 +347,12 @@ class MklAvgPoolingGradOp : public MklPoolingBackwardOpBase { prop_kind::forward_training, static_cast(this->data_format_mkldnn_), src_md, this->native_format_); - MklDnnThreadPool eigen_tp(context); + Eigen::ThreadPoolInterface* eigen_interface = + context->device() + ->tensorflow_cpu_worker_threads() + ->workers->AsEigenThreadPool(); + tsl::OneDnnThreadPool eigen_tp(eigen_interface, + ThreadPoolUseCallerThread()); MklPoolingBwdPrimitive* pooling_bwd = MklPoolingBwdPrimitiveFactory::Get(bwdParams); diff --git a/tensorflow/core/kernels/mkl/mkl_batch_matmul_op.cc b/tensorflow/core/kernels/mkl/mkl_batch_matmul_op.cc index 3562e2c83da56e..62e1de035821e7 100644 --- a/tensorflow/core/kernels/mkl/mkl_batch_matmul_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_batch_matmul_op.cc @@ -25,7 +25,6 @@ limitations under the License. #if defined(INTEL_MKL) -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" @@ -37,6 +36,7 @@ limitations under the License. #include "tensorflow/core/kernels/mkl/mkl_matmul_ops_common.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/matmul_bcast.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" namespace tensorflow { @@ -160,7 +160,14 @@ class BatchMatMulMkl : public OpKernel { out_shape, adj_x_, adj_y_); this->ExtendMklMatMulParams(ctx, *params); - MklDnnThreadPool eigen_tp(ctx); + // Create the oneDNN wrapper over eigen threapool and set max threads + // in oneDNN. + Eigen::ThreadPoolInterface* eigen_interface = + ctx->device() + ->tensorflow_cpu_worker_threads() + ->workers->AsEigenThreadPool(); + tsl::OneDnnThreadPool eigen_tp(eigen_interface, + ThreadPoolUseCallerThread()); // Create or retrieve matmul primitive from cache. MklMatMulPrimitive* matmul_prim = MklMatMulPrimitiveFactory::Get( diff --git a/tensorflow/core/kernels/mkl/mkl_concat_op.cc b/tensorflow/core/kernels/mkl/mkl_concat_op.cc index 07e758a5e77d61..319f21201d470f 100644 --- a/tensorflow/core/kernels/mkl/mkl_concat_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_concat_op.cc @@ -760,7 +760,14 @@ class MklConcatOp : public OpKernel { // then since MklDnn order is NCHW, concat_dim needs to be 1. if (are_all_mkl_inputs) concat_dim = mkl_input_shapes[0].TfDimIdx(concat_dim); - MklDnnThreadPool eigen_tp(context); + // Create the oneDNN wrapper over eigen threapool and set max threads + // in oneDNN. + Eigen::ThreadPoolInterface* eigen_interface = + context->device() + ->tensorflow_cpu_worker_threads() + ->workers->AsEigenThreadPool(); + tsl::OneDnnThreadPool eigen_tp(eigen_interface, + ThreadPoolUseCallerThread()); if (!inputs.empty()) { if (are_all_mkl_inputs) { auto concat_pd = diff --git a/tensorflow/core/kernels/mkl/mkl_conv_grad_filter_ops.cc b/tensorflow/core/kernels/mkl/mkl_conv_grad_filter_ops.cc index 0b5519fc46b06f..7728c9aaf53cee 100644 --- a/tensorflow/core/kernels/mkl/mkl_conv_grad_filter_ops.cc +++ b/tensorflow/core/kernels/mkl/mkl_conv_grad_filter_ops.cc @@ -517,7 +517,14 @@ class MklConvCustomBackpropFilterOp // variable TF_MKL_OPTIMIZE_PRIMITIVE_MEMUSE is set to true. bool do_not_cache = MklPrimitiveFactory::IsPrimitiveMemOptEnabled(); - MklDnnThreadPool eigen_tp(context); + // Create the oneDNN wrapper over eigen threadpool and set max threads + // in oneDNN. + Eigen::ThreadPoolInterface* eigen_interface = + context->device() + ->tensorflow_cpu_worker_threads() + ->workers->AsEigenThreadPool(); + tsl::OneDnnThreadPool eigen_tp(eigen_interface, + ThreadPoolUseCallerThread()); MklConvBwdFilterPrimitive* conv_bwd_filter = MklConvBwdFilterPrimitiveFactory::Get(convBwdFilterDims, do_not_cache); diff --git a/tensorflow/core/kernels/mkl/mkl_conv_grad_input_ops.cc b/tensorflow/core/kernels/mkl/mkl_conv_grad_input_ops.cc index 3eab40d24ee07f..c376b4e4ec6531 100644 --- a/tensorflow/core/kernels/mkl/mkl_conv_grad_input_ops.cc +++ b/tensorflow/core/kernels/mkl/mkl_conv_grad_input_ops.cc @@ -470,7 +470,14 @@ class MklConvCustomBackpropInputOp (MklPrimitiveFactory::IsLegacyPlatform() || IsConv1x1StrideNot1(fwd_filter_dims, strides)); - MklDnnThreadPool eigen_tp(context); + // Create the oneDNN wrapper over eigen threadpool and set max threads + // in oneDNN. + Eigen::ThreadPoolInterface* eigen_interface = + context->device() + ->tensorflow_cpu_worker_threads() + ->workers->AsEigenThreadPool(); + tsl::OneDnnThreadPool eigen_tp(eigen_interface, + ThreadPoolUseCallerThread()); MklConvBwdInputPrimitive* conv_bwd_input = MklConvBwdInputPrimitiveFactory::Get(convBwdInputDims, do_not_cache); diff --git a/tensorflow/core/kernels/mkl/mkl_conv_ops.cc b/tensorflow/core/kernels/mkl/mkl_conv_ops.cc index 8dae3705e0a811..027c3df8a20434 100644 --- a/tensorflow/core/kernels/mkl/mkl_conv_ops.cc +++ b/tensorflow/core/kernels/mkl/mkl_conv_ops.cc @@ -876,7 +876,14 @@ class MklConvOp : public OpKernel { // TODO(intel-tf): Extend the basic parameters for data types and fusions this->ExtendConvFwdParams(context, convFwdDims); - MklDnnThreadPool eigen_tp(context); + // Create the oneDNN wrapper over eigen threadpool and set max threads + // in oneDNN. + Eigen::ThreadPoolInterface* eigen_interface = + context->device() + ->tensorflow_cpu_worker_threads() + ->workers->AsEigenThreadPool(); + tsl::OneDnnThreadPool eigen_tp(eigen_interface, + ThreadPoolUseCallerThread()); conv_fwd = MklConvFwdPrimitiveFactory::Get( convFwdDims, do_not_cache); diff --git a/tensorflow/core/kernels/mkl/mkl_dequantize_op.cc b/tensorflow/core/kernels/mkl/mkl_dequantize_op.cc index ce293200bb3ea2..2f07569fc36261 100644 --- a/tensorflow/core/kernels/mkl/mkl_dequantize_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_dequantize_op.cc @@ -72,7 +72,14 @@ class MklDequantizeOp : public OpKernel { MklDnnData dst(&cpu_engine); std::shared_ptr reorder_stream; - MklDnnThreadPool eigen_tp(ctx); + // Create the oneDNN wrapper over eigen threadpool and set max threads + // in oneDNN. + Eigen::ThreadPoolInterface* eigen_interface = + ctx->device() + ->tensorflow_cpu_worker_threads() + ->workers->AsEigenThreadPool(); + tsl::OneDnnThreadPool eigen_tp(eigen_interface, + ThreadPoolUseCallerThread()); reorder_stream.reset(CreateStream(&eigen_tp, cpu_engine)); memory::format_tag dst_layout_type; diff --git a/tensorflow/core/kernels/mkl/mkl_einsum_op.cc b/tensorflow/core/kernels/mkl/mkl_einsum_op.cc index 48b495a461b309..c7ba2b46cd2bee 100644 --- a/tensorflow/core/kernels/mkl/mkl_einsum_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_einsum_op.cc @@ -112,8 +112,15 @@ struct MklEinsumHelper { auto params = bmm.CreateMatMulParams(prefix, lhs.shape(), rhs.shape(), out_shape, trans_x, trans_y); + // Create the oneDNN wrapper over eigen threadpool and set max threads + // in oneDNN. + Eigen::ThreadPoolInterface* eigen_interface = + ctx->device() + ->tensorflow_cpu_worker_threads() + ->workers->AsEigenThreadPool(); + tsl::OneDnnThreadPool eigen_tp(eigen_interface, + ThreadPoolUseCallerThread()); // Create or retrieve matmul primitive from cache. - MklDnnThreadPool eigen_tp(ctx); MklMatMulPrimitive* matmul_prim = MklMatMulPrimitiveFactory::Get( *params, false /* value for do_not_cache */); diff --git a/tensorflow/core/kernels/mkl/mkl_fused_batch_norm_op.cc b/tensorflow/core/kernels/mkl/mkl_fused_batch_norm_op.cc index 611c3709878e11..83a3e525609370 100644 --- a/tensorflow/core/kernels/mkl/mkl_fused_batch_norm_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_fused_batch_norm_op.cc @@ -14,7 +14,6 @@ limitations under the License. ==============================================================================*/ #if defined(INTEL_MKL) && !defined(ENABLE_ONEDNN_V3) -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "dnnl.hpp" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" @@ -24,6 +23,7 @@ limitations under the License. #include "tensorflow/core/kernels/no_op.h" #include "tensorflow/core/util/mkl_util.h" #include "tensorflow/core/util/tensor_format.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #ifdef DNNL_AARCH64_USE_ACL #include "tensorflow/core/platform/mutex.h" #endif @@ -837,8 +837,15 @@ class MklFusedBatchNormOp : public OpKernel { MklBatchNormFwdParams fwdParams(src_dims, depth_, epsilon_, is_training_, tensor_format_, src_md, activation_mode_); + // Create the oneDNN wrapper over eigen threadpool and set max threads + // in oneDNN. + Eigen::ThreadPoolInterface* eigen_interface = + context->device() + ->tensorflow_cpu_worker_threads() + ->workers->AsEigenThreadPool(); + tsl::OneDnnThreadPool eigen_tp(eigen_interface, + ThreadPoolUseCallerThread()); // Get forward batch-normalization op from the primitive caching pool. - MklDnnThreadPool eigen_tp(context); MklFusedBatchNormFwdPrimitive* bn_fwd = MklFusedBatchNormFwdPrimitiveFactory::Get(fwdParams); @@ -1312,7 +1319,12 @@ class MklFusedBatchNormGradOp : public OpKernel { MklBatchNormBwdParams bwdParams(src_dims, diff_dst_dims, depth_, epsilon_, is_training_, tensor_format_, src_md, diff_dst_md); - MklDnnThreadPool eigen_tp(context); + Eigen::ThreadPoolInterface* eigen_interface = + context->device() + ->tensorflow_cpu_worker_threads() + ->workers->AsEigenThreadPool(); + tsl::OneDnnThreadPool eigen_tp(eigen_interface, + ThreadPoolUseCallerThread()); MklFusedBatchNormBwdPrimitive* bn_bwd = MklFusedBatchNormBwdPrimitiveFactory::Get(bwdParams); diff --git a/tensorflow/core/kernels/mkl/mkl_fused_instance_norm_op.cc b/tensorflow/core/kernels/mkl/mkl_fused_instance_norm_op.cc index 6373bf09539fe4..c98fff64627e27 100644 --- a/tensorflow/core/kernels/mkl/mkl_fused_instance_norm_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_fused_instance_norm_op.cc @@ -15,7 +15,6 @@ limitations under the License. #ifdef INTEL_MKL -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "dnnl.hpp" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" @@ -23,6 +22,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/util/mkl_util.h" #include "tensorflow/core/util/tensor_format.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" using namespace dnnl; using dnnl::batch_normalization_forward; @@ -71,7 +71,14 @@ class MklFusedInstanceNormOp : public OpKernel { OP_REQUIRES(ctx, FormatFromString(data_format_, &tensor_format), errors::InvalidArgument("Invalid data format")); - MklDnnThreadPool eigen_tp(ctx); + // Create the oneDNN wrapper over eigen threadpool and set max threads + // in oneDNN. + Eigen::ThreadPoolInterface* eigen_interface = + ctx->device() + ->tensorflow_cpu_worker_threads() + ->workers->AsEigenThreadPool(); + tsl::OneDnnThreadPool eigen_tp(eigen_interface, + ThreadPoolUseCallerThread()); std::shared_ptr engine_stream_ptr; engine_stream_ptr.reset(CreateStream(&eigen_tp, cpu_engine_)); diff --git a/tensorflow/core/kernels/mkl/mkl_layer_norm_op.cc b/tensorflow/core/kernels/mkl/mkl_layer_norm_op.cc index 297e95c1cc6f20..ad78096ec11547 100644 --- a/tensorflow/core/kernels/mkl/mkl_layer_norm_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_layer_norm_op.cc @@ -15,7 +15,6 @@ limitations under the License. #ifdef INTEL_MKL -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "dnnl.hpp" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" @@ -23,6 +22,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/util/mkl_util.h" #include "tensorflow/core/util/tensor_format.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" using CPUDevice = Eigen::ThreadPoolDevice; using dnnl::layer_normalization_forward; @@ -61,7 +61,14 @@ class MklLayerNormOp : public OpKernel { "tensors are not same.")); auto cpu_engine = engine(engine::kind::cpu, 0); - MklDnnThreadPool eigen_tp(ctx); + // Create the oneDNN wrapper over eigen threadpool and set max threads + // in oneDNN. + Eigen::ThreadPoolInterface* eigen_interface = + ctx->device() + ->tensorflow_cpu_worker_threads() + ->workers->AsEigenThreadPool(); + tsl::OneDnnThreadPool eigen_tp(eigen_interface, + ThreadPoolUseCallerThread()); auto cpu_stream = std::unique_ptr(CreateStream(&eigen_tp, cpu_engine)); diff --git a/tensorflow/core/kernels/mkl/mkl_matmul_op.cc b/tensorflow/core/kernels/mkl/mkl_matmul_op.cc index dc37a7023b42ca..b7d157a58bedd2 100644 --- a/tensorflow/core/kernels/mkl/mkl_matmul_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_matmul_op.cc @@ -162,7 +162,14 @@ class MklMatMulOp : public OpKernel { char char_transb = transb ? 'T' : 'N'; VLOG(2) << "MKL DNN SGEMM called"; #ifndef ENABLE_ONEDNN_OPENMP - MklDnnThreadPool eigen_tp(ctx); + // Create the oneDNN wrapper over eigen threadpool and set max threads + // in oneDNN. + Eigen::ThreadPoolInterface* eigen_interface = + ctx->device() + ->tensorflow_cpu_worker_threads() + ->workers->AsEigenThreadPool(); + tsl::OneDnnThreadPool eigen_tp(eigen_interface, + ThreadPoolUseCallerThread()); // With threadpool , the runtime overhead is comparable to the kernel // execution for small kernel sizes. For such sizes, it may be better to run // the kernel single threaded. Here we are coming up with a cost model based diff --git a/tensorflow/core/kernels/mkl/mkl_matmul_op_fused.cc b/tensorflow/core/kernels/mkl/mkl_matmul_op_fused.cc index 363ca8bbff6c88..3d9d24e358bf86 100644 --- a/tensorflow/core/kernels/mkl/mkl_matmul_op_fused.cc +++ b/tensorflow/core/kernels/mkl/mkl_matmul_op_fused.cc @@ -135,7 +135,14 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase { // Extend the basic parameters for data types and fusions. ExtendMklDnnMatMulFwdParams(ctx, matmul_params); auto st = ExecuteSingleThreadedGemm(batch, channel, k, sizeof(T)); - MklDnnThreadPool eigen_tp(ctx, st ? 1 : -1); + // Create the oneDNN wrapper over eigen threadpool and set max threads + // in oneDNN. + Eigen::ThreadPoolInterface* eigen_interface = + ctx->device() + ->tensorflow_cpu_worker_threads() + ->workers->AsEigenThreadPool(); + tsl::OneDnnThreadPool eigen_tp(eigen_interface, ThreadPoolUseCallerThread(), + st ? 1 : -1); MklDnnMatMulFwdPrimitive* matmul_prim = MklDnnMatMulFwdPrimitiveFactory::Get(matmul_params, 0); diff --git a/tensorflow/core/kernels/mkl/mkl_matmul_ops_common.h b/tensorflow/core/kernels/mkl/mkl_matmul_ops_common.h index 33e2e8a646d192..e74ad037944bb1 100644 --- a/tensorflow/core/kernels/mkl/mkl_matmul_ops_common.h +++ b/tensorflow/core/kernels/mkl/mkl_matmul_ops_common.h @@ -21,12 +21,12 @@ limitations under the License. #include #include -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "dnnl.hpp" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/util/mkl_util.h" #include "tensorflow/core/util/onednn_env_vars.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #ifdef DNNL_AARCH64_USE_ACL #include "tensorflow/core/platform/mutex.h" #endif @@ -1016,7 +1016,14 @@ void dnnl_gemm(char transa, char transb, int64_t m, int64_t n, int64_t k, MklMatMulParams params("dnnl_gemm", a_dims, b_dims, c_dims, a_strides, b_strides, c_strides); auto st = ExecuteSingleThreadedGemm(m, n, k, sizeof(T)); - MklDnnThreadPool eigen_tp(ctx, st ? 1 : -1); + // Create the oneDNN wrapper over eigen threadpool and set max threads + // in oneDNN. + Eigen::ThreadPoolInterface* eigen_interface = + ctx->device() + ->tensorflow_cpu_worker_threads() + ->workers->AsEigenThreadPool(); + tsl::OneDnnThreadPool eigen_tp(eigen_interface, ThreadPoolUseCallerThread(), + st ? 1 : -1); MklMatMulPrimitive* matmul_prim = MklMatMulPrimitiveFactory::Get(params, 0); diff --git a/tensorflow/core/kernels/mkl/mkl_maxpooling_op.cc b/tensorflow/core/kernels/mkl/mkl_maxpooling_op.cc index d27b36c54f2870..050ca1190380cc 100644 --- a/tensorflow/core/kernels/mkl/mkl_maxpooling_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_maxpooling_op.cc @@ -143,7 +143,14 @@ class MklMaxPoolingOp : public MklPoolingForwardOpBase { pooling_prop_kind, static_cast(this->data_format_mkldnn_), input_md, this->native_format_); - MklDnnThreadPool eigen_tp(context); + // Create the oneDNN wrapper over eigen threadpool and set max threads + // in oneDNN. + Eigen::ThreadPoolInterface* eigen_interface = + context->device() + ->tensorflow_cpu_worker_threads() + ->workers->AsEigenThreadPool(); + tsl::OneDnnThreadPool eigen_tp(eigen_interface, + ThreadPoolUseCallerThread()); pooling_fwd = MklPoolingFwdPrimitiveFactory::Get(fwdParams); // Allocate output tensor. this->AllocateOutputTensor(context, *(pooling_fwd->GetPoolingFwdPd()), @@ -337,7 +344,12 @@ class MklMaxPoolingGradOp : public MklPoolingBackwardOpBase { prop_kind::forward_training, static_cast(this->data_format_mkldnn_), src_md, this->native_format_); - MklDnnThreadPool eigen_tp(context); + Eigen::ThreadPoolInterface* eigen_interface = + context->device() + ->tensorflow_cpu_worker_threads() + ->workers->AsEigenThreadPool(); + tsl::OneDnnThreadPool eigen_tp(eigen_interface, + ThreadPoolUseCallerThread()); MklPoolingBwdPrimitive* pooling_bwd = MklPoolingBwdPrimitiveFactory::Get(bwdParams); diff --git a/tensorflow/core/kernels/mkl/mkl_qmatmul_op.cc b/tensorflow/core/kernels/mkl/mkl_qmatmul_op.cc index 155f9fb4207563..28f27dab0c2a34 100644 --- a/tensorflow/core/kernels/mkl/mkl_qmatmul_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_qmatmul_op.cc @@ -104,7 +104,6 @@ limitations under the License. #include "tensorflow/core/kernels/mkl/mkl_quantized_conv_ops.h" #include "tensorflow/core/kernels/no_op.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/util/mkl_threadpool.h" #include "tensorflow/core/util/mkl_util.h" #include "tensorflow/core/util/work_sharder.h" @@ -246,8 +245,15 @@ class MklDnnQuantizedMatMulOp // Extend the basic parameters for data types and fusions. this->ExtendMklDnnMatMulFwdParams(context, matmul_fwd_dims); + // Create the oneDNN wrapper over eigen threadpool and set max threads + // in oneDNN. + Eigen::ThreadPoolInterface* eigen_interface = + context->device() + ->tensorflow_cpu_worker_threads() + ->workers->AsEigenThreadPool(); + tsl::OneDnnThreadPool eigen_tp(eigen_interface, + ThreadPoolUseCallerThread()); // Get a MatMul fwd from primitive pool. - MklDnnThreadPool eigen_tp(context); matmul_fwd = MklDnnMatMulFwdPrimitiveFactory::Get(matmul_fwd_dims, 0); diff --git a/tensorflow/core/kernels/mkl/mkl_quantize_op.cc b/tensorflow/core/kernels/mkl/mkl_quantize_op.cc index f2f2234f5fa5ae..ff6ed33df0674a 100644 --- a/tensorflow/core/kernels/mkl/mkl_quantize_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_quantize_op.cc @@ -560,7 +560,14 @@ class MklQuantizeV2Op : public OpKernel { fwdParams.post_op_params.param.push_back(scale_factor); #endif // ENABLE_ONEDNN_V3 - MklDnnThreadPool eigen_tp(ctx); + // Create the oneDNN wrapper over eigen threadpool and set max threads + // in oneDNN. + Eigen::ThreadPoolInterface* eigen_interface = + ctx->device() + ->tensorflow_cpu_worker_threads() + ->workers->AsEigenThreadPool(); + tsl::OneDnnThreadPool eigen_tp(eigen_interface, + ThreadPoolUseCallerThread()); MklReorderWithScalePrimitive* reorder_prim = MklReorderWithScalePrimitiveFactory::Get(src.GetUsrMem(), dst.GetUsrMem(), fwdParams); diff --git a/tensorflow/core/kernels/mkl/mkl_relu_op.cc b/tensorflow/core/kernels/mkl/mkl_relu_op.cc index 24a0ae60fc02b6..1d07848cc15cd8 100644 --- a/tensorflow/core/kernels/mkl/mkl_relu_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_relu_op.cc @@ -18,7 +18,6 @@ limitations under the License. #include -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "dnnl.hpp" #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op_kernel.h" @@ -26,6 +25,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/util/mkl_util.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #ifdef DNNL_AARCH64_USE_ACL #include "tensorflow/core/platform/mutex.h" #endif @@ -476,7 +476,14 @@ class MklReluOpBase : public OpKernel { // Try to get an eltwise forward primitive from caching pool MklEltwiseFwdParams fwdParams(src_dims, src_md, alg_kind, alpha_, beta_); - MklDnnThreadPool eigen_tp(context); + // Create the oneDNN wrapper over eigen threadpool and set max threads + // in oneDNN. + Eigen::ThreadPoolInterface* eigen_interface = + context->device() + ->tensorflow_cpu_worker_threads() + ->workers->AsEigenThreadPool(); + tsl::OneDnnThreadPool eigen_tp(eigen_interface, + ThreadPoolUseCallerThread()); MklEltwiseFwdPrimitive* eltwise_fwd = MklEltwiseFwdPrimitiveFactory::Get(fwdParams); auto eltwise_fwd_pd = eltwise_fwd->GetEltwiseFwdPd(); @@ -683,7 +690,12 @@ class MklReluGradOpBase : public OpKernel { MklEltwiseBwdParams bwdParams(src_dims, common_md, alg_kind, alpha_, beta_, GetTypeOfInputTensorFromFwdOp()); - MklDnnThreadPool eigen_tp(context); + Eigen::ThreadPoolInterface* eigen_interface = + context->device() + ->tensorflow_cpu_worker_threads() + ->workers->AsEigenThreadPool(); + tsl::OneDnnThreadPool eigen_tp(eigen_interface, + ThreadPoolUseCallerThread()); MklEltwiseBwdPrimitive* eltwise_bwd = MklEltwiseBwdPrimitiveFactory::Get(bwdParams); diff --git a/tensorflow/core/kernels/mkl/mkl_requantization_range_per_channel_op.cc b/tensorflow/core/kernels/mkl/mkl_requantization_range_per_channel_op.cc index 6e1daf9ff5babe..3390eb8d6898a4 100644 --- a/tensorflow/core/kernels/mkl/mkl_requantization_range_per_channel_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_requantization_range_per_channel_op.cc @@ -21,7 +21,6 @@ limitations under the License. #include -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/type_traits.h" @@ -29,8 +28,8 @@ limitations under the License. #include "tensorflow/core/kernels/meta_support.h" #include "tensorflow/core/kernels/no_op.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/util/mkl_threadpool.h" #include "tensorflow/core/util/mkl_util.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" namespace tensorflow { diff --git a/tensorflow/core/kernels/mkl/mkl_requantize_per_channel_op.cc b/tensorflow/core/kernels/mkl/mkl_requantize_per_channel_op.cc index 62ac3674f2e048..5d2da88a5b3313 100644 --- a/tensorflow/core/kernels/mkl/mkl_requantize_per_channel_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_requantize_per_channel_op.cc @@ -20,7 +20,6 @@ limitations under the License. #include -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "dnnl.hpp" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" @@ -30,6 +29,7 @@ limitations under the License. #include "tensorflow/core/kernels/no_op.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/util/mkl_util.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" namespace tensorflow { @@ -115,7 +115,14 @@ class MklRequantizePerChannelOp : public OpKernel { cpu_engine_, scales.data()); #endif // !ENABLE_ONEDNN_V3 - MklDnnThreadPool eigen_tp(ctx); + // Create the oneDNN wrapper over eigen threadpool and set max threads + // in oneDNN. + Eigen::ThreadPoolInterface* eigen_interface = + ctx->device() + ->tensorflow_cpu_worker_threads() + ->workers->AsEigenThreadPool(); + tsl::OneDnnThreadPool eigen_tp(eigen_interface, + ThreadPoolUseCallerThread()); memory::dims dims_mkl_order = TFShapeToMklDnnDimsInNCHW(input.shape(), FORMAT_NHWC); memory::desc input_md = memory::desc(dims_mkl_order, MklDnnType(), diff --git a/tensorflow/core/kernels/mkl/mkl_softmax_op.cc b/tensorflow/core/kernels/mkl/mkl_softmax_op.cc index 9291d2c099165a..2fd9e16f1e25a9 100644 --- a/tensorflow/core/kernels/mkl/mkl_softmax_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_softmax_op.cc @@ -266,7 +266,14 @@ class MklSoftmaxOp : public OpKernel { fwdParams.aarch64_counter = MklSoftmaxPrimitiveFactory::IncrementCounter(); #endif - MklDnnThreadPool eigen_tp(context); + // Create the oneDNN wrapper over eigen threadpool and set max threads + // in oneDNN. + Eigen::ThreadPoolInterface* eigen_interface = + context->device() + ->tensorflow_cpu_worker_threads() + ->workers->AsEigenThreadPool(); + tsl::OneDnnThreadPool eigen_tp(eigen_interface, + ThreadPoolUseCallerThread()); MklSoftmaxPrimitive* softmax_fwd = MklSoftmaxPrimitiveFactory::Get(fwdParams); diff --git a/tensorflow/core/kernels/mkl/mkl_transpose_op.cc b/tensorflow/core/kernels/mkl/mkl_transpose_op.cc index 7ad7e517edc813..c7ee23e508221f 100644 --- a/tensorflow/core/kernels/mkl/mkl_transpose_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_transpose_op.cc @@ -83,7 +83,14 @@ Status MKLTransposeND(OpKernelContext* context, const Tensor& in_tensor, out.SetUsrMem(in_dims, out_strides, out_tensor); std::vector net; - MklDnnThreadPool eigen_tp(context); + // Create the oneDNN wrapper over eigen threadpool and set max threads + // in oneDNN. + Eigen::ThreadPoolInterface* eigen_interface = + context->device() + ->tensorflow_cpu_worker_threads() + ->workers->AsEigenThreadPool(); + tsl::OneDnnThreadPool eigen_tp(eigen_interface, + ThreadPoolUseCallerThread()); auto* prim = FindOrCreateReorder(in.GetUsrMem(), out.GetUsrMem()); transpose_stream.reset(CreateStream(&eigen_tp, prim->GetEngine())); in.SetUsrMemDataHandle(&in_tensor, transpose_stream); diff --git a/tensorflow/core/util/BUILD b/tensorflow/core/util/BUILD index 7076340d5a5df8..f7436f0a35c7b1 100644 --- a/tensorflow/core/util/BUILD +++ b/tensorflow/core/util/BUILD @@ -163,7 +163,6 @@ filegroup( "matmul_autotune.h", "matmul_bcast.h", "mirror_pad_mode.h", - "mkl_threadpool.h", "mkl_util.h", "onednn_env_vars.h", "overflow.h", @@ -296,9 +295,9 @@ filegroup( filegroup( name = "mkl_util_hdrs", srcs = [ - "mkl_threadpool.h", "mkl_util.h", "onednn_env_vars.h", + "//tensorflow/tsl/util:onednn_util_hdrs", ], visibility = ["//tensorflow/core:__pkg__"], ) diff --git a/tensorflow/core/util/mkl_util.h b/tensorflow/core/util/mkl_util.h index 322991376f9924..7dffa8e347e0ac 100644 --- a/tensorflow/core/util/mkl_util.h +++ b/tensorflow/core/util/mkl_util.h @@ -36,13 +36,13 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/util/env_var.h" -#include "tensorflow/core/util/mkl_threadpool.h" #include "tensorflow/core/util/onednn_env_vars.h" #include "tensorflow/core/util/padding.h" #include "tensorflow/core/util/tensor_format.h" #ifdef DNNL_AARCH64_USE_ACL #include "tensorflow/core/platform/mutex.h" #endif +#include "tensorflow/tsl/util/onednn_threadpool.h" using dnnl::engine; using dnnl::memory; @@ -274,7 +274,7 @@ inline bool array_cmp(const T* a1, const T* a2, size_t size) { return true; } -inline dnnl::stream* CreateStream(MklDnnThreadPool* eigen_tp, +inline dnnl::stream* CreateStream(tsl::OneDnnThreadPool* eigen_tp, const engine& engine) { #ifndef ENABLE_ONEDNN_OPENMP if (eigen_tp != nullptr) { @@ -663,9 +663,16 @@ inline void ExecutePrimitive(const std::vector& net, DCHECK(net_args); DCHECK_EQ(net.size(), net_args->size()); std::unique_ptr cpu_stream; - MklDnnThreadPool eigen_tp; + // Create the oneDNN wrapper over eigen threadpool and set max threads + // in oneDNN. + tsl::OneDnnThreadPool eigen_tp; if (context != nullptr) { - eigen_tp = MklDnnThreadPool(context); + Eigen::ThreadPoolInterface* eigen_interface = + context->device() + ->tensorflow_cpu_worker_threads() + ->workers->AsEigenThreadPool(); + eigen_tp = + tsl::OneDnnThreadPool(eigen_interface, ThreadPoolUseCallerThread()); cpu_stream.reset(CreateStream(&eigen_tp, cpu_engine)); } else { cpu_stream.reset(CreateStream(nullptr, cpu_engine)); @@ -1596,9 +1603,14 @@ class MklDnnData { reorder_memory_ = new memory(op_md, engine); auto* prim = FindOrCreateReorder(user_memory_, reorder_memory_); std::shared_ptr cpu_stream; - MklDnnThreadPool eigen_tp; + tsl::OneDnnThreadPool eigen_tp; if (context != nullptr) { - eigen_tp = MklDnnThreadPool(context); + Eigen::ThreadPoolInterface* eigen_interface = + context->device() + ->tensorflow_cpu_worker_threads() + ->workers->AsEigenThreadPool(); + eigen_tp = + tsl::OneDnnThreadPool(eigen_interface, ThreadPoolUseCallerThread()); cpu_stream.reset(CreateStream(&eigen_tp, prim->GetEngine())); } else { cpu_stream.reset(CreateStream(nullptr, prim->GetEngine())); @@ -1663,9 +1675,14 @@ class MklDnnData { reorder_memory_ = new memory(op_md, engine, reorder_data_handle); auto* prim = FindOrCreateReorder(user_memory_, reorder_memory_); std::shared_ptr cpu_stream; - MklDnnThreadPool eigen_tp; + tsl::OneDnnThreadPool eigen_tp; if (context != nullptr) { - eigen_tp = MklDnnThreadPool(context); + Eigen::ThreadPoolInterface* eigen_interface = + context->device() + ->tensorflow_cpu_worker_threads() + ->workers->AsEigenThreadPool(); + eigen_tp = + tsl::OneDnnThreadPool(eigen_interface, ThreadPoolUseCallerThread()); cpu_stream.reset(CreateStream(&eigen_tp, prim->GetEngine())); } else { cpu_stream.reset(CreateStream(nullptr, prim->GetEngine())); @@ -1774,9 +1791,14 @@ class MklDnnData { net_args.push_back( {{DNNL_ARG_FROM, *reorder_memory_}, {DNNL_ARG_TO, *user_memory_}}); std::shared_ptr cpu_stream; - MklDnnThreadPool eigen_tp; + tsl::OneDnnThreadPool eigen_tp; if (ctx != nullptr) { - eigen_tp = MklDnnThreadPool(ctx); + Eigen::ThreadPoolInterface* eigen_interface = + ctx->device() + ->tensorflow_cpu_worker_threads() + ->workers->AsEigenThreadPool(); + eigen_tp = + tsl::OneDnnThreadPool(eigen_interface, ThreadPoolUseCallerThread()); cpu_stream.reset(CreateStream(&eigen_tp, prim->GetEngine())); } else { cpu_stream.reset(CreateStream(nullptr, prim->GetEngine())); diff --git a/tensorflow/tsl/util/BUILD b/tensorflow/tsl/util/BUILD index 0e624d68a91e22..e7d376ad04f369 100644 --- a/tensorflow/tsl/util/BUILD +++ b/tensorflow/tsl/util/BUILD @@ -303,6 +303,18 @@ filegroup( visibility = set_external_visibility(["//tensorflow/core/util:__pkg__"]), ) +filegroup( + name = "onednn_util_hdrs", + srcs = [ + "onednn_threadpool.h", + ], + visibility = set_external_visibility([ + "//tensorflow/compiler/xla:__pkg__", + "//tensorflow/core:__pkg__", + "//tensorflow/core/framework:__pkg__", + ]), +) + filegroup( name = "android_test_hdrs", testonly = 1, diff --git a/tensorflow/core/util/mkl_threadpool.h b/tensorflow/tsl/util/onednn_threadpool.h similarity index 72% rename from tensorflow/core/util/mkl_threadpool.h rename to tensorflow/tsl/util/onednn_threadpool.h index e160c75661265a..ed9989bd3c2511 100644 --- a/tensorflow/core/util/mkl_threadpool.h +++ b/tensorflow/tsl/util/onednn_threadpool.h @@ -14,8 +14,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CORE_UTIL_MKL_THREADPOOL_H_ -#define TENSORFLOW_CORE_UTIL_MKL_THREADPOOL_H_ +#ifndef TENSORFLOW_TSL_UTIL_ONEDNN_THREADPOOL_H_ +#define TENSORFLOW_TSL_UTIL_ONEDNN_THREADPOOL_H_ #ifdef INTEL_MKL #include @@ -25,17 +25,15 @@ limitations under the License. #include #include -#include "dnnl_threadpool.hpp" -#include "dnnl.hpp" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/platform/blocking_counter.h" -#include "tensorflow/core/platform/cpu_info.h" -#include "tensorflow/core/platform/threadpool.h" -#include "tensorflow/core/util/onednn_env_vars.h" - #define EIGEN_USE_THREADS -namespace tensorflow { +#include "dnnl.hpp" +#include "dnnl_threadpool.hpp" +#include "tensorflow/tsl/platform/blocking_counter.h" +#include "tensorflow/tsl/platform/cpu_info.h" +#include "tensorflow/tsl/platform/threadpool.h" + +namespace tsl { #ifndef ENABLE_ONEDNN_OPENMP using dnnl::threadpool_interop::threadpool_iface; @@ -75,28 +73,20 @@ inline void run_jobs(bool balance, int i, int n, int njobs, } } -struct MklDnnThreadPool : public threadpool_iface { - MklDnnThreadPool() = default; +class OneDnnThreadPool : public threadpool_iface { + public: + OneDnnThreadPool() = default; - MklDnnThreadPool(OpKernelContext* ctx, int num_threads = -1) { - eigen_interface_ = ctx->device() - ->tensorflow_cpu_worker_threads() - ->workers->AsEigenThreadPool(); -#if DNNL_VERSION_MAJOR >= 3 || \ - (DNNL_VERSION_MAJOR == 2 && DNNL_VERSION_MINOR >= 7) - if (num_threads == -1) { - dnnl_threadpool_interop_set_max_concurrency( - eigen_interface_->NumThreads()); - num_threads_ = eigen_interface_->NumThreads(); - } else { - dnnl_threadpool_interop_set_max_concurrency(num_threads); - num_threads_ = num_threads; - } -#else - num_threads_ = - num_threads == -1 ? eigen_interface_->NumThreads() : num_threads; -#endif // DNNL_VERSION_MAJOR >= 3 || - // (DNNL_VERSION_MAJOR == 2 && DNNL_VERSION_MINOR >= 7) + OneDnnThreadPool(Eigen::ThreadPoolInterface* eigen_interface, + int num_threads = -1) + : eigen_interface_(eigen_interface) { + set_num_and_max_threads(num_threads); + } + OneDnnThreadPool(Eigen::ThreadPoolInterface* eigen_interface, + bool can_use_caller_thread, int num_threads = -1) + : eigen_interface_(eigen_interface), + can_use_caller_thread_(can_use_caller_thread) { + set_num_and_max_threads(num_threads); } virtual int get_num_threads() const override { return num_threads_; } virtual bool get_in_parallel() const override { @@ -121,10 +111,10 @@ struct MklDnnThreadPool : public threadpool_iface { // If use_caller_thread, schedule njobs-1 jobs to thread pool and run last // job directly. const bool use_caller_thread = - ThreadPoolUseCallerThread() && nthr == port::NumSchedulableCPUs(); + can_use_caller_thread_ && nthr == port::NumSchedulableCPUs(); const int njobs_to_schedule = use_caller_thread ? njobs - 1 : njobs; - BlockingCounter counter(njobs_to_schedule); + tsl::BlockingCounter counter(njobs_to_schedule); std::function handle_range = [=, &handle_range, &counter]( int first, int last) { while (last - first > 1) { @@ -152,25 +142,38 @@ struct MklDnnThreadPool : public threadpool_iface { counter.Wait(); } - ~MklDnnThreadPool() {} + ~OneDnnThreadPool() {} private: Eigen::ThreadPoolInterface* eigen_interface_ = nullptr; - int num_threads_ = 1; // Execute in caller thread. + int num_threads_ = 1; // Execute in caller thread. + bool can_use_caller_thread_ = false; // true if the user set the env variable + // to use caller thread also. + inline void set_num_and_max_threads(int num_threads) { + num_threads_ = + num_threads == -1 ? eigen_interface_->NumThreads() : num_threads; +#if DNNL_VERSION_MAJOR >= 3 || \ + (DNNL_VERSION_MAJOR == 2 && DNNL_VERSION_MINOR >= 7) + dnnl_threadpool_interop_set_max_concurrency(num_threads_); +#endif // DNNL_VERSION_MAJOR >= 3 || + // (DNNL_VERSION_MAJOR == 2 && DNNL_VERSION_MINOR >= 7) + } }; #else -// This struct was just added to enable successful OMP-based build. -struct MklDnnThreadPool { - MklDnnThreadPool() = default; - MklDnnThreadPool(OpKernelContext* ctx) {} - MklDnnThreadPool(OpKernelContext* ctx, int num_threads) {} +// This class was just added to enable successful OMP-based build. +class OneDnnThreadPool { + public: + OneDnnThreadPool() = default; + OneDnnThreadPool(Eigen::ThreadPoolInterface* eigen_interface) {} + OneDnnThreadPool(Eigen::ThreadPoolInterface* eigen_interface, + bool can_use_caller_thread, int num_threads = -1) {} }; #endif // !ENABLE_ONEDNN_OPENMP -} // namespace tensorflow +} // namespace tsl #endif // INTEL_MKL -#endif // TENSORFLOW_CORE_UTIL_MKL_THREADPOOL_H_ +#endif // TENSORFLOW_TSL_UTIL_ONEDNN_THREADPOOL_H_ From c4ab857287e4f06da6b10a2298d724356bf3567f Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 10 Jul 2023 16:19:39 -0700 Subject: [PATCH 096/376] Integrate LLVM at llvm/llvm-project@86943d863ef6 Updates LLVM usage to match [86943d863ef6](https://github.com/llvm/llvm-project/commit/86943d863ef6) PiperOrigin-RevId: 547013315 --- third_party/llvm/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl index a439c3f924583c..7a772fb5657237 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "bfd94882f2648e2a5ed651bca6cfeb4fb7788b86" - LLVM_SHA256 = "9c23082138fa8706ebb4c4e5e2f1873d954202b9b76ef4ee52542ab00262f5dd" + LLVM_COMMIT = "86943d863ef66d68bf79d3e2f0ec2c205814b235" + LLVM_SHA256 = "b37024a8d88985b69b240e4222932379f794906f602464c4c31c516580508a93" tf_http_archive( name = name, From 0a404f577aa50ca8316f02757b204afa25de6c4b Mon Sep 17 00:00:00 2001 From: Fiona Lang Date: Mon, 10 Jul 2023 16:20:17 -0700 Subject: [PATCH 097/376] Update ops.Tensor references to //third_party/tensorflow/python/framework/tensor.py. PiperOrigin-RevId: 547013448 --- tensorflow/python/keras/BUILD | 8 ++++--- tensorflow/python/keras/backend.py | 24 +++++++++---------- tensorflow/python/keras/callbacks.py | 5 ++-- tensorflow/python/keras/engine/BUILD | 9 +++++-- tensorflow/python/keras/engine/base_layer.py | 10 ++++---- .../python/keras/engine/base_layer_v1.py | 12 +++++----- .../keras/engine/base_preprocessing_layer.py | 3 ++- .../python/keras/engine/data_adapter.py | 10 ++++---- tensorflow/python/keras/engine/functional.py | 3 ++- .../python/keras/engine/keras_tensor.py | 15 ++++++------ tensorflow/python/keras/engine/node.py | 5 ++-- tensorflow/python/keras/engine/training.py | 4 +++- .../python/keras/layers/legacy_rnn/BUILD | 1 + .../keras/layers/legacy_rnn/rnn_cell_impl.py | 5 ++-- tensorflow/python/keras/optimizer_v2/BUILD | 5 +++- .../keras/optimizer_v2/gradient_descent.py | 5 ++-- .../python/keras/optimizer_v2/optimizer_v2.py | 9 +++---- .../python/keras/optimizer_v2/rmsprop.py | 4 +++- tensorflow/python/keras/saving/utils_v1/BUILD | 1 + .../keras/saving/utils_v1/export_output.py | 13 +++++----- tensorflow/python/keras/utils/BUILD | 7 ++++-- .../python/keras/utils/control_flow_util.py | 4 ++-- tensorflow/python/keras/utils/data_utils.py | 4 ++-- .../python/keras/utils/metrics_utils.py | 3 ++- tensorflow/python/keras/utils/tf_utils.py | 12 +++++----- 25 files changed, 104 insertions(+), 77 deletions(-) diff --git a/tensorflow/python/keras/BUILD b/tensorflow/python/keras/BUILD index 0bf6466b963526..c271a5ef77a784 100755 --- a/tensorflow/python/keras/BUILD +++ b/tensorflow/python/keras/BUILD @@ -55,14 +55,14 @@ py_library( "//tensorflow/core:protos_all_py", "//tensorflow/python/client", "//tensorflow/python/client:session", - "//tensorflow/python/distribute:distribute_coordinator", "//tensorflow/python/distribute:distribute_lib", - "//tensorflow/python/distribute:multi_worker_util", - "//tensorflow/python/framework", + "//tensorflow/python/framework:composite_tensor", + "//tensorflow/python/framework:config", "//tensorflow/python/framework:constant_op", "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", "//tensorflow/python/framework:sparse_tensor", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_conversion", "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/keras/distribute:distribute_coordinator_utils", @@ -159,6 +159,8 @@ py_library( deps = [ ":backend", "//tensorflow/python/framework:errors", + "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/keras/distribute:distributed_file_utils", "//tensorflow/python/keras/distribute:worker_training_state", "//tensorflow/python/keras/protobuf:projector_config_proto_py", diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py index bf012e88e4313f..18ebf20ab0eff3 100644 --- a/tensorflow/python/keras/backend.py +++ b/tensorflow/python/keras/backend.py @@ -43,9 +43,9 @@ from tensorflow.python.framework import func_graph from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_conversion from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import tensor_util from tensorflow.python.keras import backend_config from tensorflow.python.keras.distribute import distribute_coordinator_utils as dc @@ -213,7 +213,7 @@ def cast_to_floatx(x): dtype('float32') """ - if isinstance(x, (ops.Tensor, + if isinstance(x, (tensor_lib.Tensor, variables_module.Variable, sparse_tensor.SparseTensor)): return math_ops.cast(x, dtype=floatx()) @@ -672,9 +672,9 @@ def _current_graph(op_input_list, graph=None): # TODO(josh11b): Note that we exclude subclasses of Tensor. Need to clean this # up. if (isinstance(op_input, ( - ops.Operation, ops.Tensor, composite_tensor.CompositeTensor)) and - ((not isinstance(op_input, ops.Tensor)) - or type(op_input) == ops.Tensor)): # pylint: disable=unidiomatic-typecheck + ops.Operation, tensor_lib.Tensor, composite_tensor.CompositeTensor)) and + ((not isinstance(op_input, tensor_lib.Tensor)) + or type(op_input) == tensor_lib.Tensor)): # pylint: disable=unidiomatic-typecheck graph_element = op_input else: graph_element = _as_graph_element(op_input) @@ -1266,7 +1266,7 @@ def is_keras_tensor(x): """ if not isinstance(x, - (ops.Tensor, variables_module.Variable, + (tensor_lib.Tensor, variables_module.Variable, sparse_tensor.SparseTensor, ragged_tensor.RaggedTensor, keras_tensor.KerasTensor)): raise ValueError('Unexpectedly found an instance of type `' + str(type(x)) + @@ -1339,7 +1339,7 @@ def placeholder(shape=None, spec = ragged_tensor.RaggedTensorSpec( shape=shape, dtype=dtype, ragged_rank=ragged_rank) else: - spec = tensor_spec.TensorSpec( + spec = tensor_lib.TensorSpec( shape=shape, dtype=dtype, name=name) x = keras_tensor.keras_tensor_from_type_spec(spec, name=name) else: @@ -3859,7 +3859,7 @@ def print_tensor(x, message='', summarize=3): Returns: The same tensor `x`, unchanged. """ - if isinstance(x, ops.Tensor) and hasattr(x, 'graph'): + if isinstance(x, tensor_lib.Tensor) and hasattr(x, 'graph'): with get_graph().as_default(): op = logging_ops.print_v2( message, x, output_stream=sys.stdout, summarize=summarize) @@ -4423,7 +4423,7 @@ def compute_masked_output(mask_t, flat_out, flat_mask): return tuple( array_ops.where_v2(m, o, fm) for m, o, fm in zip(tiled_mask_t, flat_out, flat_mask)) - elif isinstance(input_length, ops.Tensor): + elif isinstance(input_length, tensor_lib.Tensor): if go_backwards: max_len = math_ops.reduce_max(input_length, axis=0) rev_input_length = math_ops.subtract(max_len - 1, input_length) @@ -4476,7 +4476,7 @@ def _step(time, output_ta_t, prev_output, *states): flat_state = nest.flatten(states) flat_new_state = nest.flatten(new_states) for state, new_state in zip(flat_state, flat_new_state): - if isinstance(new_state, ops.Tensor): + if isinstance(new_state, tensor_lib.Tensor): new_state.set_shape(state.shape) flat_final_state = compute_masked_output(mask_t, flat_new_state, flat_state) @@ -4513,7 +4513,7 @@ def _step(time, output_ta_t, *states): flat_state = nest.flatten(states) flat_new_state = nest.flatten(new_states) for state, new_state in zip(flat_state, flat_new_state): - if isinstance(new_state, ops.Tensor): + if isinstance(new_state, tensor_lib.Tensor): new_state.set_shape(state.shape) flat_output = nest.flatten(output) @@ -4536,7 +4536,7 @@ def _step(time, output_ta_t, *states): # static shape inference def set_shape(output_): - if isinstance(output_, ops.Tensor): + if isinstance(output_, tensor_lib.Tensor): shape = output_.shape.as_list() shape[0] = time_steps shape[1] = batch diff --git a/tensorflow/python/keras/callbacks.py b/tensorflow/python/keras/callbacks.py index 9d414ceb1488c6..c5cbd1873c3058 100644 --- a/tensorflow/python/keras/callbacks.py +++ b/tensorflow/python/keras/callbacks.py @@ -41,6 +41,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.keras import backend from tensorflow.python.keras.distribute import distributed_file_utils from tensorflow.python.keras.distribute import worker_training_state @@ -1965,10 +1966,10 @@ def on_epoch_begin(self, epoch, logs=None): lr = self.schedule(epoch, lr) except TypeError: # Support for old API for backward compatibility lr = self.schedule(epoch) - if not isinstance(lr, (ops.Tensor, float, np.float32, np.float64)): + if not isinstance(lr, (tensor_lib.Tensor, float, np.float32, np.float64)): raise ValueError('The output of the "schedule" function ' 'should be float.') - if isinstance(lr, ops.Tensor) and not lr.dtype.is_floating: + if isinstance(lr, tensor_lib.Tensor) and not lr.dtype.is_floating: raise ValueError('The dtype of Tensor should be float') backend.set_value(self.model.optimizer.lr, backend.get_value(lr)) if self.verbose > 0: diff --git a/tensorflow/python/keras/engine/BUILD b/tensorflow/python/keras/engine/BUILD index 51e77dc218925e..2098b1650bc920 100644 --- a/tensorflow/python/keras/engine/BUILD +++ b/tensorflow/python/keras/engine/BUILD @@ -54,6 +54,8 @@ py_library( "//tensorflow/python/distribute:reduce_util", "//tensorflow/python/distribute/coordinator:cluster_coordinator", "//tensorflow/python/eager:monitoring", + "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_conversion", "//tensorflow/python/keras:activations", "//tensorflow/python/keras:backend", @@ -130,6 +132,7 @@ py_library( ":input_spec", ":node", "//third_party/py/numpy", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_conversion", "//tensorflow/python/platform:tf_logging", "//tensorflow/python/util:compat", @@ -188,7 +191,7 @@ py_library( srcs_version = "PY3", deps = [ "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_conversion", "//tensorflow/python/keras/utils:dataset_creator", "//tensorflow/python/keras/utils:engine_utils", @@ -221,7 +224,7 @@ py_library( srcs_version = "PY3", deps = [ "//tensorflow/python/framework:dtypes", - "//tensorflow/python/framework:tensor_spec", + "//tensorflow/python/framework:tensor", "//tensorflow/python/keras/utils:object_identity", "//tensorflow/python/lib/io:lib", "//tensorflow/python/util:nest", @@ -238,6 +241,7 @@ py_library( ":base_layer", "//tensorflow/python/data", "//tensorflow/python/eager:monitoring", + "//tensorflow/python/framework:tensor", "//tensorflow/python/keras:backend", "//tensorflow/python/module", ], @@ -250,6 +254,7 @@ py_library( deps = [ ":base_layer_utils", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_util", "//tensorflow/python/keras:backend", "//tensorflow/python/keras/utils:tf_utils", diff --git a/tensorflow/python/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py index 48e0862fe1b05a..6b97d9ce9904ed 100644 --- a/tensorflow/python/keras/engine/base_layer.py +++ b/tensorflow/python/keras/engine/base_layer.py @@ -39,8 +39,8 @@ from tensorflow.python.framework import func_graph from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_conversion -from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import tensor_util from tensorflow.python.keras import backend from tensorflow.python.keras import constraints @@ -91,7 +91,7 @@ # TODO(mdan): Should we have a single generic type for types that can be passed # to tf.cast? -_AUTOCAST_TYPES = (ops.Tensor, sparse_tensor.SparseTensor, +_AUTOCAST_TYPES = (tensor_lib.Tensor, sparse_tensor.SparseTensor, ragged_tensor.RaggedTensor) @@ -822,7 +822,7 @@ def compute_output_signature(self, input_signature): TypeError: If input_signature contains a non-TensorSpec object. """ def check_type_return_shape(s): - if not isinstance(s, tensor_spec.TensorSpec): + if not isinstance(s, tensor_lib.TensorSpec): raise TypeError('Only TensorSpec signature types are supported, ' 'but saw signature entry: {}.'.format(s)) return s.shape @@ -835,7 +835,7 @@ def check_type_return_shape(s): # dtype. dtype = input_dtypes[0] return nest.map_structure( - lambda s: tensor_spec.TensorSpec(dtype=dtype, shape=s), + lambda s: tensor_lib.TensorSpec(dtype=dtype, shape=s), output_shape) def _keras_tensor_symbolic_call(self, inputs, input_masks, args, kwargs): @@ -847,7 +847,7 @@ def _keras_tensor_symbolic_call(self, inputs, input_masks, args, kwargs): # TODO(fchollet): consider py_func as an alternative, which # would enable us to run the underlying graph if needed. input_signature = nest.map_structure( - lambda x: tensor_spec.TensorSpec(shape=x.shape, dtype=x.dtype), + lambda x: tensor_lib.TensorSpec(shape=x.shape, dtype=x.dtype), inputs) output_signature = self.compute_output_signature(input_signature) return nest.map_structure(keras_tensor.KerasTensor, output_signature) diff --git a/tensorflow/python/keras/engine/base_layer_v1.py b/tensorflow/python/keras/engine/base_layer_v1.py index 6836bdfd9eeecd..3cb10b362125b4 100644 --- a/tensorflow/python/keras/engine/base_layer_v1.py +++ b/tensorflow/python/keras/engine/base_layer_v1.py @@ -32,8 +32,8 @@ from tensorflow.python.framework import func_graph from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor from tensorflow.python.framework import tensor_conversion -from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import tensor_util from tensorflow.python.keras import backend from tensorflow.python.keras import constraints @@ -601,7 +601,7 @@ def compute_output_signature(self, input_signature): TypeError: If input_signature contains a non-TensorSpec object. """ def check_type_return_shape(s): - if not isinstance(s, tensor_spec.TensorSpec): + if not isinstance(s, tensor.TensorSpec): raise TypeError('Only TensorSpec signature types are supported, ' 'but saw signature entry: {}.'.format(s)) return s.shape @@ -614,7 +614,7 @@ def check_type_return_shape(s): # dtype. dtype = input_dtypes[0] return nest.map_structure( - lambda s: tensor_spec.TensorSpec(dtype=dtype, shape=s), + lambda s: tensor.TensorSpec(dtype=dtype, shape=s), output_shape) @generic_utils.default @@ -1815,15 +1815,15 @@ def _maybe_cast_inputs(self, inputs): dtypes.as_dtype(compute_dtype).is_floating): def f(x): """Cast a single Tensor or TensorSpec to the compute dtype.""" - cast_types = (ops.Tensor, sparse_tensor.SparseTensor, + cast_types = (tensor.Tensor, sparse_tensor.SparseTensor, ragged_tensor.RaggedTensor) if (isinstance(x, cast_types) and x.dtype.is_floating and x.dtype.base_dtype.name != compute_dtype): return math_ops.cast(x, compute_dtype) - elif isinstance(x, tensor_spec.TensorSpec) and x.dtype.is_floating: + elif isinstance(x, tensor.TensorSpec) and x.dtype.is_floating: # Inputs may be TensorSpecs when this function is called from # model._set_inputs. - return tensor_spec.TensorSpec(x.shape, compute_dtype, x.name) + return tensor.TensorSpec(x.shape, compute_dtype, x.name) else: return x return nest.map_structure(f, inputs) diff --git a/tensorflow/python/keras/engine/base_preprocessing_layer.py b/tensorflow/python/keras/engine/base_preprocessing_layer.py index 4aef248622d49c..0d29c21c83dcbb 100644 --- a/tensorflow/python/keras/engine/base_preprocessing_layer.py +++ b/tensorflow/python/keras/engine/base_preprocessing_layer.py @@ -24,6 +24,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor from tensorflow.python.keras import backend from tensorflow.python.keras.engine import data_adapter from tensorflow.python.keras.engine.base_layer import Layer @@ -468,7 +469,7 @@ def convert_to_list(values, sparse_default_value=None): values, default_value=sparse_default_value) values = backend.get_value(dense_tensor) - if isinstance(values, ops.Tensor): + if isinstance(values, tensor.Tensor): values = backend.get_value(values) # We may get passed a ndarray or the code above may give us a ndarray. diff --git a/tensorflow/python/keras/engine/data_adapter.py b/tensorflow/python/keras/engine/data_adapter.py index 3119c98c29ee94..50a58757df34bc 100644 --- a/tensorflow/python/keras/engine/data_adapter.py +++ b/tensorflow/python/keras/engine/data_adapter.py @@ -32,9 +32,9 @@ from tensorflow.python.eager import context from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors -from tensorflow.python.framework import ops from tensorflow.python.framework import smart_cond from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor from tensorflow.python.framework import tensor_conversion from tensorflow.python.framework import tensor_shape from tensorflow.python.keras import backend @@ -566,7 +566,7 @@ def _is_composite(v): return _is_scipy_sparse(v) def _is_tensor_or_composite(v): - if isinstance(v, (ops.Tensor, np.ndarray)): + if isinstance(v, (tensor.Tensor, np.ndarray)): return True return _is_composite(v) @@ -1460,7 +1460,7 @@ def expand_1d(data): def _expand_single_1d_tensor(t): # Leaves `CompositeTensor`s as-is. - if (isinstance(t, ops.Tensor) and + if (isinstance(t, tensor.Tensor) and isinstance(t.shape, tensor_shape.TensorShape) and t.shape.rank == 1): return array_ops.expand_dims_v2(t, axis=-1) return t @@ -1669,9 +1669,9 @@ def _get_tensor_types(): try: import pandas as pd # pylint: disable=g-import-not-at-top - return (ops.Tensor, np.ndarray, pd.Series, pd.DataFrame) + return (tensor.Tensor, np.ndarray, pd.Series, pd.DataFrame) except ImportError: - return (ops.Tensor, np.ndarray) + return (tensor.Tensor, np.ndarray) def _is_scipy_sparse(x): diff --git a/tensorflow/python/keras/engine/functional.py b/tensorflow/python/keras/engine/functional.py index 4181094edf7540..920f934eb43dfd 100644 --- a/tensorflow/python/keras/engine/functional.py +++ b/tensorflow/python/keras/engine/functional.py @@ -23,6 +23,7 @@ from tensorflow.python.eager import context from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.keras import backend from tensorflow.python.keras.engine import base_layer from tensorflow.python.keras.engine import base_layer_utils @@ -602,7 +603,7 @@ def _flatten_to_reference_inputs(self, tensors): def _conform_to_reference_input(self, tensor, ref_input): """Set shape and dtype based on `keras.Input`s.""" - if isinstance(tensor, ops.Tensor): + if isinstance(tensor, tensor_lib.Tensor): # Allow (None,) and (None, 1) Tensors to be passed interchangeably. Use # the shape specified by the `keras.Input`. t_shape = tensor.shape diff --git a/tensorflow/python/keras/engine/keras_tensor.py b/tensorflow/python/keras/engine/keras_tensor.py index a8878466367bcb..03936288362def 100644 --- a/tensorflow/python/keras/engine/keras_tensor.py +++ b/tensorflow/python/keras/engine/keras_tensor.py @@ -16,10 +16,9 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import type_spec as type_spec_module from tensorflow.python.keras.utils import object_identity from tensorflow.python.ops import array_ops @@ -143,7 +142,7 @@ def shape(self): @classmethod def from_tensor(cls, tensor): """Convert a traced (composite)tensor to a representative KerasTensor.""" - if isinstance(tensor, ops.Tensor): + if isinstance(tensor, tensor_lib.Tensor): name = getattr(tensor, 'name', None) type_spec = type_spec_module.type_spec_from_value(tensor) inferred_value = None @@ -304,7 +303,7 @@ def __str__(self): def __repr__(self): symbolic_description = '' inferred_value_string = '' - if isinstance(self.type_spec, tensor_spec.TensorSpec): + if isinstance(self.type_spec, tensor_lib.TensorSpec): type_spec_string = 'shape=%s dtype=%s' % (self.shape, self.dtype.name) else: type_spec_string = 'type_spec=%s' % self.type_spec @@ -361,7 +360,7 @@ def name(self): @classmethod def _overload_all_operators(cls, tensor_class): # pylint: disable=invalid-name """Register overloads for all operators.""" - for operator in ops.Tensor.OVERLOADABLE_OPERATORS: + for operator in tensor_lib.Tensor.OVERLOADABLE_OPERATORS: cls._overload_operator(tensor_class, operator) # We include `experimental_ref` for versions of TensorFlow that @@ -389,7 +388,7 @@ def _overload_operator(cls, tensor_class, operator): # pylint: disable=invalid- setattr(cls, operator, tensor_oper) -KerasTensor._overload_all_operators(ops.Tensor) # pylint: disable=protected-access +KerasTensor._overload_all_operators(tensor_lib.Tensor) # pylint: disable=protected-access class SparseKerasTensor(KerasTensor): @@ -556,11 +555,11 @@ def __next__(self): # 1. we do a check w/ isinstance because a key lookup based on class # would miss subclasses # 2. a list allows us to control lookup ordering -# We include ops.Tensor -> KerasTensor in the first position as a fastpath, +# We include tensor.Tensor -> KerasTensor in the first position as a fastpath, # *and* include object -> KerasTensor at the end as a catch-all. # We can re-visit these choices in the future as needed. keras_tensor_classes = [ - (ops.Tensor, KerasTensor), + (tensor_lib.Tensor, KerasTensor), (sparse_tensor.SparseTensor, SparseKerasTensor), (ragged_tensor.RaggedTensor, RaggedKerasTensor), (object, KerasTensor) diff --git a/tensorflow/python/keras/engine/node.py b/tensorflow/python/keras/engine/node.py index 657d41840fe6e1..c4d409a74d7d96 100644 --- a/tensorflow/python/keras/engine/node.py +++ b/tensorflow/python/keras/engine/node.py @@ -22,6 +22,7 @@ import numpy as np from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_util from tensorflow.python.keras import backend from tensorflow.python.keras.engine import base_layer_utils @@ -80,7 +81,7 @@ def __init__(self, if not ops.executing_eagerly_outside_functions(): # Create TensorFlowOpLayers if needed (in TF1) for obj in self._flat_arguments: - if (isinstance(obj, ops.Tensor) and + if (isinstance(obj, tensor_lib.Tensor) and base_layer_utils.needs_keras_history( obj, ignore_call_context=True)): base_layer_utils.create_keras_history(obj) @@ -178,7 +179,7 @@ def _serialize_keras_tensor(t): if isinstance(t, np.ndarray): return t.tolist() - if isinstance(t, ops.Tensor): + if isinstance(t, tensor_lib.Tensor): return backend.get_value(t).tolist() return t diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py index a8216ffe65c5ec..56fcbaaeb4e4bc 100644 --- a/tensorflow/python/keras/engine/training.py +++ b/tensorflow/python/keras/engine/training.py @@ -38,6 +38,7 @@ from tensorflow.python.framework import func_graph from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_shape from tensorflow.python.keras import backend from tensorflow.python.keras import callbacks as callbacks_module @@ -2908,7 +2909,8 @@ def _multi_worker_concat(v, strategy): def _is_scalar(x): - return isinstance(x, (ops.Tensor, variables.Variable)) and x.shape.rank == 0 + return isinstance( + x, (tensor_lib.Tensor, variables.Variable)) and x.shape.rank == 0 def write_scalar_summaries(logs, step): diff --git a/tensorflow/python/keras/layers/legacy_rnn/BUILD b/tensorflow/python/keras/layers/legacy_rnn/BUILD index 54cdae02fd8fd6..cc69a90723249d 100644 --- a/tensorflow/python/keras/layers/legacy_rnn/BUILD +++ b/tensorflow/python/keras/layers/legacy_rnn/BUILD @@ -26,6 +26,7 @@ py_library( "//tensorflow/python/framework:constant_op", "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_conversion", "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/framework:tensor_util", diff --git a/tensorflow/python/keras/layers/legacy_rnn/rnn_cell_impl.py b/tensorflow/python/keras/layers/legacy_rnn/rnn_cell_impl.py index ed41c9f2b196a4..b7bcf9483180a3 100644 --- a/tensorflow/python/keras/layers/legacy_rnn/rnn_cell_impl.py +++ b/tensorflow/python/keras/layers/legacy_rnn/rnn_cell_impl.py @@ -29,6 +29,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.framework import tensor_conversion from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util @@ -126,7 +127,7 @@ def _concat(prefix, suffix, static=False): ValueError: if prefix or suffix was `None` and asked for dynamic Tensors out. """ - if isinstance(prefix, ops.Tensor): + if isinstance(prefix, tensor.Tensor): p = prefix p_static = tensor_util.constant_value(prefix) if p.shape.ndims == 0: @@ -140,7 +141,7 @@ def _concat(prefix, suffix, static=False): p = ( constant_op.constant(p.as_list(), dtype=dtypes.int32) if p.is_fully_defined() else None) - if isinstance(suffix, ops.Tensor): + if isinstance(suffix, tensor.Tensor): s = suffix s_static = tensor_util.constant_value(suffix) if s.shape.ndims == 0: diff --git a/tensorflow/python/keras/optimizer_v2/BUILD b/tensorflow/python/keras/optimizer_v2/BUILD index 0971812a4a1c86..070e32f7c500f2 100644 --- a/tensorflow/python/keras/optimizer_v2/BUILD +++ b/tensorflow/python/keras/optimizer_v2/BUILD @@ -44,7 +44,10 @@ py_library( "//tensorflow/python/distribute:parameter_server_strategy_v2", "//tensorflow/python/distribute:reduce_util", "//tensorflow/python/distribute:values", - "//tensorflow/python/framework", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:indexed_slices", + "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_conversion", "//tensorflow/python/keras:backend", "//tensorflow/python/keras:backend_config", diff --git a/tensorflow/python/keras/optimizer_v2/gradient_descent.py b/tensorflow/python/keras/optimizer_v2/gradient_descent.py index 74428c719bb547..87c3b543578973 100644 --- a/tensorflow/python/keras/optimizer_v2/gradient_descent.py +++ b/tensorflow/python/keras/optimizer_v2/gradient_descent.py @@ -15,7 +15,7 @@ """SGD optimizer implementation.""" # pylint: disable=g-classes-have-attributes -from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.keras.optimizer_v2 import optimizer_v2 from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_resource_variable_ops @@ -108,7 +108,8 @@ def __init__(self, self._set_hyper("decay", self._initial_decay) self._momentum = False - if isinstance(momentum, ops.Tensor) or callable(momentum) or momentum > 0: + if isinstance( + momentum, tensor.Tensor) or callable(momentum) or momentum > 0: self._momentum = True if isinstance(momentum, (int, float)) and (momentum < 0 or momentum > 1): raise ValueError("`momentum` must be between [0, 1].") diff --git a/tensorflow/python/keras/optimizer_v2/optimizer_v2.py b/tensorflow/python/keras/optimizer_v2/optimizer_v2.py index b00af22388d534..755363545e21ae 100644 --- a/tensorflow/python/keras/optimizer_v2/optimizer_v2.py +++ b/tensorflow/python/keras/optimizer_v2/optimizer_v2.py @@ -30,6 +30,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import indexed_slices from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.framework import tensor_util from tensorflow.python.keras import backend from tensorflow.python.keras import initializers @@ -684,7 +685,7 @@ def _distributed_apply(self, distribution, grads_and_vars, name, apply_state): def apply_grad_to_update_var(var, grad): """Apply gradient to variable.""" - if isinstance(var, ops.Tensor): + if isinstance(var, tensor.Tensor): raise NotImplementedError("Trying to update a Tensor ", var) apply_kwargs = {} @@ -787,7 +788,7 @@ def _set_hyper(self, name, value): prev_value = self._hyper[name] if (callable(prev_value) or isinstance(prev_value, - (ops.Tensor, int, float, + (tensor.Tensor, int, float, learning_rate_schedule.LearningRateSchedule)) or isinstance(value, learning_rate_schedule.LearningRateSchedule)): self._hyper[name] = value @@ -965,8 +966,8 @@ def _create_hypers(self): with self._distribution_strategy_scope(): # Iterate hyper values deterministically. for name, value in sorted(self._hyper.items()): - if isinstance(value, - (ops.Tensor, tf_variables.Variable)) or callable(value): + if isinstance( + value, (tensor.Tensor, tf_variables.Variable)) or callable(value): # The check for `callable` covers the usage when `value` is a # `LearningRateSchedule`, in which case it does not need to create a # variable. diff --git a/tensorflow/python/keras/optimizer_v2/rmsprop.py b/tensorflow/python/keras/optimizer_v2/rmsprop.py index a0d9d07febe452..f752c41eeaf903 100644 --- a/tensorflow/python/keras/optimizer_v2/rmsprop.py +++ b/tensorflow/python/keras/optimizer_v2/rmsprop.py @@ -18,6 +18,7 @@ import numpy as np from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.framework import tensor_conversion from tensorflow.python.keras import backend_config from tensorflow.python.keras.optimizer_v2 import optimizer_v2 @@ -139,7 +140,8 @@ def __init__(self, self._set_hyper("rho", rho) self._momentum = False - if isinstance(momentum, ops.Tensor) or callable(momentum) or momentum > 0: + if isinstance( + momentum, tensor.Tensor) or callable(momentum) or momentum > 0: self._momentum = True if isinstance(momentum, (int, float)) and (momentum < 0 or momentum > 1): raise ValueError("`momentum` must be between [0, 1].") diff --git a/tensorflow/python/keras/saving/utils_v1/BUILD b/tensorflow/python/keras/saving/utils_v1/BUILD index 411e567ae6cc74..b94009e93e52e1 100644 --- a/tensorflow/python/keras/saving/utils_v1/BUILD +++ b/tensorflow/python/keras/saving/utils_v1/BUILD @@ -39,6 +39,7 @@ py_library( "//tensorflow/python/framework:constant_op", "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_util", "//tensorflow/python/platform:gfile", "//tensorflow/python/platform:tf_logging", diff --git a/tensorflow/python/keras/saving/utils_v1/export_output.py b/tensorflow/python/keras/saving/utils_v1/export_output.py index e6a595bf5acaf3..4ad09a95a2cc9a 100644 --- a/tensorflow/python/keras/saving/utils_v1/export_output.py +++ b/tensorflow/python/keras/saving/utils_v1/export_output.py @@ -20,6 +20,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.framework import tensor_util from tensorflow.python.keras.saving.utils_v1 import signature_def_utils as unexported_signature_utils from tensorflow.python.saved_model import signature_def_utils @@ -86,7 +87,7 @@ def _wrap_and_check_outputs( for key, value in outputs.items(): error_name = error_label or single_output_default_name key = self._check_output_key(key, error_name) - if not isinstance(value, ops.Tensor): + if not isinstance(value, tensor.Tensor): raise ValueError( '{} output value must be a Tensor; got {}.'.format( error_name, value)) @@ -128,12 +129,12 @@ def __init__(self, scores=None, classes=None): `Tensor` with the correct dtype. """ if (scores is not None - and not (isinstance(scores, ops.Tensor) + and not (isinstance(scores, tensor.Tensor) and scores.dtype.is_floating)): raise ValueError('Classification scores must be a float32 Tensor; ' 'got {}'.format(scores)) if (classes is not None - and not (isinstance(classes, ops.Tensor) + and not (isinstance(classes, tensor.Tensor) and dtypes.as_dtype(classes.dtype) == dtypes.string)): raise ValueError('Classification classes must be a string Tensor; ' 'got {}'.format(classes)) @@ -175,7 +176,7 @@ def __init__(self, value): Raises: ValueError: if the value is not a `Tensor` with dtype tf.float32. """ - if not (isinstance(value, ops.Tensor) and value.dtype.is_floating): + if not (isinstance(value, tensor.Tensor) and value.dtype.is_floating): raise ValueError('Regression output value must be a float32 Tensor; ' 'got {}'.format(value)) self._value = value @@ -334,7 +335,7 @@ def _wrap_and_check_metrics(self, metrics): val_name = key + self._SEPARATOR_CHAR + self.METRIC_VALUE_SUFFIX op_name = key + self._SEPARATOR_CHAR + self.METRIC_UPDATE_SUFFIX - if not isinstance(metric_val, ops.Tensor): + if not isinstance(metric_val, tensor.Tensor): raise ValueError( '{} output value must be a Tensor; got {}.'.format( key, metric_val)) @@ -347,7 +348,7 @@ def _wrap_and_check_metrics(self, metrics): # We must wrap any ops (or variables) in a Tensor before export, as the # SignatureDef proto expects tensors only. See b/109740581 metric_op_tensor = metric_op - if not isinstance(metric_op, ops.Tensor): + if not isinstance(metric_op, tensor.Tensor): with ops.control_dependencies([metric_op]): metric_op_tensor = constant_op.constant([], name='metric_op_wrapper') diff --git a/tensorflow/python/keras/utils/BUILD b/tensorflow/python/keras/utils/BUILD index af18b3fbecbb1e..3763a24e60657d 100644 --- a/tensorflow/python/keras/utils/BUILD +++ b/tensorflow/python/keras/utils/BUILD @@ -60,6 +60,7 @@ py_library( ":generic_utils", ":io_utils", ":tf_inspect", + "//tensorflow/python/framework:tensor", ], ) @@ -94,7 +95,7 @@ py_library( "//tensorflow/python/eager:context", "//tensorflow/python/framework:composite_tensor", "//tensorflow/python/framework:ops", - "//tensorflow/python/framework:smart_cond", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/framework:tensor_util", "//tensorflow/python/ops:control_flow_ops", @@ -151,11 +152,13 @@ py_library( ], srcs_version = "PY3", deps = [ + ":engine_utils", ":generic_utils", ":tf_utils", "//tensorflow/python/distribute", - "//tensorflow/python/framework", "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_conversion", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:check_ops", diff --git a/tensorflow/python/keras/utils/control_flow_util.py b/tensorflow/python/keras/utils/control_flow_util.py index 0730cd6bc77978..067570eb6d64b0 100644 --- a/tensorflow/python/keras/utils/control_flow_util.py +++ b/tensorflow/python/keras/utils/control_flow_util.py @@ -17,8 +17,8 @@ This file is copied from tensorflow/python/ops/control_flow_util.py. """ -from tensorflow.python.framework import ops from tensorflow.python.framework import smart_cond as smart_module +from tensorflow.python.framework import tensor from tensorflow.python.framework import tensor_util from tensorflow.python.ops import cond from tensorflow.python.ops import variables @@ -124,7 +124,7 @@ def constant_value(pred): # pylint: disable=invalid-name TypeError: If `pred` is not a Variable, Tensor or bool, or Python integer 1 or 0. """ - if isinstance(pred, ops.Tensor): + if isinstance(pred, tensor.Tensor): return tensor_util.constant_value(pred) if pred in {0, 1}: # Accept 1/0 as valid boolean values return bool(pred) diff --git a/tensorflow/python/keras/utils/data_utils.py b/tensorflow/python/keras/utils/data_utils.py index f0ba68db8c4365..c9c4be339e1e9b 100644 --- a/tensorflow/python/keras/utils/data_utils.py +++ b/tensorflow/python/keras/utils/data_utils.py @@ -36,7 +36,7 @@ import numpy as np -from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from six.moves.urllib.request import urlopen from tensorflow.python.keras.utils import tf_inspect from tensorflow.python.keras.utils.generic_utils import Progbar @@ -89,7 +89,7 @@ def chunk_read(response, chunk_size=8192, reporthook=None): def is_generator_or_sequence(x): """Check if `x` is a Keras generator type.""" builtin_iterators = (str, list, tuple, dict, set, frozenset) - if isinstance(x, (ops.Tensor, np.ndarray) + builtin_iterators): + if isinstance(x, (tensor.Tensor, np.ndarray) + builtin_iterators): return False return (tf_inspect.isgenerator(x) or isinstance(x, Sequence) or diff --git a/tensorflow/python/keras/utils/metrics_utils.py b/tensorflow/python/keras/utils/metrics_utils.py index cc1621b826aee4..fde05826279ecf 100644 --- a/tensorflow/python/keras/utils/metrics_utils.py +++ b/tensorflow/python/keras/utils/metrics_utils.py @@ -24,6 +24,7 @@ from tensorflow.python.distribute import distribute_lib from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.framework import tensor_conversion from tensorflow.python.keras import backend from tensorflow.python.keras.utils import losses_utils @@ -148,7 +149,7 @@ def decorated(metric_obj, *args): # Results need to be wrapped in a `tf.identity` op to ensure # correct execution order. if isinstance(raw_result, - (ops.Tensor, variables_module.Variable, float, int)): + (tensor.Tensor, variables_module.Variable, float, int)): result_t = array_ops.identity(raw_result) elif isinstance(raw_result, dict): result_t = { diff --git a/tensorflow/python/keras/utils/tf_utils.py b/tensorflow/python/keras/utils/tf_utils.py index 34ec293514c66c..91c1aab5cdbada 100644 --- a/tensorflow/python/keras/utils/tf_utils.py +++ b/tensorflow/python/keras/utils/tf_utils.py @@ -24,8 +24,8 @@ from tensorflow.python.framework import composite_tensor from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import tensor_util from tensorflow.python.framework import type_spec from tensorflow.python.keras import backend as K @@ -42,7 +42,7 @@ def is_tensor_or_tensor_list(v): v = nest.flatten(v) - if v and isinstance(v[0], ops.Tensor): + if v and isinstance(v[0], tensor_lib.Tensor): return True else: return False @@ -314,7 +314,7 @@ def is_symbolic_tensor(tensor): Returns: True for symbolic tensors, False for eager tensors. """ - if isinstance(tensor, ops.Tensor): + if isinstance(tensor, tensor_lib.Tensor): return hasattr(tensor, 'graph') elif is_extension_type(tensor): component_tensors = nest.flatten(tensor, expand_composites=True) @@ -378,7 +378,7 @@ def type_spec_from_value(value): # Get a TensorSpec for array-like data without # converting the data to a Tensor if hasattr(value, 'shape') and hasattr(value, 'dtype'): - return tensor_spec.TensorSpec(value.shape, value.dtype) + return tensor_lib.TensorSpec(value.shape, value.dtype) else: return type_spec.type_spec_from_value(value) @@ -477,7 +477,7 @@ def get_tensor_spec(t, dynamic_batch=False, name=None): hasattr(t._keras_history[0], '_type_spec')): return t._keras_history[0]._type_spec elif hasattr(t, 'shape') and hasattr(t, 'dtype'): - spec = tensor_spec.TensorSpec(shape=t.shape, dtype=t.dtype, name=name) + spec = tensor_lib.TensorSpec(shape=t.shape, dtype=t.dtype, name=name) else: return None # Allow non-Tensors to pass through. @@ -521,7 +521,7 @@ def sync_to_numpy_or_python_type(tensors): return tensors.fetch() def _to_single_numpy_or_python_type(t): - if isinstance(t, ops.Tensor): + if isinstance(t, tensor_lib.Tensor): x = t.numpy() return x.item() if np.ndim(x) == 0 else x return t # Don't turn ragged or sparse tensors to NumPy. From 31ac289e17f24799959b697c3f162c1b7f2697ff Mon Sep 17 00:00:00 2001 From: Armando Ugalde Velasco Date: Mon, 10 Jul 2023 16:21:16 -0700 Subject: [PATCH 098/376] Remove target processing time collection in ClientHeartbeat PiperOrigin-RevId: 547013688 --- tensorflow/core/data/service/client/data_service_client.cc | 5 ----- 1 file changed, 5 deletions(-) diff --git a/tensorflow/core/data/service/client/data_service_client.cc b/tensorflow/core/data/service/client/data_service_client.cc index c9fc81902b5362..13eebf9a32b27b 100644 --- a/tensorflow/core/data/service/client/data_service_client.cc +++ b/tensorflow/core/data/service/client/data_service_client.cc @@ -427,11 +427,6 @@ void DataServiceClient::Heartbeat() TF_LOCKS_EXCLUDED(mu_) { req.set_blocked_round(round_robin_round_limit_.value()); } } - { - mutex_lock l(mu_); - double target_processing_time_nsec = ctx_->GetTargetProcessingTimeNsec(); - req.set_target_processing_time_nsec(target_processing_time_nsec); - } ClientHeartbeatResponse resp; Status s = dispatcher_->ClientHeartbeat(req, resp); if (!s.ok()) { From 8d82dc2c8761074715c79bff54e3bc0be651d745 Mon Sep 17 00:00:00 2001 From: Laura Pak Date: Mon, 10 Jul 2023 16:25:19 -0700 Subject: [PATCH 099/376] Add HandleFromInput signature to use index input and return Status. PiperOrigin-RevId: 547014686 --- tensorflow/core/framework/resource_mgr.cc | 14 ++++- tensorflow/core/framework/resource_mgr.h | 6 ++ .../core/framework/resource_mgr_test.cc | 59 +++++++++++++++++++ 3 files changed, 77 insertions(+), 2 deletions(-) diff --git a/tensorflow/core/framework/resource_mgr.cc b/tensorflow/core/framework/resource_mgr.cc index 513f49aee578c7..eacd62a62492cc 100644 --- a/tensorflow/core/framework/resource_mgr.cc +++ b/tensorflow/core/framework/resource_mgr.cc @@ -380,17 +380,27 @@ string ContainerInfo::DebugString() const { "]"); } -// TODO(b/228388547) users of this method should be migrated to the one below. +// TODO(b/228388547) users of this method should be migrated to the ones below. const ResourceHandle& HandleFromInput(OpKernelContext* ctx, int input) { return ctx->input(input).flat()(0); } +Status HandleFromInput(OpKernelContext* ctx, int input, + ResourceHandle* handle) { + TF_ASSIGN_OR_RETURN(const Tensor* tensor, ctx->get_input(input)); + if (tensor->NumElements() == 0) { + return absl::InvalidArgumentError("Empty resource handle"); + } + *handle = tensor->flat()(0); + return OkStatus(); +} + Status HandleFromInput(OpKernelContext* ctx, StringPiece input, ResourceHandle* handle) { const Tensor* tensor; TF_RETURN_IF_ERROR(ctx->input(input, &tensor)); if (tensor->NumElements() == 0) { - return errors::InvalidArgument("Empty resouce handle"); + return absl::InvalidArgumentError("Empty resource handle"); } *handle = tensor->flat()(0); return OkStatus(); diff --git a/tensorflow/core/framework/resource_mgr.h b/tensorflow/core/framework/resource_mgr.h index ffdcb9aebfe038..fc043dd84bad1d 100644 --- a/tensorflow/core/framework/resource_mgr.h +++ b/tensorflow/core/framework/resource_mgr.h @@ -366,6 +366,12 @@ Status MakeResourceHandleToOutput(OpKernelContext* context, int output_index, // Returns a resource handle from a numbered op input. const ResourceHandle& HandleFromInput(OpKernelContext* ctx, int input); + +// Safely returns a resource handle from a numbered op input. +// Prevents segfault by checking for empty resource handle. +Status HandleFromInput(OpKernelContext* ctx, int input, ResourceHandle* handle); +// Returns a resource handle by name, as defined in the OpDef. +// Also prevents segfault by checking for empty resource handle. Status HandleFromInput(OpKernelContext* ctx, StringPiece input, ResourceHandle* handle); diff --git a/tensorflow/core/framework/resource_mgr_test.cc b/tensorflow/core/framework/resource_mgr_test.cc index e5ad8f0b094c25..5c079cb2ac7318 100644 --- a/tensorflow/core/framework/resource_mgr_test.cc +++ b/tensorflow/core/framework/resource_mgr_test.cc @@ -16,10 +16,12 @@ limitations under the License. #include "tensorflow/core/framework/resource_mgr.h" #include +#include #include "tensorflow/core/framework/device_attributes.pb.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/resource_handle.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/refcount.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -371,6 +373,63 @@ TEST(ResourceHandleTest, CRUD) { } } +TEST(ResourceHandleTest, ResourceFromValidIntInput) { + ResourceMgr resource_mgr(""); + OpKernelContext::Params params; + params.resource_manager = &resource_mgr; + StubDevice device("device_name"); + params.device = &device; + OpKernelContext ctx(¶ms, 1); + + ResourceHandleProto proto; + proto.set_device("cpu:0"); + proto.set_container("test_container"); + proto.set_name("test_var"); + auto handle = std::make_unique(proto); + auto expected_summary = + "ResourceHandle(name=\"test_var\", device=\"cpu:0\", " + "container=\"test_container\", type=\"\", dtype and shapes : \"[ ]\")"; + EXPECT_EQ(handle->SummarizeValue(), expected_summary); + + Tensor arg0(DT_RESOURCE, TensorShape({2})); + arg0.flat()(0) = *handle; + std::vector inputs{TensorValue(new Tensor(arg0))}; + params.inputs = inputs; + + ResourceHandle get_int_handle; + TF_ASSERT_OK(HandleFromInput(&ctx, 0, &get_int_handle)); + EXPECT_EQ(get_int_handle.SummarizeValue(), expected_summary); + delete inputs.at(0).tensor; +} + +TEST(ResourceHandleTest, ResourceFromInvalidIntInput) { + ResourceMgr resource_mgr(""); + OpKernelContext::Params params; + params.resource_manager = &resource_mgr; + StubDevice device("device_name"); + params.device = &device; + OpKernelContext ctx(¶ms, 0); + + ResourceHandle get_int_handle; + EXPECT_FALSE(HandleFromInput(&ctx, 0, &get_int_handle).ok()); +} + +TEST(ResourceHandleTest, ResourceFromIntInputWithoutResource) { + ResourceMgr resource_mgr(""); + OpKernelContext::Params params; + params.resource_manager = &resource_mgr; + StubDevice device("device_name"); + params.device = &device; + OpKernelContext ctx(¶ms, 1); + + std::vector inputs{TensorValue(new Tensor())}; + params.inputs = inputs; + + ResourceHandle get_int_handle; + EXPECT_FALSE(HandleFromInput(&ctx, 0, &get_int_handle).ok()); + delete inputs.at(0).tensor; +} + TEST(ResourceHandleTest, LookupDeleteGenericResource) { ResourceMgr resource_mgr(""); OpKernelContext::Params params; From ce10eb91b2460d790df27022f98c4bc9bc3347df Mon Sep 17 00:00:00 2001 From: Anlun Xu Date: Mon, 10 Jul 2023 16:41:52 -0700 Subject: [PATCH 100/376] [xla:gpu] Add min_graph_size flag to xla-gpu-outline-cuda-graphs pass So that the test outline_cuda_graphs.mlir will use the default value 2, while XLA will use the value in debug options. PiperOrigin-RevId: 547018532 --- .../gpu/transforms/outline_cuda_graphs.cc | 20 ++++++++++--------- .../mlir/backends/gpu/transforms/passes.cc | 3 ++- .../xla/mlir/backends/gpu/transforms/passes.h | 3 ++- .../mlir/backends/gpu/transforms/passes.td | 5 +++++ .../service/gpu/compile_module_to_llvm_ir.cc | 1 + 5 files changed, 21 insertions(+), 11 deletions(-) diff --git a/tensorflow/compiler/xla/mlir/backends/gpu/transforms/outline_cuda_graphs.cc b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/outline_cuda_graphs.cc index 7c5f31c8bbb973..3a921bb4af7b5d 100644 --- a/tensorflow/compiler/xla/mlir/backends/gpu/transforms/outline_cuda_graphs.cc +++ b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/outline_cuda_graphs.cc @@ -57,8 +57,10 @@ class OutlineCudaGraphsPass : public impl::OutlineCudaGraphsPassBase { public: OutlineCudaGraphsPass() = default; - explicit OutlineCudaGraphsPass(int cuda_graph_level) - : cuda_graph_level_(cuda_graph_level) {} + explicit OutlineCudaGraphsPass(int cuda_graph_level, int min_graph_size) + : cuda_graph_level_(cuda_graph_level) { + this->min_graph_size_ = min_graph_size; + } void runOnOperation() override; @@ -326,16 +328,14 @@ static std::vector GetGraphCaptureFuncArgs(const CaptureSequence& seq) { // and replace them with an XLA Gpu runtime function call. static LogicalResult Outline(unsigned ordinal, CustomCallDeclarations& custom_calls, - CaptureSequence& seq) { + CaptureSequence& seq, int min_graph_size) { // Only operations that have to be moved into the graph capture function // represent Gpu computations. unsigned num_move_captures = llvm::count_if(seq, [](auto capture) { return capture.second == OpCapturePattern::Capture::kMove; }); DebugOptions debug_options = GetDebugOptionsFromFlags(); - int32_t graph_capture_threshold = - debug_options.xla_gpu_cuda_graph_min_graph_size(); - if (num_move_captures < graph_capture_threshold) return failure(); + if (num_move_captures < min_graph_size) return failure(); SymbolTable& sym_table = custom_calls.sym_table(); MLIRContext* ctx = sym_table.getOp()->getContext(); @@ -479,7 +479,8 @@ void OutlineCudaGraphsPass::runOnOperation() { unsigned ordinal = 1; // entry point will be exported with ordinal 0 for (auto& seq : CollectCaptureSequences(getAnalysis(), getOperation(), patterns)) { - if (succeeded(Outline(ordinal, custom_calls, seq))) ordinal++; + if (succeeded(Outline(ordinal, custom_calls, seq, min_graph_size_))) + ordinal++; } } @@ -488,8 +489,9 @@ std::unique_ptr> createOutlineCudaGraphsPass() { } std::unique_ptr> createOutlineCudaGraphsPass( - int cuda_graph_level) { - return std::make_unique(cuda_graph_level); + int cuda_graph_level, int min_graph_size) { + return std::make_unique(cuda_graph_level, + min_graph_size); } } // namespace gpu diff --git a/tensorflow/compiler/xla/mlir/backends/gpu/transforms/passes.cc b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/passes.cc index 2912fb1df9f4a6..d0c4861dbf3cac 100644 --- a/tensorflow/compiler/xla/mlir/backends/gpu/transforms/passes.cc +++ b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/passes.cc @@ -41,7 +41,8 @@ void populateXlaGpuRuntimePasses(mlir::OpPassManager& pm, pm.addPass(createSymbolDCEPass()); // Clean up unused global constants. // Outline CUDA-Graph-compatible operations into graph capture functions. - pm.addPass(createOutlineCudaGraphsPass(opts.cuda_graph_level)); + pm.addPass( + createOutlineCudaGraphsPass(opts.cuda_graph_level, opts.min_graph_size)); if (opts.enable_concurrent_region) { pm.addPass(createAddConcurrentRegionsPass()); } diff --git a/tensorflow/compiler/xla/mlir/backends/gpu/transforms/passes.h b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/passes.h index ebf4058365ef10..3917c73cec634c 100644 --- a/tensorflow/compiler/xla/mlir/backends/gpu/transforms/passes.h +++ b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/passes.h @@ -42,6 +42,7 @@ struct GpuPipelineOpts { // CUDA Graphs, which allows us to amortize the cost of launching multiple // device kernels. int32_t cuda_graph_level = 0; + int32_t min_graph_size = 0; bool enable_concurrent_region = false; }; @@ -101,7 +102,7 @@ std::unique_ptr> createOutlineCudaGraphsPass(); std::unique_ptr> -createOutlineCudaGraphsPass(int32_t cuda_graph_level); +createOutlineCudaGraphsPass(int32_t cuda_graph_level, int32_t min_graph_size); //===----------------------------------------------------------------------===// // Passes for marking concurrent region in CUDA graph capture function. diff --git a/tensorflow/compiler/xla/mlir/backends/gpu/transforms/passes.td b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/passes.td index 1c7dbb55bc9f40..c2daab9f2a6ef4 100644 --- a/tensorflow/compiler/xla/mlir/backends/gpu/transforms/passes.td +++ b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/passes.td @@ -185,6 +185,11 @@ def OutlineCudaGraphsPass : }]; let constructor = "createOutlineCudaGraphsPass()"; + + let options = [ + Option<"min_graph_size_", "min_graph_size", "int64_t", /*default=*/"2", + "The minimum size of the outlined CUDA graph function.">, + ]; } //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/xla/service/gpu/compile_module_to_llvm_ir.cc b/tensorflow/compiler/xla/service/gpu/compile_module_to_llvm_ir.cc index f950b5850697ea..c91d9cce339c53 100644 --- a/tensorflow/compiler/xla/service/gpu/compile_module_to_llvm_ir.cc +++ b/tensorflow/compiler/xla/service/gpu/compile_module_to_llvm_ir.cc @@ -117,6 +117,7 @@ static Status LowerToXlaGpuRuntime(mlir::ModuleOp module, GpuPipelineOpts opts; opts.cuda_graph_level = debug_options.xla_gpu_cuda_graph_level(); + opts.min_graph_size = debug_options.xla_gpu_cuda_graph_min_graph_size(); opts.enable_concurrent_region = debug_options.xla_gpu_cuda_graph_enable_concurrent_region(); populateXlaGpuRuntimePasses(pm, thunk_sequence, opts); From 9787698e589dac891caa7687ccb63fd0f6969d5d Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 10 Jul 2023 16:54:42 -0700 Subject: [PATCH 101/376] Update TFRT dependency to use revision http://github.com/tensorflow/runtime/commit/3e5fa08c9a184710601dbf8a1c7b52eaa306124d. PiperOrigin-RevId: 547021376 --- third_party/tf_runtime/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/tf_runtime/workspace.bzl b/third_party/tf_runtime/workspace.bzl index 03069a8bb25376..8a0084f1896fb4 100644 --- a/third_party/tf_runtime/workspace.bzl +++ b/third_party/tf_runtime/workspace.bzl @@ -6,8 +6,8 @@ def repo(): """Imports TFRT.""" # Attention: tools parse and update these lines. - TFRT_COMMIT = "46664e9fe40d2a204bbc8b629a778cb8032fd8c1" - TFRT_SHA256 = "2d676f58d4e803a0f4e4a9de951a5e2663dad51d535147ae748aec806522e6bf" + TFRT_COMMIT = "3e5fa08c9a184710601dbf8a1c7b52eaa306124d" + TFRT_SHA256 = "56d5a34fa884ec6eee7a602d90ee8387099c488bf5c3dc21a45ae8e19e2e27ad" tf_http_archive( name = "tf_runtime", From c1e3844df08fdefbbd3546b096a346d6166b1e36 Mon Sep 17 00:00:00 2001 From: David Dunleavy Date: Mon, 10 Jul 2023 17:10:47 -0700 Subject: [PATCH 102/376] [NFC] Change uses of get_compatible_with_cloud to get_compatible_with_portable. PiperOrigin-RevId: 547024987 --- tensorflow/python/profiler/internal/BUILD | 4 ++-- tensorflow/tools/graph_transforms/BUILD | 4 ++-- tensorflow/tools/optimization/BUILD | 4 ++-- tensorflow/tsl/distributed_runtime/preemption/BUILD | 4 ++-- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/tensorflow/python/profiler/internal/BUILD b/tensorflow/python/profiler/internal/BUILD index 3ba487d65d4c68..a3a6e8d5bd78cd 100644 --- a/tensorflow/python/profiler/internal/BUILD +++ b/tensorflow/python/profiler/internal/BUILD @@ -1,5 +1,5 @@ load("//tensorflow:strict.default.bzl", "py_strict_library", "py_strict_test") -load("//tensorflow:tensorflow.default.bzl", "cuda_py_strict_test", "get_compatible_with_cloud", "tf_python_pybind_extension") +load("//tensorflow:tensorflow.default.bzl", "cuda_py_strict_test", "get_compatible_with_portable", "tf_python_pybind_extension") load("//tensorflow/core/profiler/builds:build_config.bzl", "tf_profiler_copts") package( @@ -136,7 +136,7 @@ tf_python_pybind_extension( cc_library( name = "python_hooks", hdrs = ["python_hooks.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), copts = tf_profiler_copts() + ["-fexceptions"], features = ["-use_header_modules"], # Incompatible with -fexceptions. visibility = [ diff --git a/tensorflow/tools/graph_transforms/BUILD b/tensorflow/tools/graph_transforms/BUILD index effcbfcf85a386..37394c6eb9a010 100644 --- a/tensorflow/tools/graph_transforms/BUILD +++ b/tensorflow/tools/graph_transforms/BUILD @@ -9,7 +9,7 @@ load( "tf_cc_test", "tf_copts", ) -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud", "tf_py_strict_test") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable", "tf_py_strict_test") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -25,7 +25,7 @@ cc_library( hdrs = [ "transform_utils.h", ], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), copts = tf_copts(), visibility = ["//visibility:public"], deps = [ diff --git a/tensorflow/tools/optimization/BUILD b/tensorflow/tools/optimization/BUILD index f6ab1fb0d2a64a..fc12208932e00b 100644 --- a/tensorflow/tools/optimization/BUILD +++ b/tensorflow/tools/optimization/BUILD @@ -6,7 +6,7 @@ load( "tf_cc_binary", "tf_cuda_library", ) -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -18,7 +18,7 @@ tf_cuda_library( name = "optimization_pass_runner_lib", srcs = ["optimization_pass_runner.cc"], hdrs = ["optimization_pass_runner.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_base", diff --git a/tensorflow/tsl/distributed_runtime/preemption/BUILD b/tensorflow/tsl/distributed_runtime/preemption/BUILD index 97572b3285543a..6eeb7e8ae47bff 100644 --- a/tensorflow/tsl/distributed_runtime/preemption/BUILD +++ b/tensorflow/tsl/distributed_runtime/preemption/BUILD @@ -1,6 +1,6 @@ load("//tensorflow/tsl/platform:rules_cc.bzl", "cc_library") load("//tensorflow/tsl/platform:build_config.bzl", "tsl_cc_test") -load("//tensorflow/tsl:tsl.default.bzl", "get_compatible_with_cloud", "tsl_grpc_cc_dependencies") +load("//tensorflow/tsl:tsl.default.bzl", "get_compatible_with_portable", "tsl_grpc_cc_dependencies") load("//tensorflow/tsl:tsl.bzl", "set_external_visibility") package( @@ -15,7 +15,7 @@ cc_library( name = "preemption_notifier", srcs = ["preemption_notifier.cc"], hdrs = ["preemption_notifier.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "//tensorflow/tsl/platform:env", "//tensorflow/tsl/platform:errors", From d45191017d0619155f1ab0fc6eee4d17a20010c3 Mon Sep 17 00:00:00 2001 From: Fiona Lang Date: Mon, 10 Jul 2023 17:13:18 -0700 Subject: [PATCH 103/376] Update ops.Tensor references to //third_party/tensorflow/python/framework/tensor.py. PiperOrigin-RevId: 547025579 --- tensorflow/python/eager/BUILD | 6 ++- tensorflow/python/eager/backprop.py | 9 ++-- tensorflow/python/eager/backprop_util.py | 5 ++- tensorflow/python/eager/lift_to_graph.py | 3 +- tensorflow/python/eager/ops_test.py | 5 ++- tensorflow/python/eager/wrap_function.py | 9 ++-- tensorflow/python/ops/numpy_ops/BUILD | 5 ++- .../python/ops/numpy_ops/np_array_ops.py | 37 +++++++++++----- tensorflow/python/ops/numpy_ops/np_arrays.py | 4 +- .../python/ops/numpy_ops/np_arrays_test.py | 7 +-- .../python/ops/numpy_ops/np_math_ops.py | 33 +++++++------- tensorflow/python/ops/parallel_for/BUILD | 4 +- .../ops/parallel_for/control_flow_ops.py | 3 +- .../python/ops/parallel_for/gradients.py | 3 +- tensorflow/python/ops/parallel_for/pfor.py | 43 ++++++++++--------- tensorflow/python/ops/structured/BUILD | 9 ++-- .../ops/structured/structured_array_ops.py | 5 ++- .../structured/structured_array_ops_test.py | 6 +-- .../ops/structured/structured_tensor.py | 42 ++++++++++-------- .../structured/structured_tensor_dynamic.py | 4 +- .../ops/structured/structured_tensor_test.py | 30 ++++++------- 21 files changed, 156 insertions(+), 116 deletions(-) diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD index 9ad4d8c0d58c9a..b7bc8350e135f8 100644 --- a/tensorflow/python/eager/BUILD +++ b/tensorflow/python/eager/BUILD @@ -688,6 +688,7 @@ py_strict_library( "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:indexed_slices", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/framework:tensor_util", "//tensorflow/python/framework:type_spec", @@ -720,6 +721,7 @@ py_strict_library( "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:indexed_slices", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_util", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:handle_data_util", @@ -917,6 +919,7 @@ cuda_py_strict_test( "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:errors", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/framework:test_lib", "//tensorflow/python/layers", @@ -997,6 +1000,7 @@ py_strict_library( deps = [ "//tensorflow/python/framework:func_graph", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:op_selector", "//tensorflow/python/ops:resource_variable_ops", @@ -1041,8 +1045,8 @@ py_strict_library( "//tensorflow/python/framework:func_graph", "//tensorflow/python/framework:ops", "//tensorflow/python/framework:sparse_tensor", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_shape", - "//tensorflow/python/framework:tensor_spec", "//tensorflow/python/framework:tensor_util", "//tensorflow/python/ops:resource_variable_ops", "//tensorflow/python/ops:variable_scope", diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py index 9c488d1b133526..16fc829ee6ca20 100644 --- a/tensorflow/python/eager/backprop.py +++ b/tensorflow/python/eager/backprop.py @@ -32,6 +32,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import indexed_slices from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util from tensorflow.python.framework import type_spec @@ -405,7 +406,7 @@ def decorated(*args, **kwds): def _ensure_unique_tensor_objects(parameter_positions, args): - """Make each of the parameter_positions in args a unique ops.Tensor object. + """Make each of the parameter_positions in args a unique tensor_lib.Tensor object. Ensure that each parameter is treated independently. For example: @@ -594,18 +595,18 @@ def _aggregate_grads(gradients): if len(gradients) == 1: return gradients[0] - if all(isinstance(g, ops.Tensor) for g in gradients): + if all(isinstance(g, tensor_lib.Tensor) for g in gradients): return gen_math_ops.add_n(gradients) else: assert all( - isinstance(g, (ops.Tensor, indexed_slices.IndexedSlices)) + isinstance(g, (tensor_lib.Tensor, indexed_slices.IndexedSlices)) for g in gradients) return backprop_util.AggregateIndexedSlicesGradients(gradients) def _num_elements(grad): """The number of elements in the `grad` tensor.""" - if isinstance(grad, ops.Tensor): + if isinstance(grad, tensor_lib.Tensor): shape_tuple = grad._shape_tuple() # pylint: disable=protected-access elif isinstance(grad, indexed_slices.IndexedSlices): shape_tuple = grad.values._shape_tuple() # pylint: disable=protected-access diff --git a/tensorflow/python/eager/backprop_util.py b/tensorflow/python/eager/backprop_util.py index b6509a48307a94..c4fe1158dc3c8e 100644 --- a/tensorflow/python/eager/backprop_util.py +++ b/tensorflow/python/eager/backprop_util.py @@ -19,6 +19,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import indexed_slices from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import handle_data_util @@ -67,7 +68,7 @@ def IsTrainable(tensor_or_dtype): def FlattenNestedIndexedSlices(grad): assert isinstance(grad, indexed_slices.IndexedSlices) - if isinstance(grad.values, ops.Tensor): + if isinstance(grad.values, tensor_lib.Tensor): return grad else: assert isinstance(grad.values, indexed_slices.IndexedSlices) @@ -85,7 +86,7 @@ def AggregateIndexedSlicesGradients(grads): grads = [g for g in grads if g is not None] # If any gradient is a `Tensor`, sum them up and return a dense tensor # object. - if any(isinstance(g, ops.Tensor) for g in grads): + if any(isinstance(g, tensor_lib.Tensor) for g in grads): return math_ops.add_n(grads) # The following `_as_indexed_slices_list` casts ids of IndexedSlices into diff --git a/tensorflow/python/eager/lift_to_graph.py b/tensorflow/python/eager/lift_to_graph.py index 5e8be440cb4af6..7d7ac1b8e0dff2 100644 --- a/tensorflow/python/eager/lift_to_graph.py +++ b/tensorflow/python/eager/lift_to_graph.py @@ -19,6 +19,7 @@ from tensorflow.python.framework import func_graph from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.ops import array_ops from tensorflow.python.ops import op_selector from tensorflow.python.ops import resource_variable_ops @@ -31,7 +32,7 @@ def _as_operation(op_or_tensor): - if isinstance(op_or_tensor, ops.Tensor): + if isinstance(op_or_tensor, tensor_lib.Tensor): return op_or_tensor.op return op_or_tensor diff --git a/tensorflow/python/eager/ops_test.py b/tensorflow/python/eager/ops_test.py index 1792b9fb659b8b..d1006b4ece3ef1 100644 --- a/tensorflow/python/eager/ops_test.py +++ b/tensorflow/python/eager/ops_test.py @@ -29,6 +29,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors_impl from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import test_util from tensorflow.python.layers import core @@ -191,8 +192,8 @@ def ops_test(v1, v2): self.assertAllEqual((a >= b), np.greater_equal(v1, v2)) # TODO(b/120678848): Remove the else branch once we enable - # ops.Tensor._USE_EQUALITY by default. - if ops.Tensor._USE_EQUALITY: + # tensor.Tensor._USE_EQUALITY by default. + if tensor.Tensor._USE_EQUALITY: self.assertAllEqual((a == b), np.equal(v1, v2)) self.assertAllEqual((a != b), np.not_equal(v1, v2)) else: diff --git a/tensorflow/python/eager/wrap_function.py b/tensorflow/python/eager/wrap_function.py index 6ffe1caf5c15e1..23935e6a126646 100644 --- a/tensorflow/python/eager/wrap_function.py +++ b/tensorflow/python/eager/wrap_function.py @@ -29,8 +29,8 @@ from tensorflow.python.framework import importer from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import tensor_util from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variable_scope @@ -246,7 +246,7 @@ def _call_impl(self, args, kwargs): if self._signature is not None: args = list(args) for i, arg in enumerate(args): - if isinstance(self._signature[i], tensor_spec.DenseSpec): + if isinstance(self._signature[i], tensor_lib.DenseSpec): args[i] = ops.convert_to_tensor(arg, self._signature[i].dtype) return self._call_flat(args, self.captured_inputs) else: @@ -281,7 +281,7 @@ def prune(self, feeds, fetches, name=None, input_signature=None): flat_feeds = nest.flatten(feeds, expand_composites=True) flat_feeds = [self.graph.as_graph_element(t) for t in flat_feeds] for f in flat_feeds: - if not isinstance(f, ops.Tensor): + if not isinstance(f, tensor_lib.Tensor): raise ValueError("All memebers of argument `feeds` must be tensors. " f"Got {f} with type {type(f)}.") @@ -319,7 +319,8 @@ def _fetch_preprocessing_callback(fetch): else: operation_fetches.append(decoded) return decoded - elif isinstance(fetch, (ops.Tensor, composite_tensor.CompositeTensor)): + elif isinstance( + fetch, (tensor_lib.Tensor, composite_tensor.CompositeTensor)): tensor_fetches.append(fetch) return fetch else: diff --git a/tensorflow/python/ops/numpy_ops/BUILD b/tensorflow/python/ops/numpy_ops/BUILD index 73b52041d31d9a..ad6c320f0ce6be 100644 --- a/tensorflow/python/ops/numpy_ops/BUILD +++ b/tensorflow/python/ops/numpy_ops/BUILD @@ -57,6 +57,7 @@ py_strict_library( "//tensorflow/python/framework:constant_op", "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:array_ops_stack", @@ -126,6 +127,7 @@ py_strict_library( "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:errors", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:array_ops_stack", "//tensorflow/python/ops:bitwise_ops", @@ -148,7 +150,7 @@ py_strict_library( deps = [ ":np_dtypes", "//tensorflow/python/framework:dtypes", - "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_conversion", ], ) @@ -175,6 +177,7 @@ cuda_py_strict_test( "//tensorflow/python/eager:def_function", "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/ops:array_ops", "//tensorflow/python/platform:client_testlib", "//tensorflow/python/util:nest", diff --git a/tensorflow/python/ops/numpy_ops/np_array_ops.py b/tensorflow/python/ops/numpy_ops/np_array_ops.py index 10b676e1d3f075..cd795bae931c0e 100644 --- a/tensorflow/python/ops/numpy_ops/np_array_ops.py +++ b/tensorflow/python/ops/numpy_ops/np_array_ops.py @@ -26,6 +26,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops_stack @@ -154,7 +155,7 @@ def _array_internal(val, dtype=None, copy=True, ndmin=0): # pylint: disable=red """Main implementation of np.array().""" result_t = val - if not isinstance(result_t, ops.Tensor): + if not isinstance(result_t, tensor_lib.Tensor): dtype = np_utils.result_type_unary(result_t, dtype) # We can't call `convert_to_tensor(result_t, dtype=dtype)` here because # convert_to_tensor doesn't allow incompatible arguments such as (5.5, int) @@ -548,7 +549,7 @@ def _reduce(tf_fn, elif promote_int == _TO_FLOAT: a = math_ops.cast(a, np_dtypes.default_float_type()) - if isinstance(axis, ops.Tensor) and axis.dtype not in ( + if isinstance(axis, tensor_lib.Tensor) and axis.dtype not in ( dtypes.int32, dtypes.int64): axis = math_ops.cast(axis, dtypes.int64) @@ -1096,7 +1097,7 @@ def broadcast_to(array, shape): # pylint: disable=redefined-outer-name @np_utils.np_doc('stack') def stack(arrays, axis=0): # pylint: disable=missing-function-docstring - if isinstance(arrays, (np_arrays.ndarray, ops.Tensor)): + if isinstance(arrays, (np_arrays.ndarray, tensor_lib.Tensor)): arrays = asarray(arrays) if axis == 0: return arrays @@ -1880,10 +1881,17 @@ def _as_spec_tuple(slice_spec): def _getitem(self, slice_spec): """Implementation of ndarray.__getitem__.""" - if (isinstance(slice_spec, bool) or (isinstance(slice_spec, ops.Tensor) and - slice_spec.dtype == dtypes.bool) or - (isinstance(slice_spec, (np.ndarray, np_arrays.ndarray)) and - slice_spec.dtype == np.bool_)): + if ( + isinstance(slice_spec, bool) + or ( + isinstance(slice_spec, tensor_lib.Tensor) + and slice_spec.dtype == dtypes.bool + ) + or ( + isinstance(slice_spec, (np.ndarray, np_arrays.ndarray)) + and slice_spec.dtype == np.bool_ + ) + ): return array_ops.boolean_mask(tensor=self, mask=slice_spec) if not isinstance(slice_spec, tuple): @@ -1895,10 +1903,17 @@ def _getitem(self, slice_spec): def _with_index_update_helper(update_method, a, slice_spec, updates): """Implementation of ndarray._with_index_*.""" - if (isinstance(slice_spec, bool) or (isinstance(slice_spec, ops.Tensor) and - slice_spec.dtype == dtypes.bool) or - (isinstance(slice_spec, (np.ndarray, np_arrays.ndarray)) and - slice_spec.dtype == np.bool_)): + if ( + isinstance(slice_spec, bool) + or ( + isinstance(slice_spec, tensor_lib.Tensor) + and slice_spec.dtype == dtypes.bool + ) + or ( + isinstance(slice_spec, (np.ndarray, np_arrays.ndarray)) + and slice_spec.dtype == np.bool_ + ) + ): slice_spec = nonzero(slice_spec) if not isinstance(slice_spec, tuple): diff --git a/tensorflow/python/ops/numpy_ops/np_arrays.py b/tensorflow/python/ops/numpy_ops/np_arrays.py index 987f7738c17073..78257ae37ec66b 100644 --- a/tensorflow/python/ops/numpy_ops/np_arrays.py +++ b/tensorflow/python/ops/numpy_ops/np_arrays.py @@ -17,7 +17,7 @@ # pylint: disable=g-direct-tensorflow-import from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.framework import tensor_conversion from tensorflow.python.ops.numpy_ops import np_dtypes @@ -47,4 +47,4 @@ def convert_to_tensor(value, dtype=None, dtype_hint=None): value, dtype=dtype, dtype_hint=dtype_hint) -ndarray = ops.Tensor +ndarray = tensor.Tensor diff --git a/tensorflow/python/ops/numpy_ops/np_arrays_test.py b/tensorflow/python/ops/numpy_ops/np_arrays_test.py index 9985c6ce9d909e..6bba3cdbafce11 100644 --- a/tensorflow/python/ops/numpy_ops/np_arrays_test.py +++ b/tensorflow/python/ops/numpy_ops/np_arrays_test.py @@ -19,6 +19,7 @@ from tensorflow.python.eager import def_function from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.ops import array_ops from tensorflow.python.ops.numpy_ops import np_arrays # Required for operator overloads @@ -193,8 +194,8 @@ def testFromToCompositeTensor(self): # Each ndarray contains only one tensor, so the flattened output should be # just 2 tensors in a list. self.assertLen(flattened, 2) - self.assertIsInstance(flattened[0], ops.Tensor) - self.assertIsInstance(flattened[1], ops.Tensor) + self.assertIsInstance(flattened[0], tensor.Tensor) + self.assertIsInstance(flattened[1], tensor.Tensor) repacked = nest.pack_sequence_as(tensors, flattened, expand_composites=True) self.assertLen(repacked, 2) @@ -208,7 +209,7 @@ def testFromToCompositeTensor(self): # TODO(wangpeng): Test in graph mode as well. Also test in V2 (the requirement # for setting _USE_EQUALITY points to V2 behavior not being on). ops.enable_eager_execution() - ops.Tensor._USE_EQUALITY = True + tensor.Tensor._USE_EQUALITY = True ops.set_dtype_conversion_mode('legacy') np_math_ops.enable_numpy_methods_on_tensor() test.main() diff --git a/tensorflow/python/ops/numpy_ops/np_math_ops.py b/tensorflow/python/ops/numpy_ops/np_math_ops.py index 701c8f1eb1a92d..8d8dae2a69fbab 100644 --- a/tensorflow/python/ops/numpy_ops/np_math_ops.py +++ b/tensorflow/python/ops/numpy_ops/np_math_ops.py @@ -24,6 +24,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops_stack from tensorflow.python.ops import bitwise_ops @@ -1465,30 +1466,30 @@ def _tensor_tolist(self): def enable_numpy_methods_on_tensor(): """Adds additional NumPy methods on tf.Tensor class.""" t = property(_tensor_t) - setattr(ops.Tensor, 'T', t) + setattr(tensor.Tensor, 'T', t) ndim = property(_tensor_ndim) - setattr(ops.Tensor, 'ndim', ndim) + setattr(tensor.Tensor, 'ndim', ndim) size = property(_tensor_size) - setattr(ops.Tensor, 'size', size) + setattr(tensor.Tensor, 'size', size) - setattr(ops.Tensor, '__pos__', _tensor_pos) - setattr(ops.Tensor, 'tolist', _tensor_tolist) + setattr(tensor.Tensor, '__pos__', _tensor_pos) + setattr(tensor.Tensor, 'tolist', _tensor_tolist) # TODO(b/178540516): Make a custom `setattr` that changes the method's # docstring to the TF one. - setattr(ops.Tensor, 'transpose', np_array_ops.transpose) - setattr(ops.Tensor, 'flatten', np_array_ops.flatten) - setattr(ops.Tensor, 'reshape', np_array_ops._reshape_method_wrapper) # pylint: disable=protected-access - setattr(ops.Tensor, 'ravel', np_array_ops.ravel) - setattr(ops.Tensor, 'clip', clip) - setattr(ops.Tensor, 'astype', math_ops.cast) - setattr(ops.Tensor, '__round__', np_array_ops.around) - setattr(ops.Tensor, 'max', np_array_ops.amax) - setattr(ops.Tensor, 'mean', np_array_ops.mean) - setattr(ops.Tensor, 'min', np_array_ops.amin) + setattr(tensor.Tensor, 'transpose', np_array_ops.transpose) + setattr(tensor.Tensor, 'flatten', np_array_ops.flatten) + setattr(tensor.Tensor, 'reshape', np_array_ops._reshape_method_wrapper) # pylint: disable=protected-access + setattr(tensor.Tensor, 'ravel', np_array_ops.ravel) + setattr(tensor.Tensor, 'clip', clip) + setattr(tensor.Tensor, 'astype', math_ops.cast) + setattr(tensor.Tensor, '__round__', np_array_ops.around) + setattr(tensor.Tensor, 'max', np_array_ops.amax) + setattr(tensor.Tensor, 'mean', np_array_ops.mean) + setattr(tensor.Tensor, 'min', np_array_ops.amin) # TODO(wangpeng): Remove `data` when all uses of it are removed data = property(lambda self: self) - setattr(ops.Tensor, 'data', data) + setattr(tensor.Tensor, 'data', data) diff --git a/tensorflow/python/ops/parallel_for/BUILD b/tensorflow/python/ops/parallel_for/BUILD index 5deb99e8a1af21..a86b5d93a0cf70 100644 --- a/tensorflow/python/ops/parallel_for/BUILD +++ b/tensorflow/python/ops/parallel_for/BUILD @@ -37,8 +37,8 @@ py_strict_library( "//tensorflow/python/framework:ops", "//tensorflow/python/framework:smart_cond", "//tensorflow/python/framework:sparse_tensor", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_shape", - "//tensorflow/python/framework:tensor_spec", "//tensorflow/python/framework:tensor_util", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:array_ops_gen", @@ -95,6 +95,7 @@ py_strict_library( "//tensorflow/python/framework:indexed_slices", "//tensorflow/python/framework:ops", "//tensorflow/python/framework:sparse_tensor", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/framework:tensor_util", "//tensorflow/python/framework:type_spec", @@ -278,6 +279,7 @@ py_strict_library( deps = [ ":control_flow_ops", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:check_ops", "//tensorflow/python/ops:gradients_impl", diff --git a/tensorflow/python/ops/parallel_for/control_flow_ops.py b/tensorflow/python/ops/parallel_for/control_flow_ops.py index d4f102e09221d6..e65e4fdd1c1a2c 100644 --- a/tensorflow/python/ops/parallel_for/control_flow_ops.py +++ b/tensorflow/python/ops/parallel_for/control_flow_ops.py @@ -25,6 +25,7 @@ from tensorflow.python.framework import indexed_slices from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util from tensorflow.python.framework import type_spec @@ -314,7 +315,7 @@ def _pfor_impl(loop_fn, for loop_fn_output in nest.flatten(loop_fn_output_tensors): if (loop_fn_output is not None and not isinstance( loop_fn_output, - (ops.Operation, ops.Tensor, sparse_tensor.SparseTensor))): + (ops.Operation, tensor.Tensor, sparse_tensor.SparseTensor))): if isinstance(loop_fn_output, indexed_slices.IndexedSlices): logging.warn("Converting %s to a dense representation may make it slow." " Alternatively, output the indices and values of the" diff --git a/tensorflow/python/ops/parallel_for/gradients.py b/tensorflow/python/ops/parallel_for/gradients.py index 3ef8e0cc58ce9e..da667a5e1bbde5 100644 --- a/tensorflow/python/ops/parallel_for/gradients.py +++ b/tensorflow/python/ops/parallel_for/gradients.py @@ -14,6 +14,7 @@ # ============================================================================== """Jacobian ops.""" from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import gradients_impl as gradient_ops @@ -66,7 +67,7 @@ def loop_fn(i): parallel_iterations=parallel_iterations) for i, out in enumerate(pfor_outputs): - if isinstance(out, ops.Tensor): + if isinstance(out, tensor.Tensor): new_shape = array_ops.concat( [output_shape, array_ops.shape(out)[1:]], axis=0) out = array_ops.reshape(out, new_shape) diff --git a/tensorflow/python/ops/parallel_for/pfor.py b/tensorflow/python/ops/parallel_for/pfor.py index 472c196c02ed71..8905f4efc8c32a 100644 --- a/tensorflow/python/ops/parallel_for/pfor.py +++ b/tensorflow/python/ops/parallel_for/pfor.py @@ -34,8 +34,8 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import smart_cond from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops_stack @@ -229,7 +229,7 @@ def __init__(self, exit_node, pfor_ops, fallback_to_while_loop, pfor_config): self._pfor_config = pfor_config self._pfor_ops = set(pfor_ops) self._pfor_op_ids = set(x._id for x in pfor_ops) - assert isinstance(exit_node, ops.Tensor) + assert isinstance(exit_node, tensor_lib.Tensor) self._while_context = exit_node.op._get_control_flow_context() assert isinstance(self._while_context, control_flow_ops.WhileContext) self._context_name = self._while_context.name @@ -260,13 +260,13 @@ def __init__(self, exit_node, pfor_ops, fallback_to_while_loop, pfor_config): # to different Operations/Tensors of a single cycle as illustrated above. # List of Switch ops (ops.Operation) that feed into an Exit Node. self._exit_switches = [] - # List of inputs (ops.Tensor) to NextIteration. + # List of inputs (tensor_lib.Tensor) to NextIteration. self._body_outputs = [] # List of list of control inputs of the NextIteration nodes. self._next_iter_control_inputs = [] # List of Merge ops (ops.Operation). self._enter_merges = [] - # List of output (ops.Tensor) of Exit nodes. + # List of output (tensor_lib.Tensor) of Exit nodes. self._outputs = [] # List of Enter Tensors. @@ -1071,7 +1071,7 @@ def wrap(tensor, is_stacked=True, is_sparse_stacked=False): """Helper to create a WrappedTensor object.""" assert isinstance(is_stacked, bool) assert isinstance(is_sparse_stacked, bool) - assert isinstance(tensor, ops.Tensor) + assert isinstance(tensor, tensor_lib.Tensor) assert not is_sparse_stacked or is_stacked, ("If the wrapped tensor is " "stacked via a sparse " "conversion, it must also be " @@ -1116,7 +1116,7 @@ def while_body(i, *ta_list): # TODO(agarwal): Add tf.debugging asserts to check that the shapes across # the different iterations are the same. for out, ta in zip(op_outputs, ta_list): - assert isinstance(out, ops.Tensor) + assert isinstance(out, tensor_lib.Tensor) outputs.append(ta.write(i, out)) return tuple([i + 1] + outputs) @@ -1143,7 +1143,7 @@ def _has_reductions(self): def _set_iters(self, iters): """Set number of pfor iterations.""" - if isinstance(iters, ops.Tensor): + if isinstance(iters, tensor_lib.Tensor): iters = tensor_util.constant_value(iters) self._maybe_iters = iters @@ -1170,12 +1170,12 @@ def reduce(self, fn, *args): # Creates a concrete function that will be used for reduction. tensor_specs = [] for arg in args: - if not isinstance(arg, ops.Tensor): + if not isinstance(arg, tensor_lib.Tensor): raise ValueError(f"Got a non-Tensor argument {arg} in reduce.") batched_shape = tensor_shape.TensorShape([self._maybe_iters ]).concatenate(arg.shape) tensor_specs.append( - tensor_spec.TensorSpec(shape=batched_shape, dtype=arg.dtype)) + tensor_lib.TensorSpec(shape=batched_shape, dtype=arg.dtype)) concrete_function = def_function.function(fn).get_concrete_function( *tensor_specs) @@ -1184,7 +1184,7 @@ def reduce(self, fn, *args): pl_outputs = [] with ops.control_dependencies(args): for output in concrete_function.outputs: - if not isinstance(output, ops.Tensor): + if not isinstance(output, tensor_lib.Tensor): raise ValueError(f"Got a non-Tensor output {output} while running " "reduce.") # Note that we use placeholder_with_default just to make XLA happy since @@ -1249,7 +1249,7 @@ def reduce_sum(self, x): def _lookup_reduction(self, t): """Lookups Tensor `t` in the reduction maps.""" - assert isinstance(t, ops.Tensor), t + assert isinstance(t, tensor_lib.Tensor), t return self._reduce_map.get(t.op) @@ -1298,7 +1298,7 @@ def __init__(self, """Creates an object to rewrite a parallel-for loop. Args: - loop_var: ops.Tensor output of a Placeholder operation. The value should + loop_var: Tensor output of a Placeholder operation. The value should be an int32 scalar representing the loop iteration number. loop_len: A scalar or scalar Tensor representing the number of iterations the loop is run for. @@ -1316,7 +1316,7 @@ def __init__(self, pfor_config: PForConfig object used while constructing the loop body. warn: Whether or not to warn on while loop conversions. """ - assert isinstance(loop_var, ops.Tensor) + assert isinstance(loop_var, tensor_lib.Tensor) assert loop_var.op.type == "PlaceholderWithDefault" self._loop_var = loop_var loop_len_value = tensor_util.constant_value(loop_len) @@ -1425,7 +1425,7 @@ def convert(self, y): """Returns the converted value corresponding to y. Args: - y: A ops.Tensor or a ops.Operation object. If latter, y should not have + y: A Tensor or a ops.Operation object. If latter, y should not have any outputs. Returns: @@ -1436,10 +1436,10 @@ def convert(self, y): return None if isinstance(y, sparse_tensor.SparseTensor): return self._convert_sparse(y) - assert isinstance(y, (ops.Tensor, ops.Operation)), y + assert isinstance(y, (tensor_lib.Tensor, ops.Operation)), y output = self._convert_helper(y) if isinstance(output, WrappedTensor): - assert isinstance(y, ops.Tensor) + assert isinstance(y, tensor_lib.Tensor) return self._unwrap_or_tile(output) else: assert isinstance(y, ops.Operation) @@ -1453,7 +1453,8 @@ def _was_converted(self, t): return converted_t.t is not t def _add_conversion(self, old_output, new_output): - assert isinstance(old_output, (ops.Tensor, ops.Operation)), old_output + assert isinstance( + old_output, (tensor_lib.Tensor, ops.Operation)), old_output assert isinstance(new_output, (WrappedTensor, ops.Operation)), new_output self._conversion_map[old_output] = new_output @@ -1467,7 +1468,7 @@ def _convert_reduction(self, y): (reduction_fn, reduction_args) = reduction batched_args = [] for reduction_arg in reduction_args: - assert isinstance(reduction_arg, ops.Tensor), reduction_arg + assert isinstance(reduction_arg, tensor_lib.Tensor), reduction_arg # Tensor being reduced should already be converted due to a control # dependency on the created placeholder. # Note that in cases where reduction_arg is in an outer context, one @@ -1499,7 +1500,7 @@ def _convert_helper(self, op_or_tensor): "Got %s", y) y_op = y else: - assert isinstance(y, ops.Tensor), y + assert isinstance(y, tensor_lib.Tensor), y y_op = y.op is_while_loop = y_op.type == "Exit" @@ -1891,7 +1892,7 @@ def _channel_flatten_input(x, data_format): We then merge the S and C dimension. Args: - x: ops.Tensor to transform. + x: tensor_lib.Tensor to transform. data_format: "NCHW" or "NHWC". Returns: @@ -2588,7 +2589,7 @@ def _convert_gather(pfor_input): if param_stacked: pfor_input.stack_inputs(stack_indices=[1]) indices = pfor_input.stacked_input(1) - if isinstance(axis, ops.Tensor): + if isinstance(axis, tensor_lib.Tensor): axis = array_ops.where(axis >= 0, axis + 1, axis) else: axis = axis + 1 if axis >= 0 else axis diff --git a/tensorflow/python/ops/structured/BUILD b/tensorflow/python/ops/structured/BUILD index 0b98081990e0f8..708242c253f540 100644 --- a/tensorflow/python/ops/structured/BUILD +++ b/tensorflow/python/ops/structured/BUILD @@ -43,8 +43,8 @@ py_strict_library( "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:extension_type", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_shape", - "//tensorflow/python/framework:tensor_spec", "//tensorflow/python/framework:type_spec", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:check_ops", @@ -72,7 +72,7 @@ py_strict_library( srcs_version = "PY3", deps = [ ":structured_tensor", - "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops/ragged:dynamic_ragged_shape", @@ -91,6 +91,7 @@ py_strict_library( "//tensorflow/python/framework:constant_op", "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:math_ops", "//tensorflow/python/ops:random_ops", @@ -116,8 +117,8 @@ py_strict_test( "//tensorflow/python/framework:extension_type", "//tensorflow/python/framework:ops", "//tensorflow/python/framework:sparse_tensor", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_shape", - "//tensorflow/python/framework:tensor_spec", "//tensorflow/python/framework:test_lib", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops/ragged:dynamic_ragged_shape", @@ -163,8 +164,8 @@ py_strict_test( ":structured_tensor", "//tensorflow/python/eager:def_function", "//tensorflow/python/framework:dtypes", - "//tensorflow/python/framework:ops", "//tensorflow/python/framework:random_seed", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/framework:test_lib", "//tensorflow/python/ops:array_ops", diff --git a/tensorflow/python/ops/structured/structured_array_ops.py b/tensorflow/python/ops/structured/structured_array_ops.py index 3bee40bc9e1d1d..9805418517399b 100644 --- a/tensorflow/python/ops/structured/structured_array_ops.py +++ b/tensorflow/python/ops/structured/structured_array_ops.py @@ -20,6 +20,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops @@ -515,7 +516,7 @@ def _structured_tensor_from_row_partitions(shape, row_partitions): # pylint: disable=protected_access def _all_nested_row_partitions(rt): """Returns all nested row partitions in rt, including for dense dimensions.""" - if isinstance(rt, ops.Tensor): + if isinstance(rt, tensor_lib.Tensor): if rt.shape.rank <= 1: return () else: @@ -529,7 +530,7 @@ def _all_nested_row_partitions(rt): def _structured_tensor_like(t): """Create a StructuredTensor with the shape of a (composite) tensor.""" - if isinstance(t, ops.Tensor): + if isinstance(t, tensor_lib.Tensor): return _structured_tensor_from_dense_tensor(t) if ragged_tensor.is_ragged(t): return StructuredTensor.from_fields( diff --git a/tensorflow/python/ops/structured/structured_array_ops_test.py b/tensorflow/python/ops/structured/structured_array_ops_test.py index 09421202488416..04f21fb28880c1 100644 --- a/tensorflow/python/ops/structured/structured_array_ops_test.py +++ b/tensorflow/python/ops/structured/structured_array_ops_test.py @@ -17,8 +17,8 @@ from tensorflow.python.eager import def_function from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed +from tensorflow.python.framework import tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops @@ -57,11 +57,11 @@ def assertAllEqual(self, a, b, msg=None): self.assertIsNone(e, (msg + ": " if msg else "") + str(e)) a_tensors = [ x for x in nest.flatten(a, expand_composites=True) - if isinstance(x, ops.Tensor) + if isinstance(x, tensor.Tensor) ] b_tensors = [ x for x in nest.flatten(b, expand_composites=True) - if isinstance(x, ops.Tensor) + if isinstance(x, tensor.Tensor) ] self.assertLen(a_tensors, len(b_tensors)) a_arrays, b_arrays = self.evaluate((a_tensors, b_tensors)) diff --git a/tensorflow/python/ops/structured/structured_tensor.py b/tensorflow/python/ops/structured/structured_tensor.py index 696589fcb39c9d..752d29895fe1a3 100644 --- a/tensorflow/python/ops/structured/structured_tensor.py +++ b/tensorflow/python/ops/structured/structured_tensor.py @@ -23,8 +23,8 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import extension_type from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import type_spec from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops @@ -39,8 +39,12 @@ from tensorflow.python.util.tf_export import tf_export # Each field may contain one of the following types of Tensors. -_FieldValue = Union[ops.Tensor, ragged_tensor.RaggedTensor, 'StructuredTensor', - extension_type.ExtensionType] +_FieldValue = Union[ + tensor.Tensor, + ragged_tensor.RaggedTensor, + 'StructuredTensor', + extension_type.ExtensionType +] # Function that takes a FieldValue as input and returns the transformed # FieldValue. _FieldFn = Callable[[_FieldValue], _FieldValue] @@ -134,7 +138,7 @@ def _old_init(cls, fields, shape, nrows, row_partitions, internal=False): """ assert isinstance(fields, dict), fields assert isinstance(shape, tensor_shape.TensorShape), shape - assert nrows is None or isinstance(nrows, ops.Tensor), nrows + assert nrows is None or isinstance(nrows, tensor.Tensor), nrows assert row_partitions is None or isinstance(row_partitions, tuple), row_partitions return StructuredTensor( @@ -786,7 +790,7 @@ def _tensor_getitem(self, key): if not (k.start is None and k.stop is None and k.step is None): # TODO(edloper): Better static shape analysis here. result_shape[d] = None - elif isinstance(k, (int, ops.Tensor)): + elif isinstance(k, (int, tensor.Tensor)): result_shape[d] = -1 # mark for deletion elif k is None: raise ValueError('Slicing not supported for tf.newaxis') @@ -1008,7 +1012,7 @@ def _from_pylist_of_value(cls, pyval, typespec, path_so_far): return ragged_factory_ops.constant(pyval) except Exception as exc: raise ValueError('Error parsing path %r' % (path_so_far,)) from exc - elif isinstance(typespec, tensor_spec.TensorSpec): + elif isinstance(typespec, tensor.TensorSpec): try: result = constant_op.constant(pyval, typespec.dtype) except Exception as exc: @@ -1049,7 +1053,7 @@ def _from_pyscalar(cls, pyval, typespec, path_so_far): except Exception as exc: raise ValueError('Error parsing path %r' % (path_so_far,)) from exc else: - if not (isinstance(typespec, tensor_spec.TensorSpec) and + if not (isinstance(typespec, tensor.TensorSpec) and typespec.shape.rank == 0): raise ValueError('Value at %r does not match typespec: %r vs %r' % (path_so_far, typespec, pyval)) @@ -1200,7 +1204,7 @@ def rank(self): def _convert_to_structured_field_value(value): """Converts `value` to a Tensor, RaggedTensor, or StructuredTensor.""" if isinstance(value, - (ops.Tensor, ragged_tensor.RaggedTensor, StructuredTensor)): + (tensor.Tensor, ragged_tensor.RaggedTensor, StructuredTensor)): return value elif ragged_tensor.is_ragged(value): return ragged_tensor.convert_to_tensor_or_ragged_tensor(value) @@ -1215,7 +1219,7 @@ def _convert_to_structured_field_value(value): def _find_shape_dtype( - fields: Mapping[str, _FieldValue], nrows: Optional[ops.Tensor], + fields: Mapping[str, _FieldValue], nrows: Optional[tensor.Tensor], row_partitions: Optional[Sequence[RowPartition]]) -> dtypes.DType: """Return a consistent dtype for fields, nrows, & row_partitions. @@ -1232,7 +1236,7 @@ def _find_shape_dtype( If int32 is explicitly specified, return int32. Otherwise, return int64. """ field_dtypes = [_field_shape_dtype(v) for v in fields.values()] - nrows_dtypes = [nrows.dtype] if isinstance(nrows, ops.Tensor) else [] + nrows_dtypes = [nrows.dtype] if isinstance(nrows, tensor.Tensor) else [] rp_dtypes = [] if row_partitions is None else [ rp.dtype for rp in row_partitions ] @@ -1266,7 +1270,7 @@ def _merge_nrows(nrows, static_nrows, value, dtype, validate): A tuple `(nrows, static_nrows)`. """ static_value_nrows = tensor_shape.dimension_at_index(value.shape, 0) - if isinstance(value, ops.Tensor): + if isinstance(value, tensor.Tensor): value_nrows = array_ops.shape(value, out_type=dtype)[0] else: value_nrows = value.nrows() @@ -1287,7 +1291,7 @@ def _merge_nrows(nrows, static_nrows, value, dtype, validate): def _merge_row_partitions(row_partitions, value, rank, dtype, validate): """Merges `row_partitions` with `row_partitions(value)`.""" - if isinstance(value, ops.Tensor): + if isinstance(value, tensor.Tensor): value_row_partitions = _row_partitions_for_tensor(value, rank, dtype) elif isinstance(value, ragged_tensor.RaggedTensor): @@ -1486,7 +1490,7 @@ def _replace_row_partitions(value, new_partitions): A value that is equivalent to `value`, where outer row partitions have been replaced by `new_partitions`. """ - if isinstance(value, ops.Tensor) or not new_partitions: + if isinstance(value, tensor.Tensor) or not new_partitions: return value elif isinstance(value, ragged_tensor.RaggedTensor): @@ -1532,14 +1536,14 @@ def _partition_outer_dimension(value, row_partition): `result.rank = value.rank + 1`. """ is_ragged = row_partition.uniform_row_length() is None - if isinstance(value, ops.Tensor) and not is_ragged: + if isinstance(value, tensor.Tensor) and not is_ragged: new_shape = array_ops.concat( [[row_partition.nrows(), row_partition.uniform_row_length()], array_ops.shape(value, out_type=row_partition.dtype)[1:]], axis=0) return array_ops.reshape(value, new_shape) - elif isinstance(value, (ops.Tensor, ragged_tensor.RaggedTensor)): + elif isinstance(value, (tensor.Tensor, ragged_tensor.RaggedTensor)): return ragged_tensor.RaggedTensor._from_row_partition( # pylint: disable=protected-access value, row_partition) else: @@ -1558,7 +1562,7 @@ def _partition_outer_dimension(value, row_partition): def _merge_dims(value, outer_axis, inner_axis): """Merges `outer_axis...inner_axis` of `value` into a single dimension.""" assert outer_axis < inner_axis - if isinstance(value, (ops.Tensor, ragged_tensor.RaggedTensor)): + if isinstance(value, (tensor.Tensor, ragged_tensor.RaggedTensor)): return ragged_tensor.merge_dims(value, outer_axis, inner_axis) else: assert isinstance(value, StructuredTensor) @@ -1575,7 +1579,7 @@ def _merge_dims(value, outer_axis, inner_axis): def _dynamic_ragged_shape_spec_from_spec( spec: Union[dynamic_ragged_shape.DynamicRaggedShape.Spec, ragged_tensor.RaggedTensorSpec, StructuredTensor.Spec, - tensor_spec.TensorSpec] + tensor.TensorSpec] ) -> dynamic_ragged_shape.DynamicRaggedShape.Spec: if isinstance(spec, StructuredTensor.Spec): return spec._ragged_shape # pylint: disable=protected-access @@ -1630,7 +1634,7 @@ def _dynamic_ragged_shape_from_tensor( return field._ragged_shape # pylint: disable=protected-access shape = array_ops.shape_v2(field, out_type=dtype) - if isinstance(shape, ops.Tensor): + if isinstance(shape, tensor.Tensor): return dynamic_ragged_shape.DynamicRaggedShape( row_partitions=[], inner_shape=shape) elif isinstance(shape, dynamic_ragged_shape.DynamicRaggedShape): @@ -1697,7 +1701,7 @@ def _dynamic_ragged_shape_init(fields, shape, nrows, row_partitions): """Produce a DynamicRaggedShape for StructuredTensor.""" assert isinstance(fields, dict), fields assert isinstance(shape, tensor_shape.TensorShape), shape - assert nrows is None or isinstance(nrows, ops.Tensor) or isinstance( + assert nrows is None or isinstance(nrows, tensor.Tensor) or isinstance( nrows, int), nrows assert row_partitions is None or isinstance(row_partitions, tuple), row_partitions diff --git a/tensorflow/python/ops/structured/structured_tensor_dynamic.py b/tensorflow/python/ops/structured/structured_tensor_dynamic.py index 84944861af1830..8aa434831e616a 100644 --- a/tensorflow/python/ops/structured/structured_tensor_dynamic.py +++ b/tensorflow/python/ops/structured/structured_tensor_dynamic.py @@ -14,7 +14,7 @@ # ============================================================================== """Dynamic shape for structured Tensors.""" -from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops.ragged import dynamic_ragged_shape @@ -26,7 +26,7 @@ def _dynamic_ragged_shape_init(fields, shape, nrows, row_partitions): """Produce a DynamicRaggedShape for StructuredTensor.""" assert isinstance(fields, dict), fields assert isinstance(shape, tensor_shape.TensorShape), shape - assert nrows is None or isinstance(nrows, ops.Tensor), nrows + assert nrows is None or isinstance(nrows, tensor.Tensor), nrows assert isinstance(row_partitions, tuple), row_partitions rank = shape.rank diff --git a/tensorflow/python/ops/structured/structured_tensor_test.py b/tensorflow/python/ops/structured/structured_tensor_test.py index 3e84e21ae6f018..b182c530b1a393 100644 --- a/tensorflow/python/ops/structured/structured_tensor_test.py +++ b/tensorflow/python/ops/structured/structured_tensor_test.py @@ -25,8 +25,8 @@ from tensorflow.python.framework import extension_type from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops.ragged import ragged_factory_ops @@ -611,7 +611,7 @@ def testFromFields(self, for field, value in fields.items(): self.assertIsInstance( struct.field_value(field), - (ops.Tensor, structured_tensor.StructuredTensor, + (tensor.Tensor, structured_tensor.StructuredTensor, ragged_tensor.RaggedTensor)) self.assertAllEqual(struct.field_value(field), value) @@ -791,7 +791,7 @@ def testPartitionOuterDims(self): dtype=dtypes.int64), _fields={ "x": - tensor_spec.TensorSpec([2, 2], dtypes.int32), + tensor.TensorSpec([2, 2], dtypes.int32), "y": ragged_tensor.RaggedTensorSpec([2, 2, None], dtypes.int32) @@ -855,8 +855,8 @@ def testPartitionOuterDimsErrors(self): "pyval": {"a": 12, "b": [1, 2, 3], "c": [[1, 2], [3]]}, "type_spec": StructuredTensor.Spec._from_fields_and_rank( fields={ - "a": tensor_spec.TensorSpec([], dtypes.int32), - "b": tensor_spec.TensorSpec([None], dtypes.int32), + "a": tensor.TensorSpec([], dtypes.int32), + "b": tensor.TensorSpec([None], dtypes.int32), "c": ragged_tensor.RaggedTensorSpec([None, None], dtypes.int32)}, rank=0), @@ -889,7 +889,7 @@ def testPartitionOuterDimsErrors(self): "testcase_name": "EmptyListWithTypeSpecAndFields", "pyval": [], "type_spec": structured_tensor.StructuredTensor.Spec._from_fields_and_rank( - fields={"a": tensor_spec.TensorSpec([0], dtypes.int32)}, + fields={"a": tensor.TensorSpec([0], dtypes.int32)}, rank=1), "expected": lambda: StructuredTensor.from_fields(shape=[0], fields={ "a": []}) @@ -963,7 +963,7 @@ def testPartitionOuterDimsErrors(self): "pyval": [[{"a": 1}, {"a": 2}, {"a": 3},], [{"a": 4}, {"a": 5}, {"a": 6}]], "type_spec": structured_tensor.StructuredTensorSpec([2, 3], { - "a": tensor_spec.TensorSpec(None, dtypes.int32)}), + "a": tensor.TensorSpec(None, dtypes.int32)}), "expected": lambda: StructuredTensor.from_fields( shape=[2, 3], fields={"a": [[1, 2, 3], [4, 5, 6]]}) }, @@ -979,8 +979,8 @@ def testPyvalConversion(self, pyval, expected, type_spec=None): def testStructuredTensorSpecFactory(self): spec = StructuredTensor.Spec._from_fields_and_rank( fields={ - "a": tensor_spec.TensorSpec([], dtypes.int32), - "b": tensor_spec.TensorSpec([None], dtypes.int32), + "a": tensor.TensorSpec([], dtypes.int32), + "b": tensor.TensorSpec([None], dtypes.int32), "c": ragged_tensor.RaggedTensorSpec([None, None], dtypes.int32) }, rank=0) @@ -1042,19 +1042,19 @@ def testToPyval(self, st, expected): dict(testcase_name="TypeSpecMismatch_DictKey", pyval={"a": 1}, type_spec=StructuredTensor.Spec._from_fields_and_rank( - fields={"b": tensor_spec.TensorSpec([1], dtypes.int32)}, + fields={"b": tensor.TensorSpec([1], dtypes.int32)}, rank=1), msg=r"Value at \(\) does not match typespec"), dict(testcase_name="TypeSpecMismatch_ListDictKey", pyval=[{"a": 1}], type_spec=StructuredTensor.Spec._from_fields_and_rank( - fields={"b": tensor_spec.TensorSpec([1], dtypes.int32)}, + fields={"b": tensor.TensorSpec([1], dtypes.int32)}, rank=1), msg=r"Value at \(\) does not match typespec"), dict(testcase_name="TypeSpecMismatch_RankMismatch", pyval=[{"a": 1}], type_spec=StructuredTensor.Spec._from_fields_and_rank( - fields={"a": tensor_spec.TensorSpec([], dtypes.int32)}, + fields={"a": tensor.TensorSpec([], dtypes.int32)}, rank=0), msg=r"Value at \(\) does not match typespec \(rank mismatch\)"), dict(testcase_name="TypeSpecMismatch_Scalar", @@ -1068,14 +1068,14 @@ def testToPyval(self, st, expected): dict(testcase_name="TypeSpecMismatch_ListTensor", pyval={"a": [[1]]}, type_spec=StructuredTensor.Spec._from_fields_and_rank( - fields={"a": tensor_spec.TensorSpec([], dtypes.int32)}, + fields={"a": tensor.TensorSpec([], dtypes.int32)}, rank=0), msg=r"Value at \('a',\) does not match typespec"), dict(testcase_name="TypeSpecMismatch_ListTensorDeep", pyval={"a": {"b": [[1]]}}, type_spec=StructuredTensor.Spec._from_fields_and_rank( fields={"a": StructuredTensor.Spec._from_fields_and_rank( - fields={"b": tensor_spec.TensorSpec([], dtypes.int32)}, + fields={"b": tensor.TensorSpec([], dtypes.int32)}, rank=0 )}, rank=0), @@ -1095,7 +1095,7 @@ def testToPyval(self, st, expected): dict(testcase_name="TypeSpecMismatch_ListStruct", pyval=[[1]], type_spec=StructuredTensor.Spec._from_fields_and_rank( - fields={"a": tensor_spec.TensorSpec([1, 1], dtypes.int32)}, + fields={"a": tensor.TensorSpec([1, 1], dtypes.int32)}, rank=2), msg=r"Value at \(\) does not match typespec"), dict(testcase_name="InconsistentDictionaryDepth", From a622eb3a20bff7dbc96e85004109128e1598c85e Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 10 Jul 2023 17:30:22 -0700 Subject: [PATCH 104/376] Implement pattern conversion for MHLO::Dot hybrid UQ type to int type. PiperOrigin-RevId: 547028867 --- .../bridge/convert_mhlo_quant_to_int.cc | 140 +++++++++++++++--- .../convert-mhlo-quant-to-int-no-chlo.mlir | 10 ++ .../bridge/convert-mhlo-quant-to-int.mlir | 70 ++++++--- 3 files changed, 174 insertions(+), 46 deletions(-) diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_mhlo_quant_to_int.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_mhlo_quant_to_int.cc index 0424ad97cc61ca..43a5d568132894 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_mhlo_quant_to_int.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_mhlo_quant_to_int.cc @@ -14,8 +14,11 @@ limitations under the License. ==============================================================================*/ #include +#include #include #include +#include +#include #include #include "llvm/ADT/STLExtras.h" @@ -435,6 +438,73 @@ class ConvertUniformQuantizedAddOp : public OpConversionPattern { } }; +// A shared matchAndRewrite implementation for dot-like hybrid quantized +// operators. Hybrid ops are currently only interpreted as weight-only +// quantization ops, this might change in the future. +// +// All attrs of the original op are preserved after the conversion. +template +LogicalResult matchAndRewriteDotLikeHybridOp( + OpType &op, OpAdaptorType &adaptor, ConversionPatternRewriter &rewriter, + const quant::UniformQuantizedType &rhs_element_type) { + auto res_float32_tensor_type_or = GetSameShapeTensorType( + op, op.getResult().getType().template cast(), + rewriter.getF32Type(), rewriter); + if (failed(res_float32_tensor_type_or)) { + return failure(); + } + + Value lhs_float32_tensor = adaptor.getLhs(); + Value rhs = adaptor.getRhs(); + + // For dot like hybrid ops, lhs is float type, rhs is uniform + // quantized type and result is float type. + // For weight-only quantization: + // result = hybridOp(lhs, dequant(rhs)) + // + // Get scales and zero points for rhs. + Value rhs_zero_point = rewriter.create( + op->getLoc(), + rewriter.getF32FloatAttr((rhs_element_type.getZeroPoint()))); + Value rhs_scale_constant = rewriter.create( + op->getLoc(), rewriter.getF32FloatAttr( + static_cast(rhs_element_type.getScale()))); + + // Dequantize rhs_float32_tensor. + Value rhs_float32_tensor = rewriter.create( + op->getLoc(), *res_float32_tensor_type_or, rhs); + rhs_float32_tensor = rewriter.create( + op->getLoc(), *res_float32_tensor_type_or, rhs_float32_tensor, + rhs_zero_point, nullptr); + rhs_float32_tensor = rewriter.create( + op->getLoc(), *res_float32_tensor_type_or, rhs_float32_tensor, + rhs_scale_constant, nullptr); + + // Execute conversion target op. + SmallVector operands{lhs_float32_tensor, rhs_float32_tensor}; + Value res_float32 = rewriter.create( + op->getLoc(), *res_float32_tensor_type_or, operands, op->getAttrs()); + + Value half = rewriter.create( + op->getLoc(), rewriter.getF32FloatAttr(0.5f)); + res_float32 = rewriter.create( + op->getLoc(), *res_float32_tensor_type_or, res_float32, half, nullptr); + res_float32 = rewriter.create(op->getLoc(), res_float32); + + // Cast res_float_tensor_type to res_int_tensor_type. + auto res_int32_tensor_type_or = GetSameShapeTensorType( + op, op.getResult().getType().template cast(), + rewriter.getI32Type(), rewriter); + if (failed(res_int32_tensor_type_or)) { + return failure(); + } + + rewriter.replaceOpWithNewOp(op, *res_int32_tensor_type_or, + res_float32); + + return success(); +} + // A shared matchAndRewrite implementation for dot-like quantized operators. // // Dot-like operators refer to operators that generate a tensor where each @@ -446,24 +516,42 @@ class ConvertUniformQuantizedAddOp : public OpConversionPattern { template LogicalResult matchAndRewriteDotLikeOp(OpType &op, OpAdaptorType &adaptor, ConversionPatternRewriter &rewriter) { - auto lhs_element_type = op.getLhs() - .getType() - .getElementType() - .template dyn_cast(); - auto rhs_element_type = op.getRhs() - .getType() - .getElementType() - .template dyn_cast(); - auto res_element_type = op.getResult() - .getType() - .getElementType() - .template dyn_cast(); - - // Check if the operands and result are UniformQuantizedTypes. - if (!lhs_element_type || !rhs_element_type || !res_element_type) { + auto lhs_element_type = getElementTypeOrSelf(op.getLhs().getType()); + auto rhs_element_quant_type = + op.getRhs() + .getType() + .getElementType() + .template dyn_cast(); + auto res_element_type = getElementTypeOrSelf(op.getResult()); + + // Check if the right operand is UniformQuantizedTypes. + if (!rhs_element_quant_type) { return rewriter.notifyMatchFailure( op, "Legalization failed: supports only per-tensor quantization."); } + + if (lhs_element_type.template isa()) { + // If lhs is uniform quantized type, result should also be uniform + // quantized type, representing none-hybrid op. + if (!res_element_type.template isa()) { + op->emitError("Unsupported result element type for " + + op->getName().getStringRef().str()); + return failure(); + } + } else if (lhs_element_type.isF32()) { + // If lhs is float32 type, result should also be float32 type, + // representing hybrid op. + if (!res_element_type.isF32()) { + op->emitError("Unsupported result element type for " + + op->getName().getStringRef().str()); + return failure(); + } + return matchAndRewriteDotLikeHybridOp(op, adaptor, rewriter, + rhs_element_quant_type); + } else { + return rewriter.notifyMatchFailure(op, "Unsupported input element type."); + } + auto res_float32_tensor_type_or = GetSameShapeTensorType( op, op.getResult().getType().template cast(), rewriter.getF32Type(), rewriter); @@ -471,6 +559,10 @@ LogicalResult matchAndRewriteDotLikeOp(OpType &op, OpAdaptorType &adaptor, return failure(); } + auto lhs_element_quant_type = + lhs_element_type.template dyn_cast(); + auto res_element_quant_type = + res_element_type.template dyn_cast(); Value lhs = adaptor.getLhs(); Value rhs = adaptor.getRhs(); @@ -481,10 +573,10 @@ LogicalResult matchAndRewriteDotLikeOp(OpType &op, OpAdaptorType &adaptor, // Get scales and zero points for both operands. Value lhs_zero_point = rewriter.create( op->getLoc(), - rewriter.getF32FloatAttr((lhs_element_type.getZeroPoint()))); + rewriter.getF32FloatAttr((lhs_element_quant_type.getZeroPoint()))); Value rhs_zero_point = rewriter.create( op->getLoc(), - rewriter.getF32FloatAttr((rhs_element_type.getZeroPoint()))); + rewriter.getF32FloatAttr((rhs_element_quant_type.getZeroPoint()))); // Offset xxx_int32_tensor according to zero points. Value lhs_float32_tensor = rewriter.create( @@ -507,10 +599,10 @@ LogicalResult matchAndRewriteDotLikeOp(OpType &op, OpAdaptorType &adaptor, // scales. Value result_zero_point = rewriter.create( op->getLoc(), - rewriter.getF32FloatAttr((res_element_type.getZeroPoint()))); - const double effective_scale = lhs_element_type.getScale() * - rhs_element_type.getScale() / - res_element_type.getScale(); + rewriter.getF32FloatAttr((res_element_quant_type.getZeroPoint()))); + const double effective_scale = lhs_element_quant_type.getScale() * + rhs_element_quant_type.getScale() / + res_element_quant_type.getScale(); Value effective_scale_constant = rewriter.create( op->getLoc(), rewriter.getF32FloatAttr(static_cast(effective_scale))); @@ -543,10 +635,10 @@ LogicalResult matchAndRewriteDotLikeOp(OpType &op, OpAdaptorType &adaptor, // Clamp results by [quantization_min, quantization_max]. Value result_quantization_min = rewriter.create( op->getLoc(), rewriter.getI32IntegerAttr(static_cast( - res_element_type.getStorageTypeMin()))); + res_element_quant_type.getStorageTypeMin()))); Value result_quantization_max = rewriter.create( op->getLoc(), rewriter.getI32IntegerAttr(static_cast( - res_element_type.getStorageTypeMax()))); + res_element_quant_type.getStorageTypeMax()))); res_int32 = rewriter.create( op->getLoc(), *res_int32_tensor_type_or, result_quantization_min, res_int32, result_quantization_max); @@ -554,7 +646,7 @@ LogicalResult matchAndRewriteDotLikeOp(OpType &op, OpAdaptorType &adaptor, // Convert results back to int8. auto res_final_tensor_type_or = GetSameShapeTensorType( op, res_int32_tensor_type_or->template cast(), - res_element_type.getStorageType(), rewriter); + res_element_quant_type.getStorageType(), rewriter); rewriter.replaceOpWithNewOp(op, *res_final_tensor_type_or, res_int32); diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/bridge/convert-mhlo-quant-to-int-no-chlo.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/bridge/convert-mhlo-quant-to-int-no-chlo.mlir index df3886645a3112..95a247ffdc19b8 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/bridge/convert-mhlo-quant-to-int-no-chlo.mlir +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/bridge/convert-mhlo-quant-to-int-no-chlo.mlir @@ -125,3 +125,13 @@ func.func @uniform_quantized_convolution(%arg0: tensor, %arg1: tens -> tensor> return } + +// ----- + +// CHECK-LABEL: func @uniform_quantize_dot_hybrid +func.func @uniform_quantize_dot_hybrid(%arg0: tensor, %arg1: tensor) { + // CHECK-NOT: chlo + %0 = mhlo.uniform_quantize %arg1 : (tensor) -> tensor> + %1 = "mhlo.dot" (%arg0, %0): (tensor, tensor>) -> tensor + return +} \ No newline at end of file diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/bridge/convert-mhlo-quant-to-int.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/bridge/convert-mhlo-quant-to-int.mlir index 186a576c3731c3..4766956e56b00f 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/bridge/convert-mhlo-quant-to-int.mlir +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/bridge/convert-mhlo-quant-to-int.mlir @@ -84,13 +84,13 @@ func.func @uniform_quantize_and_dequantize_sparse_tensor_encoding(%arg0: tensor< // CHECK-LABEL: func @uniform_quantize_add func.func @uniform_quantize_add(%arg0: tensor, %arg1: tensor) -> () { - // CHECK-DAG: %[[VAL1:.*]] = mhlo.convert %[[VAL0:.*]] : (tensor) -> tensor - // CHECK-DAG: %[[VAL3:.*]] = mhlo.convert %[[VAL2:.*]] : (tensor) -> tensor + // CHECK: %[[VAL1:.*]] = mhlo.convert %[[VAL0:.*]] : (tensor) -> tensor + // CHECK: %[[VAL3:.*]] = mhlo.convert %[[VAL2:.*]] : (tensor) -> tensor // CHECK-DAG: %[[VAL5:.*]] = mhlo.constant dense<3> : tensor - // CHECK-DAG: %[[VAL4:.*]] = chlo.broadcast_add %[[VAL1:.*]], %[[VAL3:.*]] : (tensor, tensor) -> tensor - // CHECK: %[[VAL6:.*]] = chlo.broadcast_subtract %[[VAL4:.*]], %[[VAL5:.*]] : (tensor, tensor) -> tensor - // CHECK: %[[VAL9:.*]] = mhlo.clamp %[[VAL7:.*]], %[[VAL6:.*]], %[[VAL8:.*]] : (tensor, tensor, tensor) -> tensor - // CHECK: %[[VAL10:.*]] = mhlo.convert %[[VAL9:.*]] : (tensor) -> tensor + // CHECK: %[[VAL4:.*]] = chlo.broadcast_add %[[VAL1]], %[[VAL3]] : (tensor, tensor) -> tensor + // CHECK: %[[VAL6:.*]] = chlo.broadcast_subtract %[[VAL4]], %[[VAL5]] : (tensor, tensor) -> tensor + // CHECK: %[[VAL9:.*]] = mhlo.clamp %[[VAL7:.*]], %[[VAL6]], %[[VAL8:.*]] : (tensor, tensor, tensor) -> tensor + // CHECK: %[[VAL10:.*]] = mhlo.convert %[[VAL9]] : (tensor) -> tensor %0 = mhlo.uniform_quantize %arg0 : (tensor) -> tensor> %1 = mhlo.uniform_quantize %arg1 : (tensor) -> tensor> %2 = mhlo.add %0, %1: (tensor>, tensor>) -> tensor> @@ -101,13 +101,13 @@ func.func @uniform_quantize_add(%arg0: tensor, %arg1: tensor) // CHECK-LABEL: func @uniform_quantize_add_int4 func.func @uniform_quantize_add_int4(%arg0: tensor, %arg1: tensor) -> () { - // CHECK-DAG: %[[VAL1:.*]] = mhlo.convert %[[VAL0:.*]] : (tensor) -> tensor - // CHECK-DAG: %[[VAL3:.*]] = mhlo.convert %[[VAL2:.*]] : (tensor) -> tensor + // CHECK: %[[VAL1:.*]] = mhlo.convert %[[VAL0:.*]] : (tensor) -> tensor + // CHECK: %[[VAL3:.*]] = mhlo.convert %[[VAL2:.*]] : (tensor) -> tensor // CHECK-DAG: %[[VAL5:.*]] = mhlo.constant dense<3> : tensor - // CHECK-DAG: %[[VAL4:.*]] = chlo.broadcast_add %[[VAL1:.*]], %[[VAL3:.*]] : (tensor, tensor) -> tensor - // CHECK: %[[VAL6:.*]] = chlo.broadcast_subtract %[[VAL4:.*]], %[[VAL5:.*]] : (tensor, tensor) -> tensor - // CHECK: %[[VAL9:.*]] = mhlo.clamp %[[VAL7:.*]], %[[VAL6:.*]], %[[VAL8:.*]] : (tensor, tensor, tensor) -> tensor - // CHECK: %[[VAL10:.*]] = mhlo.convert %[[VAL9:.*]] : (tensor) -> tensor + // CHECK: %[[VAL4:.*]] = chlo.broadcast_add %[[VAL1]], %[[VAL3]] : (tensor, tensor) -> tensor + // CHECK: %[[VAL6:.*]] = chlo.broadcast_subtract %[[VAL4]], %[[VAL5]] : (tensor, tensor) -> tensor + // CHECK: %[[VAL9:.*]] = mhlo.clamp %[[VAL7:.*]], %[[VAL6]], %[[VAL8:.*]] : (tensor, tensor, tensor) -> tensor + // CHECK: %[[VAL10:.*]] = mhlo.convert %[[VAL9]] : (tensor) -> tensor %0 = mhlo.uniform_quantize %arg0 : (tensor) -> tensor> %1 = mhlo.uniform_quantize %arg1 : (tensor) -> tensor> %2 = mhlo.add %0, %1: (tensor>, tensor>) -> tensor> @@ -243,17 +243,17 @@ func.func @uniform_quantize_requantize_and_dequantize(%arg0: tensor) -> // CHECK-LABEL: func @uniform_quantize_dot_dequantize func.func @uniform_quantize_dot_dequantize(%arg0: tensor, %arg1: tensor) -> tensor { // CHECK: %[[VAL1:.*]] = mhlo.convert %[[VAL0:.*]] : (tensor) -> tensor - // CHECK: %[[VAL3:.*]] = chlo.broadcast_subtract %[[VAL1:.*]], %[[VAL2:.*]] : (tensor, tensor) -> tensor + // CHECK: %[[VAL3:.*]] = chlo.broadcast_subtract %[[VAL1]], %[[VAL2:.*]] : (tensor, tensor) -> tensor // CHECK: %[[VAL5:.*]] = mhlo.convert %[[VAL4:.*]] : (tensor) -> tensor - // CHECK: %[[VAL7:.*]] = chlo.broadcast_subtract %[[VAL5:.*]], %[[VAL6:.*]] : (tensor, tensor) -> tensor - // CHECK: %[[VAL8:.*]] = "mhlo.dot"(%[[VAL3:.*]], %[[VAL7:.*]]) : (tensor, tensor) -> tensor - // CHECK: %[[VAL10:.*]] = chlo.broadcast_multiply %[[VAL8:.*]], %[[VAL9:.*]] : (tensor, tensor) -> tensor - // CHECK: %[[VAL12:.*]] = chlo.broadcast_add %[[VAL10:.*]], %[[VAL11:.*]] : (tensor, tensor) -> tensor - // CHECK: %[[VAL13:.*]] = mhlo.floor %[[VAL12:.*]] : tensor - // CHECK: %[[VAL15:.*]] = chlo.broadcast_add %[[VAL13:.*]], %[[VAL14:.*]] : (tensor, tensor) -> tensor - // CHECK: %[[VAL16:.*]] = mhlo.convert %[[VAL15:.*]] : (tensor) -> tensor - // CHECK: %[[VAL19:.*]] = mhlo.clamp %[[VAL17:.*]], %[[VAL16:.*]], %[[VAL18:.*]] : (tensor, tensor, tensor) -> tensor - // CHECK: %[[VAL20:.*]] = mhlo.convert %[[VAL19:.*]] : (tensor) -> tensor + // CHECK: %[[VAL7:.*]] = chlo.broadcast_subtract %[[VAL5]], %[[VAL6:.*]] : (tensor, tensor) -> tensor + // CHECK: %[[VAL8:.*]] = "mhlo.dot"(%[[VAL3]], %[[VAL7]]) : (tensor, tensor) -> tensor + // CHECK: %[[VAL10:.*]] = chlo.broadcast_multiply %[[VAL8]], %[[VAL9:.*]] : (tensor, tensor) -> tensor + // CHECK: %[[VAL12:.*]] = chlo.broadcast_add %[[VAL10]], %[[VAL11:.*]] : (tensor, tensor) -> tensor + // CHECK: %[[VAL13:.*]] = mhlo.floor %[[VAL12]] : tensor + // CHECK: %[[VAL15:.*]] = chlo.broadcast_add %[[VAL13]], %[[VAL14:.*]] : (tensor, tensor) -> tensor + // CHECK: %[[VAL16:.*]] = mhlo.convert %[[VAL15]] : (tensor) -> tensor + // CHECK: %[[VAL19:.*]] = mhlo.clamp %[[VAL17:.*]], %[[VAL16]], %[[VAL18:.*]] : (tensor, tensor, tensor) -> tensor + // CHECK: %[[VAL20:.*]] = mhlo.convert %[[VAL19]] : (tensor) -> tensor %0 = mhlo.uniform_quantize %arg0 : (tensor) -> tensor> %1 = mhlo.uniform_quantize %arg1 : (tensor) -> tensor> %2 = "mhlo.dot" (%0, %1) : (tensor>, tensor>) -> tensor> @@ -308,3 +308,29 @@ func.func @uniform_quantized_convolution(%arg0: tensor, %arg1: tens -> tensor> return } + +// ----- + +// CHECK-LABEL: func @uniform_quantize_dot_hybrid +func.func @uniform_quantize_dot_hybrid(%arg0: tensor, %arg1: tensor) { + // CHECK: %[[VAL1:.*]] = mhlo.convert %[[VAL0:.*]] : (tensor) -> tensor + // CHECK: %[[VAL3:.*]] = chlo.broadcast_subtract %[[VAL1]], %[[VAL2:.*]] : (tensor, tensor) -> tensor + // CHECK: %[[VAL5:.*]] = chlo.broadcast_multiply %[[VAL3]], %[[VAL4:.*]] : (tensor, tensor) -> tensor + // CHECK: %[[VAL7:.*]] = "mhlo.dot"(%[[VAL6:.*]], %[[VAL5]]) : (tensor, tensor) -> tensor + // CHECK: %[[VAL9:.*]] = chlo.broadcast_add %[[VAL7]], %[[VAL8:.*]] : (tensor, tensor) -> tensor + // CHECK: %[[VAL10:.*]] = mhlo.floor %[[VAL9]] : tensor + // CHECK: %[[VAL11:.*]] = mhlo.convert %[[VAL10]] : (tensor) -> tensor + %0 = mhlo.uniform_quantize %arg1 : (tensor) -> tensor> + %1 = "mhlo.dot" (%arg0, %0): (tensor, tensor>) -> tensor + return +} + +// ----- + +func.func @uniform_quantize_dot_hybrid_result_type_not_float(%arg0: tensor, %arg1: tensor) { + %0 = mhlo.uniform_quantize %arg1 : (tensor) -> tensor> + // expected-error@+2 {{Unsupported result element type for mhlo.dot}} + // expected-error@+1 {{failed to legalize operation 'mhlo.dot' that was explicitly marked illegal}} + %1 = "mhlo.dot" (%arg0, %0): (tensor, tensor>) -> tensor> + return +} From 36fe2f95cd48bfe9a781f08adf8bf9129b8403c3 Mon Sep 17 00:00:00 2001 From: Arian Arfaian Date: Mon, 10 Jul 2023 17:53:20 -0700 Subject: [PATCH 105/376] Automatically detect whether the input model contains StableHLO. Removes the need for special flags to handle models generated from JAX with `native_serialization=True`. Also removes the `_experimental_enable_hlo_to_tf_conversion` option from the Converter API. PiperOrigin-RevId: 547033526 --- tensorflow/compiler/mlir/lite/python/BUILD | 3 +-- .../mlir/lite/python/graphdef_to_tfl_flatbuffer.cc | 2 -- .../lite/python/saved_model_to_tfl_flatbuffer.cc | 2 -- .../mlir/lite/python/tf_tfl_flatbuffer_helpers.cc | 13 +++++++++++++ tensorflow/lite/python/convert.py | 6 ------ tensorflow/lite/python/lite.py | 4 ---- tensorflow/lite/toco/toco_flags.proto | 3 ++- 7 files changed, 16 insertions(+), 17 deletions(-) diff --git a/tensorflow/compiler/mlir/lite/python/BUILD b/tensorflow/compiler/mlir/lite/python/BUILD index 51618d4826e6ad..9290272db90d28 100644 --- a/tensorflow/compiler/mlir/lite/python/BUILD +++ b/tensorflow/compiler/mlir/lite/python/BUILD @@ -27,11 +27,10 @@ cc_library( "//tensorflow/compiler/mlir/lite:tensorflow_lite", "//tensorflow/compiler/mlir/lite:tf_tfl_passes", "//tensorflow/compiler/mlir/lite:tf_to_tfl_flatbuffer", - "//tensorflow/compiler/mlir/lite/metrics:error_collector_inst", - "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:dump_mlir_util", "//tensorflow/compiler/mlir/tensorflow:import_model", "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops", "//tensorflow/core:core_cpu_base", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", diff --git a/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc index efa633e736ae69..b683e3859afc44 100644 --- a/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc @@ -117,8 +117,6 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags, pass_config.guarantee_all_funcs_one_use = toco_flags.guarantee_all_funcs_one_use(); pass_config.enable_stablehlo_conversion = toco_flags.convert_to_stablehlo(); - pass_config.enable_hlo_to_tf_conversion = - toco_flags.enable_hlo_to_tf_conversion(); return internal::ConvertMLIRToTFLiteFlatBuffer( model_flags, toco_flags, std::move(module), pass_config, diff --git a/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc index 2e5819a0e2fc63..e955159990457e 100644 --- a/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc @@ -201,8 +201,6 @@ Status ConvertSavedModelToTFLiteFlatBuffer(const toco::ModelFlags& model_flags, pass_config.guarantee_all_funcs_one_use = toco_flags.guarantee_all_funcs_one_use(); pass_config.enable_stablehlo_conversion = toco_flags.convert_to_stablehlo(); - pass_config.enable_hlo_to_tf_conversion = - toco_flags.enable_hlo_to_tf_conversion(); pass_config.legalize_custom_tensor_list_ops = toco_flags.legalize_custom_tensor_list_ops(); diff --git a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc index fb5efba769a066..c695d0a7f499df 100644 --- a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc +++ b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc @@ -32,6 +32,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/tf_tfl_passes.h" #include "tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h" #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h" #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h" @@ -346,6 +347,18 @@ Status ConvertMLIRToTFLiteFlatBuffer( mlir::TFL::PassConfig pass_config_copy = pass_config; pass_config_copy.outline_tf_while = true; + + // Checks whether the model contains an `XlaCallModuleOp` operation which + // is a wrapper around StableHLO. + // This option is mutually exclusive to `enable_stablehlo_conversion`, the + // latter of which takes precedence. + // TODO(b/290109282): explore removing the enable_hlo_to_tf_conversion flag + // entirely, such that the added passes are no-ops in the non-shlo case. + module->walk([&](mlir::TF::XlaCallModuleOp xla_call_module_op) { + pass_config_copy.enable_hlo_to_tf_conversion = true; + mlir::WalkResult::interrupt(); + }); + auto status = ConvertTFExecutorToTFLOrFlatbuffer( module.get(), /*export_to_mlir=*/false, toco_flags, pass_config_copy, saved_model_tags, model_flags.saved_model_dir(), session, result); diff --git a/tensorflow/lite/python/convert.py b/tensorflow/lite/python/convert.py index 28d8f7629be940..0bfe04c903c8d2 100644 --- a/tensorflow/lite/python/convert.py +++ b/tensorflow/lite/python/convert.py @@ -577,7 +577,6 @@ def build_conversion_flags( enable_mlir_variable_quantization=False, disable_fuse_mul_and_fc=False, quantization_options: Optional[quant_opts_pb2.QuantizationOptions] = None, - enable_hlo_to_tf_conversion=False, mlir_dump_dir=None, mlir_dump_pass_regex=None, mlir_dump_func_regex=None, @@ -686,9 +685,6 @@ def build_conversion_flags( a custom method, and allows finer, modular control. This option will override any other existing quantization flags. We plan on gradually migrating all quantization-related specs into this option. - enable_hlo_to_tf_conversion: Enable HLO to TF conversion in the Converter. - Set this to False by default as this may increase the conversion time if - set otherwise. mlir_dump_dir: A string specifying the target directory to output MLIR dumps produced during conversion. If populated, enables MLIR dumps. mlir_dump_pass_regex: A string containing a regular expression for filtering @@ -797,8 +793,6 @@ def build_conversion_flags( if quantization_options: conversion_flags.quantization_options.CopyFrom(quantization_options) - conversion_flags.enable_hlo_to_tf_conversion = enable_hlo_to_tf_conversion - # Transfer debug options. Check for existence before populating in order to # leverage defaults specified in proto definition. if mlir_dump_dir is not None: diff --git a/tensorflow/lite/python/lite.py b/tensorflow/lite/python/lite.py index f414485a6ecedd..8c31c7f973fcff 100644 --- a/tensorflow/lite/python/lite.py +++ b/tensorflow/lite/python/lite.py @@ -635,7 +635,6 @@ def __init__(self): self._experimental_enable_dynamic_update_slice = False self._experimental_preserve_assert_op = False self._experimental_guarantee_all_funcs_one_use = False - self._experimental_enable_hlo_to_tf_conversion = False # When the value is true, the MLIR quantantizer triggers dynamic range # quantization in MLIR instead of the old quantizer. Used only if @@ -790,9 +789,6 @@ def _get_base_converter_args(self): "allow_all_select_tf_ops": self._experimental_allow_all_select_tf_ops, "disable_fuse_mul_and_fc": self._experimental_disable_fuse_mul_and_fc, "quantization_options": self._experimental_quantization_options, - "enable_hlo_to_tf_conversion": ( - self._experimental_enable_hlo_to_tf_conversion - ), "mlir_dump_dir": self.mlir_dump_dir, "mlir_dump_pass_regex": self.mlir_dump_pass_regex, "mlir_dump_func_regex": self.mlir_dump_func_regex, diff --git a/tensorflow/lite/toco/toco_flags.proto b/tensorflow/lite/toco/toco_flags.proto index c8cad79d3bacfd..1421f614f82ccb 100644 --- a/tensorflow/lite/toco/toco_flags.proto +++ b/tensorflow/lite/toco/toco_flags.proto @@ -320,7 +320,8 @@ message TocoFlags { // Flag to enable hlo to tf conversion. // This is useful to exercise StableHLO -> HLO -> TF -> TFLite path. - optional bool enable_hlo_to_tf_conversion = 55 [default = false]; + optional bool enable_hlo_to_tf_conversion = 55 + [default = false, deprecated = true]; // Additional parameters for controlling debug facilities. optional tensorflow.converter.DebugOptions debug_options = 56; From 305a5a6754b725e70d6123019a694830923d84a5 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 10 Jul 2023 18:34:23 -0700 Subject: [PATCH 106/376] Do not run `cudnn_fused_conv_rewriter_test` in debug mode while we investigate an ongoing failure. PiperOrigin-RevId: 547040235 --- tensorflow/compiler/xla/service/gpu/BUILD | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index b2c7fd02999d60..e1484457245e6f 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -3178,6 +3178,7 @@ xla_cc_test( "gpu", "no_oss", "noasan", + "nodebug", # TODO(b/290684889): Fails in debug mode. "nomsan", # This test runs some fusions that are only supported on Ampere+. "requires-gpu-sm80", From 1d829f0fcd84a396fd557f34e8082873dfe9d0f1 Mon Sep 17 00:00:00 2001 From: Fiona Lang Date: Mon, 10 Jul 2023 19:26:48 -0700 Subject: [PATCH 107/376] Update ops.Tensor references to //third_party/tensorflow/python/framework/tensor.py. PiperOrigin-RevId: 547047858 --- .../python/integration_test/quantize_model_test_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/quantize_model_test_base.py b/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/quantize_model_test_base.py index d7dc023e1dc87c..5983c2b581953a 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/quantize_model_test_base.py +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/quantize_model_test_base.py @@ -907,7 +907,7 @@ def _save_tf1_model( inputs: Mapping[str, core.Tensor], outputs: Mapping[str, core.Tensor], init_op: Optional[ops.Operation] = None, - assets_collection: Optional[Sequence[ops.Tensor]] = None, + assets_collection: Optional[Sequence[core.Symbol]] = None, ) -> None: """Saves a TF1 model. From 27822155281c4edfda052b9bb431c5192c75bf0c Mon Sep 17 00:00:00 2001 From: Dateng Lin Date: Mon, 10 Jul 2023 19:44:51 -0700 Subject: [PATCH 108/376] Made a new API allowing a worker to create local context. Also put the helper functions in the test to the test utils. PiperOrigin-RevId: 547050239 --- tensorflow/c/eager/BUILD | 2 + tensorflow/c/eager/c_api_experimental.cc | 15 +++ tensorflow/c/eager/c_api_experimental.h | 6 ++ tensorflow/c/eager/c_api_experimental_test.cc | 62 ++++++++++++ tensorflow/c/eager/c_api_test.cc | 65 ------------- tensorflow/c/eager/c_api_test_util.cc | 65 +++++++++++++ tensorflow/c/eager/c_api_test_util.h | 12 +++ .../immediate_execution_distributed_manager.h | 6 ++ .../eager/context_distributed_manager.cc | 97 ++++++++++++++++--- .../eager/context_distributed_manager.h | 3 + 10 files changed, 253 insertions(+), 80 deletions(-) diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index 748d49565f64a1..0b543f7dcbf9b7 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -930,10 +930,12 @@ tf_cuda_cc_test( ":c_api_experimental", ":c_api_test_util", "//tensorflow/c:c_test_util", + "//tensorflow/c/eager:c_api_internal", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", + "//tensorflow/core/distributed_runtime:server_lib", "//tensorflow/core/platform:status", "@com_google_absl//absl/strings", ], diff --git a/tensorflow/c/eager/c_api_experimental.cc b/tensorflow/c/eager/c_api_experimental.cc index 6fbcb7bb56a69e..db8b28437607f6 100644 --- a/tensorflow/c/eager/c_api_experimental.cc +++ b/tensorflow/c/eager/c_api_experimental.cc @@ -939,3 +939,18 @@ void TFE_WaitAtBarrier(TFE_Context* ctx, const char* barrier_id, status->status = coord_agent->WaitAtBarrier( barrier_id, absl::Milliseconds(barrier_timeout_in_ms), {}); } + +void TFE_InitializeLocalOnlyContext(TFE_Context* ctx, int keep_alive_secs, + const void* proto, size_t proto_len, + TF_Status* status) { + tensorflow::ServerDef server_def; + if (!server_def.ParseFromArray(proto, proto_len)) { + status->status = tensorflow::errors::InvalidArgument( + "Invalid tensorflow.ServerDef protocol buffer"); + return; + } + status->status = + tensorflow::unwrap(ctx) + ->GetDistributedManager() + ->InitializeLocalOnlyContext(server_def, keep_alive_secs); +} diff --git a/tensorflow/c/eager/c_api_experimental.h b/tensorflow/c/eager/c_api_experimental.h index fcbced2080a082..dc88de351f74fa 100644 --- a/tensorflow/c/eager/c_api_experimental.h +++ b/tensorflow/c/eager/c_api_experimental.h @@ -747,6 +747,12 @@ TF_CAPI_EXPORT extern void TFE_WaitAtBarrier(TFE_Context* ctx, int64_t barrier_timeout_in_ms, TF_Status* status); +TF_CAPI_EXPORT extern void TFE_InitializeLocalOnlyContext(TFE_Context* ctx, + int keep_alive_secs, + const void* proto, + size_t proto_len, + TF_Status* status); + #ifdef __cplusplus } /* end extern "C" */ #endif diff --git a/tensorflow/c/eager/c_api_experimental_test.cc b/tensorflow/c/eager/c_api_experimental_test.cc index 68dbafc4d2a1e6..51e56827114cb6 100644 --- a/tensorflow/c/eager/c_api_experimental_test.cc +++ b/tensorflow/c/eager/c_api_experimental_test.cc @@ -18,13 +18,17 @@ limitations under the License. #include #include "tensorflow/c/eager/c_api.h" +#include "tensorflow/c/eager/c_api_internal.h" #include "tensorflow/c/eager/c_api_test_util.h" +#include "tensorflow/core/distributed_runtime/server_lib.h" #include "tensorflow/core/lib/monitoring/collection_registry.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/str_util.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/protobuf/cluster.pb.h" +#include "tensorflow/core/protobuf/config.pb.h" using tensorflow::string; @@ -522,5 +526,63 @@ TEST(CAPI, TensorHandleDefaults) { TFE_DeleteContext(ctx); } +TEST(CAPI, CreateLocalContextAsReset) { + tensorflow::ServerDef server_def = GetServerDef("worker", 2); + server_def.mutable_default_session_config()->set_isolate_session_state(false); + + ServerFactory* factory; + ASSERT_TRUE(ServerFactory::GetFactory(server_def, &factory).ok()); + server_def.set_job_name("worker"); + server_def.set_task_index(0); + std::unique_ptr w0; + ASSERT_TRUE( + factory->NewServer(server_def, ServerFactory::Options(), &w0).ok()); + ASSERT_TRUE(w0->Start().ok()); + server_def.set_task_index(1); + std::unique_ptr w1; + ASSERT_TRUE( + factory->NewServer(server_def, ServerFactory::Options(), &w1).ok()); + ASSERT_TRUE(w1->Start().ok()); + + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + opts->session_options.options.config.set_isolate_session_state(false); + TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT); + TFE_Context* ctx = TFE_NewContext(opts, status); + EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + + server_def.set_task_index(0); + auto cluster = server_def.mutable_cluster(); + auto client_job = cluster->add_job(); + client_job->set_name("localhost"); + int client_port = tensorflow::testing::PickUnusedPortOrDie(); + client_job->mutable_tasks()->insert( + {0, strings::StrCat("localhost:", client_port)}); + server_def.set_job_name("localhost"); + auto serialized = server_def.SerializeAsString(); + TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status); + EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + + EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + server_def.set_job_name("worker"); + server_def.set_task_index(0); + tensorflow::ClusterDef* cluster_def = server_def.mutable_cluster(); + tensorflow::JobDef* job_def = cluster_def->mutable_job(0); + int worker_port = tensorflow::testing::PickUnusedPortOrDie(); + job_def->mutable_tasks()->at(0) = + tensorflow::strings::StrCat("localhost:", worker_port); + serialized = server_def.SerializeAsString(); + TFE_InitializeLocalOnlyContext(ctx, 0, serialized.data(), serialized.size(), + status); + EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + + TFE_DeleteContextOptions(opts); + TFE_DeleteContext(ctx); + TF_DeleteStatus(status); + + w0.release(); + w1.release(); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc index 254648d9e09309..19c078cbc47e9d 100644 --- a/tensorflow/c/eager/c_api_test.cc +++ b/tensorflow/c/eager/c_api_test.cc @@ -1946,71 +1946,6 @@ tensorflow::ServerDef ReplaceTaskInServerDef( return server_def_copy; } -TFE_TensorHandle* CreateVarHandle(TFE_Context* ctx, - const tensorflow::string& device_name, - const tensorflow::string& variable_name) { - TF_Status* status = TF_NewStatus(); - // Create the variable handle. - TFE_Op* op = TFE_NewOp(ctx, "VarHandleOp", status); - if (TF_GetCode(status) != TF_OK) return nullptr; - TFE_OpSetAttrType(op, "dtype", TF_FLOAT); - TFE_OpSetAttrShape(op, "shape", {}, 0, status); - TFE_OpSetAttrString(op, "container", "localhost", 0); - TFE_OpSetAttrString(op, "shared_name", variable_name.data(), - variable_name.size()); - if (!device_name.empty()) { - TFE_OpSetDevice(op, device_name.c_str(), status); - } - if (TF_GetCode(status) != TF_OK) return nullptr; - TFE_TensorHandle* var_handle = nullptr; - int num_retvals = 1; - TFE_Execute(op, &var_handle, &num_retvals, status); - if (TF_GetCode(status) != TF_OK) return nullptr; - TFE_DeleteOp(op); - if (TF_GetCode(status) != TF_OK) return nullptr; - CHECK_EQ(1, num_retvals); - TF_DeleteStatus(status); - return var_handle; -} - -TFE_TensorHandle* CreateVariable(TFE_Context* ctx, float value, - const tensorflow::string& device_name, - const tensorflow::string& variable_name) { - TF_Status* status = TF_NewStatus(); - TFE_TensorHandle* var_handle = - CreateVarHandle(ctx, device_name, variable_name); - - // Assign 'value' to it. - TFE_Op* op = TFE_NewOp(ctx, "AssignVariableOp", status); - if (TF_GetCode(status) != TF_OK) return nullptr; - TFE_OpSetAttrType(op, "dtype", TF_FLOAT); - TFE_OpAddInput(op, var_handle, status); - if (!device_name.empty()) { - TFE_OpSetDevice(op, device_name.c_str(), status); - } - - // Convert 'value' to a TF_Tensor then a TFE_TensorHandle. - std::unique_ptr t( - TF_AllocateTensor(TF_FLOAT, nullptr, 0, sizeof(value)), TF_DeleteTensor); - memcpy(TF_TensorData(t.get()), &value, TF_TensorByteSize(t.get())); - - std::unique_ptr - value_handle(TFE_NewTensorHandle(t.get(), status), - TFE_DeleteTensorHandle); - if (TF_GetCode(status) != TF_OK) return nullptr; - - TFE_OpAddInput(op, value_handle.get(), status); - if (TF_GetCode(status) != TF_OK) return nullptr; - - int num_retvals = 0; - TFE_Execute(op, nullptr, &num_retvals, status); - TFE_DeleteOp(op); - if (TF_GetCode(status) != TF_OK) return nullptr; - CHECK_EQ(0, num_retvals); - TF_DeleteStatus(status); - return var_handle; -} - TFE_Context* CreateContext(const string& serialized_server_def, bool isolate_session_state) { TF_Status* status = TF_NewStatus(); diff --git a/tensorflow/c/eager/c_api_test_util.cc b/tensorflow/c/eager/c_api_test_util.cc index 1fb76748059a20..75450e5c7aa88b 100644 --- a/tensorflow/c/eager/c_api_test_util.cc +++ b/tensorflow/c/eager/c_api_test_util.cc @@ -485,3 +485,68 @@ tensorflow::ServerDef GetMultiClientServerDef(const std::string& job_name, } return server_def; } + +TFE_TensorHandle* CreateVarHandle(TFE_Context* ctx, + const tensorflow::string& device_name, + const tensorflow::string& variable_name) { + TF_Status* status = TF_NewStatus(); + // Create the variable handle. + TFE_Op* op = TFE_NewOp(ctx, "VarHandleOp", status); + if (TF_GetCode(status) != TF_OK) return nullptr; + TFE_OpSetAttrType(op, "dtype", TF_FLOAT); + TFE_OpSetAttrShape(op, "shape", {}, 0, status); + TFE_OpSetAttrString(op, "container", "localhost", 0); + TFE_OpSetAttrString(op, "shared_name", variable_name.data(), + variable_name.size()); + if (!device_name.empty()) { + TFE_OpSetDevice(op, device_name.c_str(), status); + } + if (TF_GetCode(status) != TF_OK) return nullptr; + TFE_TensorHandle* var_handle = nullptr; + int num_retvals = 1; + TFE_Execute(op, &var_handle, &num_retvals, status); + if (TF_GetCode(status) != TF_OK) return nullptr; + TFE_DeleteOp(op); + if (TF_GetCode(status) != TF_OK) return nullptr; + CHECK_EQ(1, num_retvals); + TF_DeleteStatus(status); + return var_handle; +} + +TFE_TensorHandle* CreateVariable(TFE_Context* ctx, float value, + const tensorflow::string& device_name, + const tensorflow::string& variable_name) { + TF_Status* status = TF_NewStatus(); + TFE_TensorHandle* var_handle = + CreateVarHandle(ctx, device_name, variable_name); + + // Assign 'value' to it. + TFE_Op* op = TFE_NewOp(ctx, "AssignVariableOp", status); + if (TF_GetCode(status) != TF_OK) return nullptr; + TFE_OpSetAttrType(op, "dtype", TF_FLOAT); + TFE_OpAddInput(op, var_handle, status); + if (!device_name.empty()) { + TFE_OpSetDevice(op, device_name.c_str(), status); + } + + // Convert 'value' to a TF_Tensor then a TFE_TensorHandle. + std::unique_ptr t( + TF_AllocateTensor(TF_FLOAT, nullptr, 0, sizeof(value)), TF_DeleteTensor); + memcpy(TF_TensorData(t.get()), &value, TF_TensorByteSize(t.get())); + + std::unique_ptr + value_handle(TFE_NewTensorHandle(t.get(), status), + TFE_DeleteTensorHandle); + if (TF_GetCode(status) != TF_OK) return nullptr; + + TFE_OpAddInput(op, value_handle.get(), status); + if (TF_GetCode(status) != TF_OK) return nullptr; + + int num_retvals = 0; + TFE_Execute(op, nullptr, &num_retvals, status); + TFE_DeleteOp(op); + if (TF_GetCode(status) != TF_OK) return nullptr; + CHECK_EQ(0, num_retvals); + TF_DeleteStatus(status); + return var_handle; +} diff --git a/tensorflow/c/eager/c_api_test_util.h b/tensorflow/c/eager/c_api_test_util.h index ce8546fb4f4186..3dad82723b6453 100644 --- a/tensorflow/c/eager/c_api_test_util.h +++ b/tensorflow/c/eager/c_api_test_util.h @@ -145,4 +145,16 @@ tensorflow::ServerDef GetMultiClientServerDef(const std::string& job_name, int num_tasks, int num_virtual_gpus = 0); +// Create a variable handle with name `variable_name` on a device with name +// `device_name`. +TFE_TensorHandle* CreateVarHandle(TFE_Context* ctx, + const tensorflow::string& device_name, + const tensorflow::string& variable_name); + +// Create a variable with value `value` and name `variable_name` on a device +// with name `device_name`. +TFE_TensorHandle* CreateVariable(TFE_Context* ctx, float value, + const tensorflow::string& device_name, + const tensorflow::string& variable_name); + #endif // TENSORFLOW_C_EAGER_C_API_TEST_UTIL_H_ diff --git a/tensorflow/c/eager/immediate_execution_distributed_manager.h b/tensorflow/c/eager/immediate_execution_distributed_manager.h index 4f96992e7393af..b0fcc49c0b8c36 100644 --- a/tensorflow/c/eager/immediate_execution_distributed_manager.h +++ b/tensorflow/c/eager/immediate_execution_distributed_manager.h @@ -44,6 +44,12 @@ class ImmediateExecutionDistributedManager { bool reset_context, int keep_alive_secs) = 0; + // Initializes context for the local worker and no contexts will be created + // for remote workers. Currently this only works for resetting context. + // TODO(b/289445025): Consider removing this when we find a proper fix. + virtual Status InitializeLocalOnlyContext(const ServerDef& server_def, + int keep_alive_secs) = 0; + // Set up a multi-client distributed execution environment. Must be called // on all tasks in the cluster. This call internally coordinates with other // tasks to initialize the eager context and TF server for multi-client diff --git a/tensorflow/core/common_runtime/eager/context_distributed_manager.cc b/tensorflow/core/common_runtime/eager/context_distributed_manager.cc index 27c77a6a9e4fb0..1287646bb1736f 100644 --- a/tensorflow/core/common_runtime/eager/context_distributed_manager.cc +++ b/tensorflow/core/common_runtime/eager/context_distributed_manager.cc @@ -63,6 +63,20 @@ limitations under the License. #endif // !IS_MOBILE_PLATFORM namespace tensorflow { + +// We don't use the TF_RETURN_IF_ERROR macro directly since that destroys the +// server object (which currently CHECK-fails) and we miss the error, instead, +// we log the error, and then return to allow the user to see the error +// message. +#define LOG_AND_RETURN_IF_ERROR(...) \ + do { \ + const tensorflow::Status _status = (__VA_ARGS__); \ + if (TF_PREDICT_FALSE(!_status.ok())) { \ + LOG(ERROR) << _status.message(); \ + return _status; \ + } \ + } while (0); + #if !defined(IS_MOBILE_PLATFORM) namespace { @@ -641,7 +655,6 @@ Status UpdateContextWithServerDef(EagerContext* context, added_workers, removed_workers)); LOG_AND_RETURN_IF_ERROR(sg.as_summary_status()); } -#undef LOG_AND_RETURN_IF_ERROR return OkStatus(); } @@ -682,21 +695,76 @@ Status EagerContextDistributedManager::SetOrUpdateServerDef( return s; } +Status EagerContextDistributedManager::InitializeLocalOnlyContext( + const ServerDef& server_def, int keep_alive_secs) { + string worker_name = + strings::StrCat("/job:", server_def.job_name(), + "/replica:0/task:", server_def.task_index()); + // New server created for new server_def. Unused if updating server_def. + std::unique_ptr new_server; + ServerInterface* server; + DeviceMgr* device_mgr = AreLocalDevicesCompatible(context_, server_def) + ? context_->local_device_mgr() + : nullptr; + LOG_AND_RETURN_IF_ERROR( + NewServerWithOptions(server_def, {device_mgr}, &new_server)); + server = new_server.get(); + uint64 context_id = EagerContext::NewContextId(); + // Make master eager context accessible by local eager service, which might + // receive send tensor requests from remote workers. + LOG_AND_RETURN_IF_ERROR( + server->AddMasterEagerContextToEagerService(context_id, context_)); + + std::vector local_device_attributes; + server->worker_env()->device_mgr->ListDeviceAttributes( + &local_device_attributes); + + auto session_name = strings::StrCat("eager_", context_id); + auto* session_mgr = server->worker_env()->session_mgr; + tsl::core::RefCountPtr r = + server->worker_env()->rendezvous_mgr->Find(context_id); + std::shared_ptr worker_session; + protobuf::RepeatedPtrField device_attributes( + local_device_attributes.begin(), local_device_attributes.end()); + LOG_AND_RETURN_IF_ERROR(session_mgr->CreateSession( + session_name, server_def, device_attributes, + context_->session_options().config.isolate_session_state())); + LOG_AND_RETURN_IF_ERROR(server->SetCoordinationServiceAgentInstance( + session_mgr->GetCoordinationServiceAgent())); + LOG_AND_RETURN_IF_ERROR( + session_mgr->WorkerSessionForSession(session_name, &worker_session)); + + // Initialize remote tensor communication based on worker session. + LOG_AND_RETURN_IF_ERROR(r->Initialize(worker_session.get())); + + DistributedFunctionLibraryRuntime* cluster_flr = + eager::CreateClusterFLR(context_id, context_, worker_session.get()); + auto remote_mgr = std::make_unique( + /*is_master=*/true, context_); + + // The remote workers and device manager are ignored since this initialization + // is local only. + LOG_AND_RETURN_IF_ERROR(context_->InitializeRemoteMaster( + std::move(new_server), server->worker_env(), worker_session, + /*remote_eager_workers=*/nullptr, /*remote_device_manager=*/nullptr, + /*remote_contexts=*/{}, context_id, std::move(r), + server->worker_env()->device_mgr, keep_alive_secs, cluster_flr, + std::move(remote_mgr))); + + // NOTE: We start the server after all other initialization, because the + // GrpcServer cannot be destroyed after it is started. + LOG_AND_RETURN_IF_ERROR(server->Start()); + + // If context is reset, make sure pointer is set to the new agent. + coordination_service_agent_ = + context_->GetServer() + ->worker_env() + ->session_mgr->GetCoordinationServiceAgent(); + return OkStatus(); +} + Status EagerContextDistributedManager::EnableCollectiveOps( const ServerDef& server_def) { - // We don't use the TF_RETURN_IF_ERROR macro directly since that destroys the - // server object (which currently CHECK-fails) and we miss the error, instead, - // we log the error, and then return to allow the user to see the error - // message. -#define LOG_AND_RETURN_IF_ERROR(...) \ - do { \ - const tensorflow::Status _status = (__VA_ARGS__); \ - if (TF_PREDICT_FALSE(!_status.ok())) { \ - LOG(ERROR) << _status.message(); \ - return _status; \ - } \ - } while (0); - ServerInterface* server = context_->GetServer(); if (server == nullptr) { std::unique_ptr new_server; @@ -789,7 +857,6 @@ Status EagerContextDistributedManager::EnableCollectiveOps( /*new_server=*/nullptr, server->worker_env()->device_mgr, server->worker_env()->collective_executor_mgr.get())); } -#undef LOG_AND_RETURN_IF_ERROR return OkStatus(); } diff --git a/tensorflow/core/common_runtime/eager/context_distributed_manager.h b/tensorflow/core/common_runtime/eager/context_distributed_manager.h index 279a792a87b522..f13c01c842ca1a 100644 --- a/tensorflow/core/common_runtime/eager/context_distributed_manager.h +++ b/tensorflow/core/common_runtime/eager/context_distributed_manager.h @@ -44,6 +44,9 @@ class EagerContextDistributedManager Status SetOrUpdateServerDef(const ServerDef& server_def, bool reset_context, int keep_alive_secs) override; + Status InitializeLocalOnlyContext(const ServerDef& server_def, + int keep_alive_secs) override; + Status EnableCollectiveOps(const ServerDef& server_def) override; Status CheckRemoteAlive(const std::string& remote_task_name, From 8a3ee99f6f7050ca42fb07acbbf74af45cd854d0 Mon Sep 17 00:00:00 2001 From: David Dunleavy Date: Mon, 10 Jul 2023 20:11:18 -0700 Subject: [PATCH 109/376] [NFC] Change uses of get_compatible_with_cloud to get_compatible_with_portable. PiperOrigin-RevId: 547054827 --- tensorflow/lite/delegates/flex/BUILD | 22 +++++++++++----------- tensorflow/lite/tools/versioning/BUILD | 8 ++++---- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/tensorflow/lite/delegates/flex/BUILD b/tensorflow/lite/delegates/flex/BUILD index 0dafc00ba0c70d..92ef4a610c6dd1 100644 --- a/tensorflow/lite/delegates/flex/BUILD +++ b/tensorflow/lite/delegates/flex/BUILD @@ -10,7 +10,7 @@ load( load("//tensorflow/lite:build_def.bzl", "tflite_copts") load("//tensorflow/lite:special_rules.bzl", "internal_visibility_allowlist") load("//tensorflow/lite/delegates/flex:build_def.bzl", "tflite_flex_cc_library", "tflite_flex_shared_library") -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") default_visibility = [ "//tensorflow/compiler/mlir/lite:__subpackages__", @@ -37,7 +37,7 @@ cc_library( name = "buffer_map", srcs = ["buffer_map.cc"], hdrs = ["buffer_map.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), copts = tf_opts_nortti_if_lite_protos(), features = tf_features_nolayering_check_if_ios(), deps = [ @@ -58,7 +58,7 @@ cc_library( name = "buffer_map_util", srcs = ["buffer_map_util.cc"], hdrs = ["buffer_map_util.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), copts = tf_opts_nortti_if_lite_protos(), features = tf_features_nolayering_check_if_ios(), deps = [ @@ -106,7 +106,7 @@ tf_cc_test( # ) tflite_flex_cc_library( name = "delegate", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), visibility = ["//visibility:public"], ) @@ -133,7 +133,7 @@ cc_library( srcs = [ "delegate_symbol.cc", ], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), copts = tflite_copts(), visibility = ["//visibility:public"], deps = [ @@ -155,7 +155,7 @@ cc_library( hdrs = [ "delegate.h", ], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), copts = tflite_copts() + tf_opts_nortti_if_android(), features = tf_features_nolayering_check_if_ios(), visibility = ["//visibility:public"], @@ -208,7 +208,7 @@ cc_library( name = "delegate_data", srcs = ["delegate_data.cc"], hdrs = ["delegate_data.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), copts = tf_opts_nortti_if_android(), features = tf_features_nolayering_check_if_ios(), visibility = ["//visibility:public"], @@ -259,7 +259,7 @@ tf_cc_test( cc_library( name = "subgraph_resource", hdrs = ["subgraph_resource.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), features = tf_features_nolayering_check_if_ios(), deps = [ "//tensorflow/lite:cc_api_experimental", @@ -312,7 +312,7 @@ cc_library( name = "util", srcs = ["util.cc"], hdrs = ["util.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), features = tf_features_nolayering_check_if_ios(), #TODO(b/206038955): Consider restrict the visibility to '//third_party/fcp/client:__subpackages__'. visibility = ["//visibility:public"], @@ -360,7 +360,7 @@ cc_library( "allowlisted_flex_ops.h", "allowlisted_flex_ops_internal.h", ], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), features = tf_features_nolayering_check_if_ios(), visibility = internal_visibility_allowlist(), deps = if_mobile([ @@ -403,7 +403,7 @@ cc_library( cc_library( name = "tflite_subgraph_execute", srcs = ["tflite_subgraph_execute.cc"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), copts = tf_opts_nortti_if_android(), features = tf_features_nolayering_check_if_ios(), deps = [ diff --git a/tensorflow/lite/tools/versioning/BUILD b/tensorflow/lite/tools/versioning/BUILD index 74875c04112619..a071a2615e9a93 100644 --- a/tensorflow/lite/tools/versioning/BUILD +++ b/tensorflow/lite/tools/versioning/BUILD @@ -2,7 +2,7 @@ load( "//tensorflow:tensorflow.bzl", "tf_cc_test", ) -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -20,7 +20,7 @@ cc_library( "op_version.h", "runtime_version.h", ], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ ":op_signature", "//tensorflow/core:tflite_portable_logging", @@ -62,7 +62,7 @@ cc_library( hdrs = [ "op_signature.h", ], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "//tensorflow/lite:stderr_reporter", "//tensorflow/lite/core/api", @@ -101,7 +101,7 @@ cc_library( hdrs = [ "gpu_compatibility.h", ], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ ":op_signature", "//tensorflow/lite:builtin_op_data", From 52a0ee8c7d218d7586fbdd3f3fac2ca18688e306 Mon Sep 17 00:00:00 2001 From: George Necula Date: Mon, 10 Jul 2023 22:21:40 -0700 Subject: [PATCH 110/376] Improve DumpMlirOpToFile when dumping to stderr. Previously, the log would show the module dump and at the end it would say "Dumped ...". With this change we will see "Dumping ..." before the module. PiperOrigin-RevId: 547075919 --- tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc index f07af4f8b85a2c..c90a0b419b1be3 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc @@ -218,10 +218,10 @@ std::string DumpMlirOpToFile(llvm::StringRef name, mlir::Operation* op, Status result = CreateFileForDumping(name, &os, &filepath, dirname); if (!result.ok()) return std::string(result.message()); + LOG(INFO) << "Dumping MLIR operation '" << op->getName().getStringRef().str() + << "' to '" << filepath << "'"; if (pass_manager) PrintPassPipeline(*pass_manager, op, *os); op->print(*os, mlir::OpPrintingFlags().useLocalScope()); - LOG(INFO) << "Dumped MLIR operation '" << op->getName().getStringRef().str() - << "' to '" << filepath << "'"; return filepath; } From f580c3a517b0ac88e3d742f92e2e982f22bbc3ed Mon Sep 17 00:00:00 2001 From: Kuangyuan Chen Date: Mon, 10 Jul 2023 22:37:07 -0700 Subject: [PATCH 111/376] Create a cancellation manager for each request. PiperOrigin-RevId: 547078260 --- .../kernel_fallback_compat_request_state.cc | 8 -- .../kernel_fallback_compat_request_state.h | 7 +- tensorflow/core/tfrt/graph_executor/BUILD | 1 + .../tfrt/graph_executor/graph_executor.cc | 2 + .../core/tfrt/graph_executor/graph_executor.h | 2 + .../graph_executor/graph_executor_test.cc | 101 ++++++++++++++++++ 6 files changed, 111 insertions(+), 10 deletions(-) diff --git a/tensorflow/core/runtime_fallback/kernel/kernel_fallback_compat_request_state.cc b/tensorflow/core/runtime_fallback/kernel/kernel_fallback_compat_request_state.cc index 54b017d71a61b1..363f02a89358c0 100644 --- a/tensorflow/core/runtime_fallback/kernel/kernel_fallback_compat_request_state.cc +++ b/tensorflow/core/runtime_fallback/kernel/kernel_fallback_compat_request_state.cc @@ -60,13 +60,6 @@ void FallbackResourceArray::SetResource( *resource_storage_[index], resources_.back().get()); } -static CancellationManager* GetDefaultCancellationManager() { - // TODO(b/167630926): Support cancellation by hooking up with TFRT's - // mechanism. - static auto* const default_cancellation_manager = new CancellationManager; - return default_cancellation_manager; -} - KernelFallbackCompatRequestState::KernelFallbackCompatRequestState( std::function)>* runner, const tensorflow::DeviceMgr* device_manager, int64_t step_id, @@ -85,7 +78,6 @@ KernelFallbackCompatRequestState::KernelFallbackCompatRequestState( ? collective_executor_handle_->get() : nullptr), rendezvous_(std::move(rendezvous)), - default_cancellation_manager_(GetDefaultCancellationManager()), device_manager_(device_manager), runner_table_(runner_table), resource_array_(resource_array), diff --git a/tensorflow/core/runtime_fallback/kernel/kernel_fallback_compat_request_state.h b/tensorflow/core/runtime_fallback/kernel/kernel_fallback_compat_request_state.h index 0645f02481e690..94df91109c06bd 100644 --- a/tensorflow/core/runtime_fallback/kernel/kernel_fallback_compat_request_state.h +++ b/tensorflow/core/runtime_fallback/kernel/kernel_fallback_compat_request_state.h @@ -138,7 +138,10 @@ class KernelFallbackCompatRequestState { std::function)>* runner() const { return runner_; } CancellationManager* cancellation_manager() const { - return default_cancellation_manager_; + return cancellation_manager_; + } + void set_cancellation_manager(CancellationManager* cancellation_manager) { + cancellation_manager_ = cancellation_manager; } RendezvousInterface* rendezvous() const { return rendezvous_.get(); } @@ -192,7 +195,7 @@ class KernelFallbackCompatRequestState { std::unique_ptr collective_executor_handle_; CollectiveExecutor* collective_executor_ = nullptr; core::RefCountPtr rendezvous_; - CancellationManager* default_cancellation_manager_ = nullptr; + CancellationManager* cancellation_manager_ = nullptr; const tensorflow::DeviceMgr* device_manager_ = nullptr; diff --git a/tensorflow/core/tfrt/graph_executor/BUILD b/tensorflow/core/tfrt/graph_executor/BUILD index 295e86a6ca267c..1e4e07907211df 100644 --- a/tensorflow/core/tfrt/graph_executor/BUILD +++ b/tensorflow/core/tfrt/graph_executor/BUILD @@ -129,6 +129,7 @@ tf_cc_test( "//tensorflow/cc:array_ops", "//tensorflow/cc:cc_ops", "//tensorflow/cc:const_op", + "//tensorflow/core:core_cpu_base", "//tensorflow/core:test", "//tensorflow/core/framework:graph_proto_cc", "//tensorflow/core/framework:types_proto_cc", diff --git a/tensorflow/core/tfrt/graph_executor/graph_executor.cc b/tensorflow/core/tfrt/graph_executor/graph_executor.cc index c59ef950f95946..d409b9af6a1af9 100644 --- a/tensorflow/core/tfrt/graph_executor/graph_executor.cc +++ b/tensorflow/core/tfrt/graph_executor/graph_executor.cc @@ -234,6 +234,8 @@ StatusOr> CreateRequestInfo( fallback_request_state.set_client_graph_resource_context( client_graph_resource_context); fallback_request_state.set_runtime_config(&options.runtime_config); + fallback_request_state.set_cancellation_manager( + &request_info->cancellation_manager); TF_RETURN_IF_ERROR( tensorflow::SetUpTfJitRtRequestContext(&request_context_builder)); diff --git a/tensorflow/core/tfrt/graph_executor/graph_executor.h b/tensorflow/core/tfrt/graph_executor/graph_executor.h index f345d4191160e1..a50f222d351c77 100644 --- a/tensorflow/core/tfrt/graph_executor/graph_executor.h +++ b/tensorflow/core/tfrt/graph_executor/graph_executor.h @@ -66,6 +66,8 @@ struct RequestInfo { WorkQueueInterface* request_queue = nullptr; // The task runner used by tensorflow::OpKernel. std::function)> runner; + + tensorflow::CancellationManager cancellation_manager; }; struct SymbolUids { diff --git a/tensorflow/core/tfrt/graph_executor/graph_executor_test.cc b/tensorflow/core/tfrt/graph_executor/graph_executor_test.cc index 1b704c4660b0c6..064c32fe47c78a 100644 --- a/tensorflow/core/tfrt/graph_executor/graph_executor_test.cc +++ b/tensorflow/core/tfrt/graph_executor/graph_executor_test.cc @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/cc/ops/const_op.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/grappler/utils/grappler_test.h" #include "tensorflow/core/platform/statusor.h" #include "tensorflow/core/protobuf/rewriter_config.pb.h" @@ -44,6 +45,8 @@ namespace tensorflow { namespace tfrt_stub { namespace { +using ::testing::status::StatusIs; + class GraphExecutorTest : public ::testing::TestWithParam {}; tensorflow::Status GetSimpleGraphDef(GraphDef& graph_def) { @@ -145,6 +148,104 @@ TEST_P(GraphExecutorTest, BasicWithOnlineCostAnalysis) { ::testing::ElementsAreArray({2})); } +REGISTER_OP("TestCancel") + .Input("x: T") + .Output("z: T") + .Attr("T: {int32}") + .SetShapeFn(::tensorflow::shape_inference::UnchangedShape); + +class TestCancelKernel : public OpKernel { + public: + explicit TestCancelKernel(OpKernelConstruction* context) + : OpKernel(context) {} + + void Compute(OpKernelContext* ctx) override { + auto status = absl::CancelledError(); + ctx->cancellation_manager()->StartCancelWithStatus(status); + ctx->SetStatus(status); + } +}; + +REGISTER_KERNEL_BUILDER(Name("TestCancel").Device(DEVICE_CPU), + TestCancelKernel); + +REGISTER_OP("TestIsCancelled").Output("z: T").Attr("T: {bool}").SetIsStateful(); + +class TestIsCancelledKernel : public OpKernel { + public: + explicit TestIsCancelledKernel(OpKernelConstruction* context) + : OpKernel(context) {} + + void Compute(OpKernelContext* ctx) override { + ctx->set_output( + 0, tensorflow::Tensor(ctx->cancellation_manager()->IsCancelled())); + } +}; + +REGISTER_KERNEL_BUILDER(Name("TestIsCancelled").Device(DEVICE_CPU), + TestIsCancelledKernel); + +TEST_P(GraphExecutorTest, Cancellation) { + GraphDef graph_def; + + tensorflow::GraphDefBuilder builder( + tensorflow::GraphDefBuilder::kFailImmediately); + + const tensorflow::TensorShape tensor_shape({10, 9}); + tensorflow::Node* input = tensorflow::ops::SourceOp( + "Placeholder", builder.opts() + .WithName("input") + .WithAttr("dtype", tensorflow::DT_INT32) + .WithAttr("shape", tensor_shape)); + tensorflow::ops::SourceOp("TestIsCancelled", + builder.opts() + .WithName("is_cancelled") + .WithAttr("T", tensorflow::DT_BOOL)); + tensorflow::ops::UnaryOp("TestCancel", input, + builder.opts() + .WithName("test_cancel") + .WithAttr("T", tensorflow::DT_INT32)); + + TF_ASSERT_OK(builder.ToGraphDef(&graph_def)); + + auto runtime = DefaultTfrtRuntime(/*num_threads=*/1); + GraphExecutor::Options options(runtime.get()); + options.enable_mlrt = GetParam(); + + TF_ASSERT_OK_AND_ASSIGN( + auto fallback_state, + tensorflow::tfrt_stub::FallbackState::Create( + CreateDefaultSessionOptions(options), graph_def.library())) + auto resource_context = std::make_unique(); + TF_ASSERT_OK_AND_ASSIGN( + auto graph_executor, + GraphExecutor::Create(std::move(options), *fallback_state, + std::move(resource_context), graph_def, + GetKernelRegistry())); + { + std::vector> inputs; + inputs.push_back({"input", CreateTfTensor( + /*shape=*/{1, 3}, /*data=*/{1, 1, 1})}); + + std::vector outputs; + EXPECT_THAT(graph_executor->Run(/*run_options=*/{}, inputs, + /*output_tensor_names=*/{"test_cancel:0"}, + /*target_tensor_names=*/{}, &outputs), + StatusIs(absl::StatusCode::kCancelled)); + } + + { + std::vector outputs; + TF_ASSERT_OK(graph_executor->Run(/*run_options=*/{}, /*inputs=*/{}, + /*output_tensor_names=*/{"is_cancelled:0"}, + /*target_tensor_names=*/{}, &outputs)); + ASSERT_EQ(outputs.size(), 1); + + EXPECT_THAT(GetTfTensorData(outputs[0]), + ::testing::ElementsAreArray({false})); + } +} + INSTANTIATE_TEST_SUITE_P(GraphExecutorTestSuite, GraphExecutorTest, ::testing::Bool()); From 877ee485f7c4968eab6fe4e30cee4d83d3f6abd9 Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Tue, 11 Jul 2023 00:09:15 -0700 Subject: [PATCH 112/376] Fix index out of bounds access in cuda_dnn.cc We might have removed one element from data_ptrs_vec, so accessing it with sizeof...(Args)-1 can be out of bounds. What we actually want to do is to check the last element. Also apply several fixes for ClangTidy warnings. Re-enable a test that exposed this bug when run in debug compilation mode. PiperOrigin-RevId: 547096288 --- tensorflow/compiler/xla/service/gpu/BUILD | 1 - .../gpu/cudnn_fused_conv_rewriter_test.cc | 11 ------ .../xla/stream_executor/cuda/cuda_dnn.cc | 39 +++++++++---------- 3 files changed, 19 insertions(+), 32 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index e1484457245e6f..b2c7fd02999d60 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -3178,7 +3178,6 @@ xla_cc_test( "gpu", "no_oss", "noasan", - "nodebug", # TODO(b/290684889): Fails in debug mode. "nomsan", # This test runs some fusions that are only supported on Ampere+. "requires-gpu-sm80", diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter_test.cc b/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter_test.cc index 3ef05e5fd399ea..0e6d84cd2960fd 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter_test.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter_test.cc @@ -215,11 +215,6 @@ TEST_F(CudnnFusedConvRewriterTest, DontFuseBiasWithDepthwiseConv) { } TEST_F(CudnnFusedConvRewriterTest, TestElu) { - if (!GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE)) { - GTEST_SKIP() << "Conv-Bias-Elu fusion is supported and recommended with " - "the Nvidia Ampere+ GPUs."; - } // sum = conv(x, w) + bias // select(compare(sum, 0, GT), sum, exponential-minus-one(sum)); TestMatchWithAllTypes(R"( @@ -243,12 +238,6 @@ TEST_F(CudnnFusedConvRewriterTest, TestElu) { } TEST_F(CudnnFusedConvRewriterTest, DontFuseEluWithDepthwiseConv) { - if (!GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE)) { - GTEST_SKIP() << "Conv-Bias-Elu fusion is supported and recommended with " - "the Nvidia Ampere+ GPUs."; - } - // sum = conv(x, w) + bias // select(compare(sum, 0, GT), sum, exponential-minus-one(sum)); TestNotMatchWithAllTypes(R"( diff --git a/tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc b/tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc index 6249b0ee176135..5396d565d00028 100644 --- a/tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc +++ b/tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc @@ -15,12 +15,18 @@ limitations under the License. #include "tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.h" +#include +#include #include #include +#include +#include #include #include #include +#include #include +#include #include "absl/base/optimization.h" #include "absl/base/thread_annotations.h" @@ -429,7 +435,7 @@ tsl::Status CudnnSupport::Init() { return tsl::Status(absl::StatusCode::kInternal, error); } - cudnn_.reset(new CudnnAccess(cudnn_handle)); + cudnn_ = std::make_unique(cudnn_handle); LOG(INFO) << "Loaded cuDNN version " << cudnnGetVersion(); return ::tsl::OkStatus(); @@ -1689,13 +1695,11 @@ class CudnnRnnSequenceTensorDescriptor : public dnn::RnnSequenceTensorDescriptor { CudnnRnnSequenceTensorDescriptor(GpuExecutor* parent, int max_seq_length, int batch_size, int data_size, - cudnnDataType_t data_type, RNNDataDescriptor data_handle, TensorDescriptor handle) : max_seq_length_(max_seq_length), batch_size_(batch_size), data_size_(data_size), - data_type_(data_type), handle_(std::move(handle)), rnn_data_handle_(std::move(data_handle)), handles_(max_seq_length, handle_.get()) {} @@ -1719,13 +1723,13 @@ class CudnnRnnSequenceTensorDescriptor /*nbDims=*/sizeof(dims) / sizeof(dims[0]), /*dimA=*/dims, /*strideA=*/strides)); return CudnnRnnSequenceTensorDescriptor(parent, max_seq_length, batch_size, - data_size, data_type, nullptr, + data_size, nullptr, std::move(tensor_desc)); } static tsl::StatusOr Create( GpuExecutor* parent, int max_seq_length, int batch_size, int data_size, - const absl::Span& seq_lengths, bool time_major, + absl::Span seq_lengths, bool time_major, cudnnDataType_t data_type) { if (max_seq_length <= 0) { return tsl::Status(absl::StatusCode::kInvalidArgument, @@ -1754,13 +1758,13 @@ class CudnnRnnSequenceTensorDescriptor /*batchSize=*/batch_size, /*vectorSize=*/data_size, /*seqLengthArray=*/seq_lengths_array, /*paddingFill*/ (void*)&padding_fill)); - return CudnnRnnSequenceTensorDescriptor( - parent, max_seq_length, batch_size, data_size, data_type, - std::move(data_desc), std::move(tensor_desc)); + return CudnnRnnSequenceTensorDescriptor(parent, max_seq_length, batch_size, + data_size, std::move(data_desc), + std::move(tensor_desc)); } const cudnnTensorDescriptor_t* handles() const { return handles_.data(); } - const cudnnRNNDataDescriptor_t data_handle() const { + cudnnRNNDataDescriptor_t data_handle() const { return rnn_data_handle_.get(); } @@ -1773,7 +1777,6 @@ class CudnnRnnSequenceTensorDescriptor int max_seq_length_; int batch_size_; int data_size_; - cudnnDataType_t data_type_; TensorDescriptor handle_; RNNDataDescriptor rnn_data_handle_; std::vector handles_; // Copies of handle_. @@ -1788,8 +1791,7 @@ class CudnnRnnStateTensorDescriptor : public dnn::RnnStateTensorDescriptor { : handle_(CreateTensorDescriptor()), num_layers_(num_layers), batch_size_(batch_size), - data_size_(data_size), - data_type_(data_type) { + data_size_(data_size) { int dims[] = {num_layers, batch_size, data_size}; int strides[] = {dims[1] * dims[2], dims[2], 1}; CHECK_CUDNN_OK(cudnnSetTensorNdDescriptor( @@ -1809,7 +1811,6 @@ class CudnnRnnStateTensorDescriptor : public dnn::RnnStateTensorDescriptor { int num_layers_; int batch_size_; int data_size_; - cudnnDataType_t data_type_; SE_DISALLOW_COPY_AND_ASSIGN(CudnnRnnStateTensorDescriptor); }; @@ -4136,8 +4137,7 @@ GetCudnnOperationGraph(dnn::ConvolutionKind kind, dnn::DataType input_type, << "\nConv: " << conv_desc.describe() << "\nOp: " << op.describe() << "\nOpGraph: " << opGraph.describe(); - return std::unique_ptr( - new cudnn_frontend::OperationGraph(std::move(opGraph))); + return std::make_unique(std::move(opGraph)); } bool SideInputNeeded(dnn::ActivationMode activation_mode, double conv_scale, @@ -4465,8 +4465,7 @@ GetCudnnFusedOperationGraph( << (act_op.has_value() ? act_op->describe() : "(identity)") << "\nOpGraph: " << op_graph.describe(); - return std::unique_ptr( - new cudnn_frontend::OperationGraph(std::move(op_graph))); + return std::make_unique(std::move(op_graph)); } tsl::StatusOr> @@ -6210,7 +6209,7 @@ class CudnnExecutionPlanRunner size_t workspace_size = plan_.getWorkspaceSize(); RETURN_MSG_IF_CUDNN_ERROR(plan_); bool should_add_scalars = - scalar_input_uids_.size() > 0 && scalar_input_values_.size() > 0; + !scalar_input_uids_.empty() && !scalar_input_values_.empty(); CHECK(scalar_input_uids_.size() == scalar_input_values_.size()); std::array data_ptrs = {inputs.opaque()...}; @@ -6223,7 +6222,7 @@ class CudnnExecutionPlanRunner data_ptrs_vec.erase(data_ptrs_vec.begin() + 2); } - if (data_ptrs_vec[sizeof...(Args) - 1] == nullptr && + if (!data_ptrs_vec.empty() && data_ptrs_vec.back() == nullptr && !has_activation_output_) { data_ptrs_vec.pop_back(); } @@ -6426,7 +6425,7 @@ tsl::Status CreateOpRunners( // Frontend, but instead they get filtered out here. VLOG(4) << "Failed building runner from ExecutionPlan (i.e. failed " "getting its workspace size): " - << runner_or.status().ToString(); + << runner_or.status(); continue; } From d81dacd3ef3a321e76dc44273b04265144680d9e Mon Sep 17 00:00:00 2001 From: Ilia Sergachev Date: Tue, 11 Jul 2023 01:53:17 -0700 Subject: [PATCH 113/376] [XLA:GPU] Roll-forward cl/543680393: Fuse more inputs into Triton GEMMs. - Let the GEMM rewriter do more complex traversals of inputs and fuse elementwise operations and broadcasts of scalar constants. - Limit the number of parameters per fusion. - Reorganize GPU compiler pipeline: bf16 float normalization is now required both before and after Triton GEMM fusion. - Remove an autotuner config that for unknown reasons fails on Volta with new fusions. PiperOrigin-RevId: 547118033 --- tensorflow/compiler/xla/service/gpu/BUILD | 7 - .../xla/service/gpu/gemm_rewriter_triton.cc | 438 ++++++------------ .../xla/service/gpu/gemm_rewriter_triton.h | 49 +- .../service/gpu/gemm_rewriter_triton_test.cc | 147 +----- .../compiler/xla/service/gpu/gpu_compiler.cc | 37 +- .../xla/service/gpu/ir_emitter_triton.cc | 2 +- .../xla/service/gpu/ir_emitter_triton_test.cc | 152 ------ .../xla/service/gpu/triton_autotuner.cc | 11 +- 8 files changed, 171 insertions(+), 672 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index b2c7fd02999d60..2d7c4d2cb488ff 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -438,7 +438,6 @@ cc_library( "//tensorflow/tsl/platform:errors", "//tensorflow/tsl/platform:logging", "//tensorflow/tsl/platform:path", - "//tensorflow/tsl/platform:statusor", "//tensorflow/tsl/platform:tensor_float_32_hdr_lib", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", @@ -494,8 +493,6 @@ xla_test( "//tensorflow/compiler/xla:autotuning_proto_cc", "//tensorflow/compiler/xla:error_spec", "//tensorflow/compiler/xla/hlo/ir:hlo", - "//tensorflow/compiler/xla/service:pattern_matcher", - "//tensorflow/compiler/xla/service:pattern_matcher_gmock", "//tensorflow/compiler/xla/service/gpu/tests:gpu_codegen_test", "//tensorflow/compiler/xla/stream_executor:device_description", "//tensorflow/compiler/xla/stream_executor/cuda:cublas_plugin", @@ -1157,22 +1154,18 @@ cc_library( "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status", - "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/hlo/ir:hlo", "//tensorflow/compiler/xla/hlo/utils:hlo_query", "//tensorflow/compiler/xla/service:hlo_creation_utils", "//tensorflow/compiler/xla/service:hlo_pass", - "//tensorflow/compiler/xla/service:instruction_fusion", - "//tensorflow/compiler/xla/stream_executor:device_description", "//tensorflow/tsl/platform:errors", "//tensorflow/tsl/platform:status", "//tensorflow/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", ], diff --git a/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton.cc b/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton.cc index 1f862d0bb5851b..6b28352ccd61ab 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton.cc @@ -22,15 +22,12 @@ limitations under the License. #include #include #include -#include #include #include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" -#include "absl/log/check.h" #include "absl/strings/str_cat.h" -#include "absl/strings/str_join.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/autotuning.pb.h" @@ -40,7 +37,6 @@ limitations under the License. #include "tensorflow/compiler/xla/hlo/ir/hlo_instruction.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_instructions.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_opcode.h" -#include "tensorflow/compiler/xla/hlo/ir/hlo_schedule.h" #include "tensorflow/compiler/xla/hlo/utils/hlo_query.h" #include "tensorflow/compiler/xla/layout.h" #include "tensorflow/compiler/xla/literal_util.h" @@ -50,12 +46,9 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/gpu/matmul_utils.h" #include "tensorflow/compiler/xla/service/hlo_creation_utils.h" -#include "tensorflow/compiler/xla/service/instruction_fusion.h" #include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status.h" -#include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/compiler/xla/stream_executor/device_description.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/tsl/platform/errors.h" @@ -64,25 +57,6 @@ limitations under the License. namespace xla { namespace gpu { - -bool TensorIterationSpec::operator==(const TensorIterationSpec& other) const { - for (int dim = 0; dim < TensorIterationSpec::kMaxDimsPerTensor; ++dim) { - if (dim_iteration_specs_[dim].size() != other[dim].size()) { - return false; - } - for (int fragment = 0; fragment < dim_iteration_specs_[dim].size(); - ++fragment) { - if (dim_iteration_specs_[dim][fragment].stride != - other[dim][fragment].stride || - dim_iteration_specs_[dim][fragment].count != - other[dim][fragment].count) { - return false; - } - } - } - return true; -} - namespace { // Batch dimensions of an operand of a dot instruction. @@ -121,10 +95,10 @@ int64_t NonContractingDimensionIndex(const HloInstruction& dot, } // Data types that are tested to work in the triton GEMM emitter. -bool IsSupportedDataType(PrimitiveType type, GpuVersion gpu_version) { +bool IsSupportedDataType(PrimitiveType t, GpuVersion gpu_version) { auto cuda_compute_capability = std::get(gpu_version); - switch (type) { + switch (t) { case PRED: case S8: case S16: @@ -140,19 +114,21 @@ bool IsSupportedDataType(PrimitiveType type, GpuVersion gpu_version) { } } -// Let input and output data volumes of a fusion grow by small amounts. -constexpr int64_t kIoToleranceBytes = 1024; - -// Difference of input and output data volumes of an instruction. -int64_t InputMinusOutputBytes(const HloInstruction& hlo) { - CHECK(!hlo.shape().IsTuple()); - int64_t output_size = ShapeUtil::ByteSizeOf(hlo.shape()); - int64_t input_size = 0; - for (const HloInstruction* operand : hlo.operands()) { - CHECK(!operand->shape().IsTuple()); - input_size += ShapeUtil::ByteSizeOf(operand->shape()); +Status RequireTritonFusibleConvert(const HloInstruction* input, + GpuVersion gpu_version) { + if (!IsSupportedDataType(input->operand(0)->shape().element_type(), + gpu_version)) { + return Unimplemented("unsupported data type"); } - return input_size - output_size; + // TODO(b/266862494): Can pick up almost any + // convert, but if it's reducing the data volume it should rather be fused + // to the output of the producer kernel. However not all operations support + // output fusion - then it should be fused here anyway! + if (ShapeUtil::ByteSizeOf(input->operand(0)->shape()) > + ShapeUtil::ByteSizeOf(input->shape())) { + return FailedPrecondition("narrowing conversion"); + } + return OkStatus(); } // Handles numbers of dimensions of a target HLO instruction @@ -166,13 +142,6 @@ class DimensionOrder { int64_t target_dim_number; int subdim_number; int64_t size; - bool operator==(const DimDescription& other) const { - return target_dim_number == other.target_dim_number && - subdim_number == other.subdim_number && size == other.size; - } - std::string ToString() const { - return absl::StrCat(target_dim_number, ":", subdim_number, ":", size); - } }; // Sequence describing all dimensions of HLO's output shape // in layout minor-to-major (physical) order. @@ -202,35 +171,34 @@ class DimensionOrder { // Transforms the DimensionOrder so that from a description of the output // of `hlo` it becomes a description of the input of `hlo`. - FusionDecision HandleInstruction(const HloInstruction* hlo) { + Status HandleInstruction(const HloInstruction* hlo) { VLOG(7) << hlo->ToString(); - if (hlo->opcode() == HloOpcode::kParameter || - hlo->opcode() == HloOpcode::kConstant) { - return FusionDecision{}; + if (hlo->opcode() == HloOpcode::kParameter) { + return OkStatus(); } else if (hlo->opcode() == HloOpcode::kTranspose || hlo->opcode() == HloOpcode::kCopy) { return HandleCopyOrTranspose(hlo); } else if (hlo->operand_count() > 0 && IsTritonSupportedElementwise( hlo->opcode(), hlo->operand(0)->shape().element_type())) { - return FusionDecision{}; + return OkStatus(); } else if (hlo->opcode() == HloOpcode::kBitcast) { return HandleBitcast(hlo); } else if (hlo->opcode() == HloOpcode::kReshape) { if (!ShapeUtil::ReshapeIsBitcast(hlo->operand(0)->shape(), hlo->shape())) { - return "Non-bitcast reshape."; + return Unimplemented("Non-bitcast reshape."); } return HandleBitcast(hlo); } else if (hlo_query::IsScalarConstant(hlo) || hlo_query::IsBroadcastOfScalarConstant(*hlo)) { // Dimension order collapses on a scalar, for simplicity leave it equal // to the output one for now. - return FusionDecision{}; + return OkStatus(); } else { - return "Unimplemented instruction."; + return Unimplemented("Instruction: %s", hlo->ToString()); } - return FusionDecision{}; + return OkStatus(); } // Get the raw data of the dimension order. @@ -242,32 +210,20 @@ class DimensionOrder { return splittable_dimension_index_; } - // Tells that two dimension orders describe the same tensor physical layout. - bool IsPhysicallyEquivalent(const DimensionOrder& other) const; - - std::string ToString() const { - return absl::StrJoin(dim_order_, "-", - [](std::string* out, const DimDescription& d) { - absl::StrAppend(out, d.ToString()); - }); - } - private: // See HandleInstruction() for the general description of Handle*(). - FusionDecision HandleBitcast(const HloInstruction* hlo); - FusionDecision HandleCopyOrTranspose(const HloInstruction* hlo); + Status HandleBitcast(const HloInstruction* hlo); + Status HandleCopyOrTranspose(const HloInstruction* hlo); DimOrderVector dim_order_; - const int64_t splittable_dimension_index_; + int64_t splittable_dimension_index_; }; -using DimIterationSpec = TensorIterationSpec::DimIterationSpec; - -TensorIterationSpec DimensionOrderToTensorIterationSpec( +DotFusionAnalysis::TensorIterationSpec DimensionOrderToTensorIterationSpec( const DimensionOrder& order) { const DimensionOrder::DimOrderVector& dim_order_vector = order.GetDimOrderVector(); - TensorIterationSpec tensor_spec; + DotFusionAnalysis::TensorIterationSpec tensor_spec; int64_t accumulated_stride = 1; for (int dim_order_index = 0; dim_order_index < dim_order_vector.size(); ++dim_order_index) { @@ -280,7 +236,8 @@ TensorIterationSpec DimensionOrderToTensorIterationSpec( continue; } - DimIterationSpec& dim_spec = tensor_spec[dim.target_dim_number]; + DotFusionAnalysis::DimIterationSpec& dim_spec = + tensor_spec[dim.target_dim_number]; if (dim_order_index > 0 && dim_order_vector[dim_order_index - 1].target_dim_number == dim.target_dim_number) { @@ -300,7 +257,7 @@ TensorIterationSpec DimensionOrderToTensorIterationSpec( accumulated_stride *= dim.size; } // Create all absent dimensions as degenerate ones to simplify later queries. - for (DimIterationSpec& dim_spec : tensor_spec) { + for (DotFusionAnalysis::DimIterationSpec& dim_spec : tensor_spec) { if (dim_spec.empty()) { dim_spec.push_back({/*stride=*/0, /*count=*/1, /*subfragments=*/{1}}); } @@ -308,11 +265,6 @@ TensorIterationSpec DimensionOrderToTensorIterationSpec( return tensor_spec; } -bool DimensionOrder::IsPhysicallyEquivalent(const DimensionOrder& other) const { - return DimensionOrderToTensorIterationSpec(*this) == - DimensionOrderToTensorIterationSpec(other); -} - DimensionOrder DimensionOrder::FromDotOperand(const HloInstruction& dot, const int operand_number, const int64_t split_k) { @@ -335,7 +287,7 @@ DimensionOrder DimensionOrder::FromDotOutput(const HloInstruction& dot) { return DimensionOrder(&dot); } -FusionDecision DimensionOrder::HandleBitcast(const HloInstruction* hlo) { +Status DimensionOrder::HandleBitcast(const HloInstruction* hlo) { const Shape& operand_shape = hlo->operand(0)->shape(); DimOrderVector operand_dim_order; operand_dim_order.reserve(dim_order_.size()); @@ -349,7 +301,7 @@ FusionDecision DimensionOrder::HandleBitcast(const HloInstruction* hlo) { ++out_dim) { if (operand_remaining_size >= out_dim->size) { if (operand_remaining_size % out_dim->size) { - return "Unsupported bitcast"; + return Unimplemented("Unsupported bitcast: %s", hlo->ToString()); } // Output dimension fragment completely fits into the operand one: // just copy it as is. @@ -367,7 +319,7 @@ FusionDecision DimensionOrder::HandleBitcast(const HloInstruction* hlo) { // If there is a remaining fragment of a previous operand dimension // assign it first. if (out_remaining_size % operand_remaining_size) { - return "Unsupported bitcast"; + return Unimplemented("Unsupported bitcast: %s", hlo->ToString()); } operand_dim_order.push_back( {out_dim->target_dim_number, subdim_index, operand_remaining_size}); @@ -385,7 +337,7 @@ FusionDecision DimensionOrder::HandleBitcast(const HloInstruction* hlo) { // assign the remainder of the output and carry over the remainder // of the operand. if (operand_dim_size % out_remaining_size) { - return "Unsupported bitcast"; + return Unimplemented("Unsupported bitcast: %s", hlo->ToString()); } operand_remaining_size = operand_dim_size / out_remaining_size; new_fragment_size = out_remaining_size; @@ -406,7 +358,7 @@ FusionDecision DimensionOrder::HandleBitcast(const HloInstruction* hlo) { int subdim_index = operand_dim_order.back().subdim_number + 1; while (operand_dim_iter != operand_shape.layout().minor_to_major().cend()) { if (operand_shape.dimensions(*operand_dim_iter) != 1) { - return "Unsupported bitcast"; + return Unimplemented("Unsupported bitcast: %s", hlo->ToString()); } operand_dim_order.push_back( {operand_dim_order.back().target_dim_number, subdim_index, 1}); @@ -415,11 +367,10 @@ FusionDecision DimensionOrder::HandleBitcast(const HloInstruction* hlo) { } dim_order_ = operand_dim_order; - return FusionDecision{}; + return OkStatus(); } -FusionDecision DimensionOrder::HandleCopyOrTranspose( - const HloInstruction* hlo) { +Status DimensionOrder::HandleCopyOrTranspose(const HloInstruction* hlo) { // Every HLO dimension can correspond to a group of subdimensions in // dim_order_. For the easier handling of permutations: group dim_order_ by // dimension, apply permutations, then finally remove the grouping. @@ -468,25 +419,25 @@ FusionDecision DimensionOrder::HandleCopyOrTranspose( dim_order_.push_back(subdim); } } - return FusionDecision{}; + return OkStatus(); } // Tells if the dimension order is supported by the triton GEMM emitter. // Only the dimension indicated by SplittableDimensionIndex() can be split // physically once by other dimensions. Other ones can be only split logically. // All subdimensions within a dimension have to be ordered. -FusionDecision RequireTritonGemmSupportedDimOrder(const DimensionOrder& order) { - std::array subdim_counters = { +Status RequireTritonGemmSupportedDimOrder(const DimensionOrder& order) { + std::array subdim_counters = { -1, -1, -1, -1}; - std::array split_counters = { + std::array split_counters = { -1, -1, -1, -1}; const DimensionOrder::DimOrderVector& dim_order_vector = order.GetDimOrderVector(); - VLOG(8) << order.ToString(); for (int i = 0; i < dim_order_vector.size(); i++) { const auto [dim_number, subdim_number, size] = dim_order_vector[i]; + VLOG(8) << dim_number << "\t" << subdim_number << "\t" << size; if (subdim_counters[dim_number] != subdim_number - 1) { - return "Transpose within a dimension."; + return Unimplemented("Transpose within a dimension."); } ++subdim_counters[dim_number]; if (size == 1) { @@ -496,179 +447,31 @@ FusionDecision RequireTritonGemmSupportedDimOrder(const DimensionOrder& order) { ++split_counters[dim_number]; if (dim_number == order.SplittableDimensionIndex()) { if (split_counters[dim_number] > 1) { - return "2nd split of a splittable dimension."; + return Unimplemented("2nd split of a splittable dimension."); } } else if (split_counters[dim_number] > 0) { - return "Split of a non-splittable dimension."; + return Unimplemented("Split of a non-splittable dimension."); } } } - return FusionDecision{}; -} - -// Tells if an instruction has no input into which it could be fused. -// More cases should be added here. -bool CanNotBeFusedIntoAProducer(const HloInstruction& hlo) { - return hlo_query::AllOperandsAreParametersOrConstants(hlo); -} - -// Tells that fusing an instruction is efficient. -bool IsInputWorthFusing(const HloInstruction& hlo) { - return hlo_query::AllOperandsAreParametersOrConstants(hlo) || - InputMinusOutputBytes(hlo) < kIoToleranceBytes; + return OkStatus(); } -// Checks if the instruction is possible and profitable to fuse. -// If so tries to transform dim_order describing output of `hlo` into a +// Transforms dim_order describing the output of `hlo` into a // description of its input if it is supported by the triton GEMM emitter. -FusionDecision CanFuse(const HloInstruction& hlo, DimensionOrder& dim_order, - const GpuVersion gpu_version) { - if (hlo.opcode() == HloOpcode::kTuple || - hlo.opcode() == HloOpcode::kGetTupleElement) { - return "Unsupported instruction."; - } - for (const HloInstruction* operand : hlo.operands()) { - if (!IsSupportedDataType(operand->shape().element_type(), gpu_version)) { - return "Unsupported input data type."; - } - } - if (!IsSupportedDataType(hlo.shape().element_type(), gpu_version)) { - return "Unsupported output data type."; - } - if (hlo.IsConstant()) { - return "Not fusing a constant."; - } - if (hlo.opcode() == HloOpcode::kBroadcast) { - return "Not fusing a broadcast."; - } - if (!CanNotBeFusedIntoAProducer(hlo) && !IsInputWorthFusing(hlo)) { - return "Not obviously profitable to fuse as input."; - } - if (FusionDecision decision = dim_order.HandleInstruction(&hlo); !decision) { - return decision; +Status CanFuse(const HloInstruction* hlo, DimensionOrder& dim_order, + const GpuVersion gpu_version) { + if (hlo->opcode() == HloOpcode::kConvert) { + return RequireTritonFusibleConvert(hlo, gpu_version); + } else if (hlo->IsElementwise() && hlo->opcode() != HloOpcode::kCopy) { + // Temporarily forbid fusing elementwise operations + // other than copy and convert. + return Unimplemented("Unsupported elementwise operation"); } + TF_RETURN_IF_ERROR(dim_order.HandleInstruction(hlo)); return RequireTritonGemmSupportedDimOrder(dim_order); } -// Clone an instruction into the fusion. -void Fuse(HloInstruction& hlo, - absl::flat_hash_map& - old_to_new_mapping, - std::vector& call_operands, - HloComputation::Builder& builder) { - if (old_to_new_mapping.contains(&hlo)) { - return; - } - VLOG(3) << "Fusing " << hlo.ToString(); - auto get_or_add_parameter = [&](HloInstruction& instr) { - if (auto it = old_to_new_mapping.find(&instr); - it != old_to_new_mapping.end()) { - return it->second; - } - call_operands.push_back(&instr); - return old_to_new_mapping - .insert({&instr, - builder.AddInstruction(HloInstruction::CreateParameter( - call_operands.size() - 1, instr.shape(), - absl::StrCat("parameter_", call_operands.size() - 1)))}) - .first->second; - }; - if (hlo.opcode() == HloOpcode::kParameter || - hlo.opcode() == HloOpcode::kGetTupleElement) { - get_or_add_parameter(hlo); - } else { - std::vector hlo_new_operands; - for (HloInstruction* operand : hlo.operands()) { - hlo_new_operands.push_back(get_or_add_parameter(*operand)); - } - old_to_new_mapping[&hlo] = builder.AddInstruction( - hlo.CloneWithNewOperands(hlo.shape(), hlo_new_operands)); - } -} - -// Tells how many new parameters does a fusion gain by fusing the operation as -// an input. -int64_t NumAddedParameters(const HloInstruction& hlo) { - // Non-scalar constant is equivalent to a parameter: one input, one output. - if (hlo.opcode() == HloOpcode::kConstant && - !ShapeUtil::IsScalar(hlo.shape())) { - return 0; - } - // All other instructions add all own inputs and remove own single output. - return hlo.operand_count() - 1; -} - -// Fuse an instruction with all its fusible inputs. -// If an input is not fusible stop there and make a parameter of the new -// fusion, otherwise put it onto stack and check its own inputs first. -void FuseWithInputsRecursively( - HloInstruction* root, DimensionOrder root_dim_order, - // Dimension orders describing inputs of corresponding instructions. - absl::flat_hash_map& dim_orders, - const GpuVersion gpu_version, - absl::flat_hash_map& - old_to_new_mapping, - std::vector& call_operands, - HloComputation::Builder& builder) { - absl::flat_hash_set visited; - std::stack to_fuse; - // Instructions at the edge 'to_fuse' that can either get fused too or - // become parameters of the fusion. Used to track the number of parameters - // of the fusion. - absl::flat_hash_set inputs; - // Currently only one physically unique dim order per scope is supported. - // Let it change while the scope has one input; afterwards require all - // of them to be physically compatible. - const HloInstruction* reference_dim_order_hlo = nullptr; - if (CanFuse(*root, root_dim_order, gpu_version)) { - to_fuse.push(root); - inputs.insert(root->operands().begin(), root->operands().end()); - // root_dim_order went through output -> input transformation here. - CHECK(dim_orders.insert({root, root_dim_order}).second) << root->ToString(); - } - visited.insert(root); - while (!to_fuse.empty()) { - bool top_is_ready_to_fuse = true; - HloInstruction* hlo = to_fuse.top(); - if (reference_dim_order_hlo == nullptr && hlo->operand_count() > 1) { - reference_dim_order_hlo = hlo; - } - for (HloInstruction* operand : hlo->mutable_operands()) { - if (visited.insert(operand).second) { - // Stop adding new parameters. - if (inputs.size() >= DotFusionAnalysis::kMaxParameterPerScope && - NumAddedParameters(*operand) > 0) { - continue; - } - // Operand's output is described by its consumer's input. - DimensionOrder operand_dim_order(dim_orders.at(hlo)); - // CanFuse() makes output -> input transformation of - // operand_dim_order if succeeds. - if (CanFuse(*operand, operand_dim_order, gpu_version)) { - if (reference_dim_order_hlo != nullptr && - !operand_dim_order.IsPhysicallyEquivalent( - dim_orders.at(reference_dim_order_hlo))) { - continue; - } - to_fuse.push(operand); - if (operand->opcode() != HloOpcode::kParameter) { - inputs.erase(operand); - } - inputs.insert(operand->operands().begin(), operand->operands().end()); - // Save the dimension order description of operand's input. - CHECK(dim_orders.insert({operand, operand_dim_order}).second) - << operand->ToString(); - top_is_ready_to_fuse = false; - } - } - } - if (top_is_ready_to_fuse) { - Fuse(*hlo, old_to_new_mapping, call_operands, builder); - to_fuse.pop(); - } - } -} - // Extracts into fused computations parts of HLO graph including dot() // operations that can target the triton GEMM emitter. class GemmRewriterTritonVisitor : public DfsHloRewriteVisitor { @@ -680,9 +483,8 @@ class GemmRewriterTritonVisitor : public DfsHloRewriteVisitor { // and replaces the original dot() with a call to the computation. Status HandleDot(HloInstruction* dot) override { VLOG(5) << dot->ToString(); - FusionDecision can_handle = CanTritonHandleGEMM(*dot, gpu_version_); - if (!can_handle) { - VLOG(3) << can_handle.Explain(); + + if (!CanTritonHandleGEMM(*dot, gpu_version_)) { return OkStatus(); } @@ -701,28 +503,72 @@ class GemmRewriterTritonVisitor : public DfsHloRewriteVisitor { std::string suggested_name = absl::StrCat("triton_gemm_", dot->name()); HloComputation::Builder builder( absl::StrCat(suggested_name, "_computation")); - std::vector call_operands; // Original instruction -> fused one. absl::flat_hash_map old_to_new_mapping; - - auto fuse_inputs = [&](int operand_number) { - absl::flat_hash_map dim_orders; - int operand_count_before = call_operands.size(); - // Direct dot inputs have well defined dimension orders. - FuseWithInputsRecursively( - dot->mutable_operand(operand_number), - DimensionOrder::FromDotOperand(*dot, operand_number), dim_orders, - gpu_version_, old_to_new_mapping, call_operands, builder); - return call_operands.size() - operand_count_before; - }; - // Separate traversal from LHS and RHS inputs of the dot: they use - // differently shaped tiles but may go through same HLO graph nodes. - TF_RET_CHECK(fuse_inputs(0) <= DotFusionAnalysis::kMaxParameterPerScope); - TF_RET_CHECK(fuse_inputs(1) <= DotFusionAnalysis::kMaxParameterPerScope); - - Fuse(*dot, old_to_new_mapping, call_operands, builder); - + absl::flat_hash_set visited; + std::vector call_operands; + // Traverse and fuse dot() inputs bottom-up starting from direct operands. + // If an input is not fusible stop there and make it a parameter of the new + // fusion, otherwise put it onto stack and check its own inputs first. + std::stack to_fuse; + // Dimension orders describing inputs of corresponding instructions. + absl::flat_hash_map dim_orders; + to_fuse.push(dot); + while (!to_fuse.empty()) { + bool top_is_ready_to_fuse = true; + HloInstruction* hlo = to_fuse.top(); + for (HloInstruction* operand : hlo->mutable_operands()) { + if (visited.insert(operand).second) { + DimensionOrder operand_dim_order = [&] { + // Direct dot inputs are described by default dimension orders. + if (operand == dot->operand(0)) { + return DimensionOrder::FromDotOperand(*dot, 0); + } else if (operand == dot->operand(1)) { + return DimensionOrder::FromDotOperand(*dot, 1); + } + // Otherwise operand's output is described by its consumer's input. + return DimensionOrder(dim_orders.at(hlo)); + }(); + // CanFuse() makes output -> input transformation of + // operand_dim_order if succeeds. + if (CanFuse(operand, operand_dim_order, gpu_version_).ok()) { + VLOG(3) << "Fusing " << operand->ToString(); + to_fuse.push(operand); + // Save the dimension order description of operand's input. + dim_orders.insert({operand, operand_dim_order}); + top_is_ready_to_fuse = false; + } + } + } + if (top_is_ready_to_fuse) { + if (hlo->opcode() == HloOpcode::kParameter || + hlo->opcode() == HloOpcode::kGetTupleElement) { + old_to_new_mapping[hlo] = + builder.AddInstruction(HloInstruction::CreateParameter( + call_operands.size(), hlo->shape(), + absl::StrCat("parameter_", call_operands.size()))); + call_operands.push_back(hlo); + } else { + std::vector hlo_new_operands; + for (HloInstruction* operand : hlo->operands()) { + const auto iter = old_to_new_mapping.find(operand); + if (iter != old_to_new_mapping.end()) { + hlo_new_operands.push_back(iter->second); + } else { + hlo_new_operands.push_back( + builder.AddInstruction(HloInstruction::CreateParameter( + call_operands.size(), operand->shape(), + absl::StrCat("parameter_", call_operands.size())))); + call_operands.push_back(operand); + } + } + old_to_new_mapping[hlo] = builder.AddInstruction( + hlo->CloneWithNewOperands(hlo->shape(), hlo_new_operands)); + } + to_fuse.pop(); + } + } HloComputation* computation = dot->GetModule()->AddComputationAndUnifyNamesAndIds(builder.Build(), /*is_entry=*/false); @@ -746,7 +592,7 @@ class GemmRewriterTritonVisitor : public DfsHloRewriteVisitor { } else { TF_RETURN_IF_ERROR(ReplaceInstruction(dot, dot_fusion)); } - XLA_VLOG_LINES(5, computation->ToString()); + VLOG(5) << computation->ToString(); return OkStatus(); } @@ -797,7 +643,7 @@ StatusOr MakeSplitKOperand( for (const HloInstruction* param : analysis.ScopeParameters(scope)) { // If an operand of dot does not read any parameters its K dimension // does not need analysis for fragmentation. - const DimIterationSpec* spec = + const DotFusionAnalysis::DimIterationSpec* spec = analysis.IterSpec(scope, param, contracting_dim_idx); // Split contracting dimension is not implemented yet. CHECK_EQ(spec->size(), 1); @@ -1039,8 +885,8 @@ DotFusionAnalysis::DotFusionAnalysis(const HloComputation* dot_computation, absl::flat_hash_map dim_orders; DimensionOrder dot_operand_dim_order = DimensionOrder::FromDotOperand(*dot, operand_number, split_k); - CHECK(dot_operand_dim_order.HandleInstruction(dot_operand)); - CHECK(RequireTritonGemmSupportedDimOrder(dot_operand_dim_order)) + TF_CHECK_OK(dot_operand_dim_order.HandleInstruction(dot_operand)); + TF_CHECK_OK(RequireTritonGemmSupportedDimOrder(dot_operand_dim_order)) << dot_computation->ToString(); dim_orders.insert({dot_operand, dot_operand_dim_order}); visited.insert(dot_operand); @@ -1061,18 +907,14 @@ DotFusionAnalysis::DotFusionAnalysis(const HloComputation* dot_computation, {hlo_operand, DimensionOrder(dim_orders.at(hlo))}); CHECK(inserted); DimensionOrder& hlo_operand_dim_order = it->second; - CHECK(hlo_operand_dim_order.HandleInstruction(hlo_operand)); - CHECK(RequireTritonGemmSupportedDimOrder(hlo_operand_dim_order)) + TF_CHECK_OK(hlo_operand_dim_order.HandleInstruction(hlo_operand)); + TF_CHECK_OK(RequireTritonGemmSupportedDimOrder(hlo_operand_dim_order)) << " " << dot_computation->ToString(); to_process.push(hlo_operand); } } - // For now all parameters of one scope have to use the same tiling. for (const HloInstruction* parameter : parameters_[scope]) { - CHECK(dim_orders.at(parameter).IsPhysicallyEquivalent( - dim_orders.at(*parameters_[scope].cbegin()))) - << dot_computation->ToString(); iter_specs_[scope][parameter] = DimensionOrderToTensorIterationSpec(dim_orders.at(parameter)); } @@ -1084,22 +926,22 @@ DotFusionAnalysis::DotFusionAnalysis(const HloComputation* dot_computation, .second); } -const DimIterationSpec* DotFusionAnalysis::IterSpec( +const DotFusionAnalysis::DimIterationSpec* DotFusionAnalysis::IterSpec( const DotFusionAnalysis::Scope scope, const HloInstruction* hlo, const int dimension) const { auto ret = iter_specs_.at(scope).find(hlo); if (ret != iter_specs_.at(scope).end()) { - return &ret->second[dimension]; + return &ret->second.at(dimension); } return nullptr; } -FusionDecision CanTritonHandleGEMM(const HloInstruction& dot, - const GpuVersion gpu_version) { +bool CanTritonHandleGEMM(const HloInstruction& dot, + const GpuVersion gpu_version) { if (dot.opcode() != HloOpcode::kDot || absl::c_any_of(dot.precision_config().operand_precision(), [](int x) { return x != PrecisionConfig::DEFAULT; })) { - return "Non-default precision."; + return false; } auto supported_output_type = [&](const PrimitiveType t) { @@ -1119,21 +961,21 @@ FusionDecision CanTritonHandleGEMM(const HloInstruction& dot, // TODO(b/266862493): Support more output types. if (!supported_output_type(dot.shape().element_type())) { - return "Unsupported output data type."; + return false; } if (!IsSupportedDataType(dot.operand(0)->shape().element_type(), gpu_version) || !IsSupportedDataType(dot.operand(1)->shape().element_type(), gpu_version)) { - return "Unsupported input data type."; + return false; } const DotDimensionNumbers& dim_numbers = dot.dot_dimension_numbers(); // TODO(b/269580541): support multiple batch dimensions. if (dim_numbers.lhs_batch_dimensions().size() > 1) { - return "Multiple batch dimensions."; + return false; } // Cases where lhs or rhs have no non-contracting dims are not handled. @@ -1143,10 +985,10 @@ FusionDecision CanTritonHandleGEMM(const HloInstruction& dot, dim_numbers.rhs_batch_dimensions().size() + dim_numbers.rhs_contracting_dimensions().size() == dot.operand(1)->shape().rank()) { - return "No non-contracting dimensions."; + return false; } - return FusionDecision{}; + return true; } bool ShouldTritonHandleGEMM(const HloInstruction& dot, @@ -1166,7 +1008,7 @@ bool ShouldTritonHandleGEMM(const HloInstruction& dot, while (!queue.empty()) { const HloInstruction* current = queue.front(); queue.pop(); - if (!CanFuse(*current, dim_order, gpu_version)) { + if (!CanFuse(current, dim_order, gpu_version).ok()) { continue; } // Stop as soon as a profitable operation is fused. diff --git a/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton.h b/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton.h index 0afc939b43ede2..715c79d9114659 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton.h +++ b/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton.h @@ -29,7 +29,6 @@ limitations under the License. #include "tensorflow/compiler/xla/hlo/ir/hlo_opcode.h" #include "tensorflow/compiler/xla/service/gpu/gpu_types.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" -#include "tensorflow/compiler/xla/service/instruction_fusion.h" namespace xla { namespace gpu { @@ -53,13 +52,13 @@ Status MakeDotSplitKBatch(HloInstruction* dot_fusion, const AutotuneResult::TritonGemmKey& tiling); // Filters GEMMs which can be handled using Triton. -FusionDecision CanTritonHandleGEMM(const HloInstruction&, - GpuVersion gpu_version); +bool CanTritonHandleGEMM(const HloInstruction&, GpuVersion gpu_version); // Filters GEMMs which are better to handle using Triton. bool ShouldTritonHandleGEMM(const HloInstruction&, GpuVersion gpu_version); -class TensorIterationSpec { +// Analysis of iteration of HLO shapes within a fusion around dot(). +class DotFusionAnalysis { public: // Description of basic iteration: `count` elements separated by `stride`. struct IterationSpecFragment { @@ -69,42 +68,16 @@ class TensorIterationSpec { // of several HLO dimensions. Product of subfragments equals `count`. std::vector subfragments; }; + // Description of complex iteration over a sequence of several strides. // Describes a logically contiguous dimension of a tensor physically // separated into multiple fragments by other dimensions. using DimIterationSpec = std::vector; // At most: contracting, non-contracting, split-K, another batch. - static constexpr int kMaxDimsPerTensor = 4; - using StorageType = std::array; - - const DimIterationSpec& operator[](int dimension) const { - return dim_iteration_specs_[dimension]; - } - - DimIterationSpec& operator[](int dimension) { - return dim_iteration_specs_[dimension]; - } - - // Compares physical layouts of tensors ignoring subfragments of dimensions. - bool operator==(const TensorIterationSpec& other) const; - - StorageType::iterator begin() { return dim_iteration_specs_.begin(); } - StorageType::iterator end() { return dim_iteration_specs_.end(); } - StorageType::const_iterator cbegin() const { - return dim_iteration_specs_.cbegin(); - } - StorageType::const_iterator cend() const { - return dim_iteration_specs_.cend(); - } - - private: - StorageType dim_iteration_specs_; -}; + static const int kMaxDimsPerTensor = 4; + using TensorIterationSpec = std::array; -// Analysis of iteration of HLO shapes within a fusion around dot(). -class DotFusionAnalysis { - public: // Execute analysis of dot fusion computation. // split_k indicates whether this operation was converted to the split-K // form and tells the analysis how to interpret the batch dimensions. @@ -115,15 +88,9 @@ class DotFusionAnalysis { // defined by left operand, right operand and output. enum class Scope { LHS = 0, RHS = 1, OUTPUT = 2 }; - // Every parameter requires a separate piece of shared memory for asynchronous - // loads. Multiple parameters are approximately equivalent to multiple - // pipeline stages. - static constexpr int kMaxParameterPerScope = 4; - // Scope -> HLO -> dot dimension number -> iteration spec at the HLO's output. - const TensorIterationSpec::DimIterationSpec* IterSpec(Scope scope, - const HloInstruction*, - int dimension) const; + const DimIterationSpec* IterSpec(Scope scope, const HloInstruction*, + int dimension) const; // Parameter HLO instructions used in a scope of `dot`. const absl::flat_hash_set& ScopeParameters( const Scope scope) const { diff --git a/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton_test.cc b/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton_test.cc index b154efabe1ef0a..d02faa5b3abdc9 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton_test.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton_test.cc @@ -94,7 +94,7 @@ ENTRY e { GmockMatch(m::Fusion(m::Parameter(), m::Parameter()))); } -TEST_F(GemmRewriterTritonTest, DoNotFuseConstants) { +TEST_F(GemmRewriterTritonTest, DoNotFuseConstant) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(R"( HloModule m @@ -102,14 +102,14 @@ HloModule m ENTRY e { p0 = s8[60,5] parameter(0) c0 = f16[60,5] convert(p0) - cst1 = f16[] constant(1234) - r1 = f16[5,120] broadcast(cst1) + cst1 = f16[600] constant({...}) + r1 = f16[5,120] reshape(cst1) ROOT d = f16[60,120] dot(c0, r1), lhs_contracting_dims={1}, rhs_contracting_dims={0} })")); EXPECT_TRUE(GemmRewriterTriton(gpu_version_).Run(module.get()).value()); EXPECT_THAT(module->entry_computation()->root_instruction(), - GmockMatch(m::Fusion(m::Parameter(), m::Broadcast()))); + GmockMatch(m::Fusion(m::Constant(), m::Parameter()))); } using TritonDotAnalysisTest = HloTestBase; @@ -793,145 +793,6 @@ ENTRY e { EXPECT_TRUE(GemmRewriterTriton(cc).Run(module.get()).value()); } -TEST_F(GemmRewriterTritonTest, DoNotFuseIncompatibleDimOrders) { - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(R"( -HloModule m - -ENTRY e { - p0 = f16[5,3] parameter(0) - p1 = f16[5,7] parameter(1) - p2 = f16[7,5] parameter(2) - t = f16[5,7] transpose(p2), dimensions={1,0} - a = f16[5,7] add(t, p1) - ROOT d = f16[3,7] dot(p0, a), - lhs_contracting_dims={0}, rhs_contracting_dims={0} -})")); - - EXPECT_TRUE(GemmRewriterTriton(gpu_version_).Run(module.get()).value()); - EXPECT_THAT( - module->entry_computation()->root_instruction(), - GmockMatch(m::Fusion(m::Parameter(), m::Parameter(), m::Transpose()))); -} - -TEST_F(GemmRewriterTritonTest, DoNotFuseTooManyParameters) { - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(R"( -ENTRY e { - tmp_0 = f32[] constant(1) - tmp_1 = f32[3,49]{1,0} broadcast(tmp_0), dimensions={} - tmp_2 = f32[3,49]{1,0} parameter(6) - tmp_3 = f32[] constant(0) - tmp_4 = f32[3,49]{1,0} broadcast(tmp_3), dimensions={} - tmp_5 = pred[3,49]{1,0} compare(tmp_2, tmp_4), direction=GT - tmp_6 = f32[3,49]{1,0} convert(tmp_5) - tmp_7 = f32[3,49]{1,0} subtract(tmp_1, tmp_6) - tmp_8 = s32[] parameter(13) - tmp_9 = f32[] convert(tmp_8) - tmp_10 = f32[] maximum(tmp_9, tmp_0) - tmp_11 = f32[] divide(tmp_3, tmp_10) - tmp_12 = f32[3,49]{1,0} broadcast(tmp_11), dimensions={} - tmp_13 = pred[3,49]{1,0} parameter(7) - tmp_14 = pred[3,49]{1,0} parameter(10) - tmp_15 = pred[3,49]{1,0} and(tmp_13, tmp_14) - tmp_16 = f32[3,49]{1,0} convert(tmp_15) - tmp_17 = f32[3,49]{1,0} multiply(tmp_12, tmp_16) - tmp_18 = f32[3,49]{1,0} negate(tmp_17) - tmp_19 = f32[3,49]{1,0} multiply(tmp_7, tmp_18) - tmp_20 = f32[3,49]{1,0} parameter(19) - tmp_21 = f32[3,49]{1,0} subtract(tmp_1, tmp_20) - tmp_22 = f32[3,49]{1,0} divide(tmp_19, tmp_21) - tmp_23 = f32[3,49]{1,0} negate(tmp_22) - tmp_24 = f32[3,49]{1,0} negate(tmp_6) - tmp_25 = f32[3,49]{1,0} multiply(tmp_24, tmp_17) - tmp_26 = f32[3,49]{1,0} divide(tmp_25, tmp_20) - tmp_27 = f32[3,49]{1,0} add(tmp_23, tmp_26) - tmp_28 = f32[3,49]{1,0} parameter(18) - tmp_29 = f32[3,49]{1,0} multiply(tmp_27, tmp_28) - tmp_30 = f32[3,49]{1,0} parameter(17) - tmp_31 = f32[3,49]{1,0} multiply(tmp_29, tmp_30) - tmp_32 = f32[3,49]{1,0} parameter(16) - tmp_33 = f32[3,49]{1,0} multiply(tmp_31, tmp_32) - tmp_34 = f32[3,49]{1,0} parameter(15) - tmp_35 = f32[3,49]{1,0} add(tmp_33, tmp_34) - tmp_36 = f32[3,49]{1,0} parameter(14) - tmp_37 = f32[3,49]{1,0} add(tmp_35, tmp_36) - tmp_38 = f32[1,1]{1,0} constant({ {0} }) - tmp_39 = f32[1,1]{1,0} broadcast(tmp_38), dimensions={0,1} - tmp_40 = f32[] reshape(tmp_39) - tmp_41 = f32[3,32]{1,0} broadcast(tmp_40), dimensions={} - tmp_42 = u32[48]{0} parameter(11) - tmp_43 = u32[48]{0} parameter(5) - tmp_44 = u32[96]{0} concatenate(tmp_42, tmp_43), dimensions={0} - tmp_45 = u32[3,32]{1,0} reshape(tmp_44) - tmp_46 = u32[96]{0} reshape(tmp_45) - tmp_47 = u32[] constant(1) - tmp_48 = u32[3,32]{1,0} broadcast(tmp_47), dimensions={} - tmp_49 = u32[96]{0} reshape(tmp_48) - tmp_50 = u32[96]{0} shift-right-logical(tmp_46, tmp_49) - tmp_51 = u32[3,32]{1,0} reshape(tmp_50) - tmp_52 = u32[3,32]{1,0} or(tmp_51, tmp_48) - tmp_53 = f32[3,32]{1,0} bitcast-convert(tmp_52) - tmp_54 = f32[3,32]{1,0} broadcast(tmp_0), dimensions={} - tmp_55 = f32[3,32]{1,0} subtract(tmp_53, tmp_54) - tmp_56 = f32[1,1]{1,0} constant({ {1} }) - tmp_57 = f32[1,1]{1,0} broadcast(tmp_56), dimensions={0,1} - tmp_58 = f32[] reshape(tmp_57) - tmp_59 = f32[3,32]{1,0} broadcast(tmp_58), dimensions={} - tmp_60 = f32[3,32]{1,0} multiply(tmp_55, tmp_59) - tmp_61 = f32[3,32]{1,0} add(tmp_60, tmp_41) - tmp_62 = f32[3,32]{1,0} maximum(tmp_41, tmp_61) - tmp_63 = f32[3,32]{1,0} broadcast(tmp_3), dimensions={} - tmp_64 = pred[3,32]{1,0} compare(tmp_62, tmp_63), direction=LT - tmp_65 = f32[3,32]{1,0} convert(tmp_64) - tmp_66 = f32[3,49]{1,0} parameter(9) - tmp_67 = f32[49]{0} parameter(4) - tmp_68 = f32[3,49]{1,0} broadcast(tmp_67), dimensions={1} - tmp_69 = f32[3,49]{1,0} add(tmp_66, tmp_68) - tmp_70 = f32[1,49]{1,0} parameter(12) - tmp_71 = f32[1,49]{1,0} broadcast(tmp_0), dimensions={} - tmp_72 = f32[1,49]{1,0} divide(tmp_70, tmp_71) - tmp_73 = f32[1,49]{1,0} broadcast(tmp_72), dimensions={0,1} - tmp_74 = f32[49]{0} reshape(tmp_73) - tmp_75 = f32[3,49]{1,0} broadcast(tmp_74), dimensions={1} - tmp_76 = f32[3,49]{1,0} subtract(tmp_69, tmp_75) - tmp_77 = f32[1,49]{1,0} parameter(3) - tmp_78 = f32[1,49]{1,0} parameter(8) - tmp_79 = f32[1,49]{1,0} divide(tmp_78, tmp_71) - tmp_80 = f32[1,49]{1,0} multiply(tmp_72, tmp_72) - tmp_81 = f32[1,49]{1,0} subtract(tmp_79, tmp_80) - tmp_82 = f32[1,49]{1,0} add(tmp_81, tmp_71) - tmp_83 = f32[1,49]{1,0} rsqrt(tmp_82) - tmp_84 = f32[1,49]{1,0} multiply(tmp_77, tmp_83) - tmp_85 = f32[1,49]{1,0} broadcast(tmp_84), dimensions={0,1} - tmp_86 = f32[49]{0} reshape(tmp_85) - tmp_87 = f32[3,49]{1,0} broadcast(tmp_86), dimensions={1} - tmp_88 = f32[3,49]{1,0} multiply(tmp_76, tmp_87) - tmp_89 = f32[1,49]{1,0} parameter(2) - tmp_90 = f32[1,49]{1,0} broadcast(tmp_89), dimensions={0,1} - tmp_91 = f32[49]{0} reshape(tmp_90) - tmp_92 = f32[3,49]{1,0} broadcast(tmp_91), dimensions={1} - tmp_93 = f32[3,49]{1,0} add(tmp_88, tmp_92) - tmp_94 = f32[49,32]{1,0} parameter(1) - tmp_95 = f32[3,32]{1,0} dot(tmp_93, tmp_94), lhs_contracting_dims={1}, rhs_contracting_dims={0} - tmp_96 = f32[32]{0} parameter(0) - tmp_97 = f32[3,32]{1,0} broadcast(tmp_96), dimensions={1} - tmp_98 = f32[3,32]{1,0} add(tmp_95, tmp_97) - tmp_99 = f32[3,32]{1,0} multiply(tmp_65, tmp_98) - tmp_100 = f32[3,32]{1,0} divide(tmp_99, tmp_63) - tmp_101 = f32[3,32]{1,0} maximum(tmp_100, tmp_63) - ROOT tmp_102 = f32[49,32]{1,0} dot(tmp_37, tmp_101), lhs_contracting_dims={0}, rhs_contracting_dims={0} -})")); - - EXPECT_TRUE(GemmRewriterTriton(gpu_version_).Run(module.get()).value()); - EXPECT_EQ(module->entry_computation()->root_instruction()->opcode(), - HloOpcode::kFusion); - EXPECT_EQ(module->entry_computation()->root_instruction()->fusion_kind(), - HloInstruction::FusionKind::kCustom); - EXPECT_LE(module->entry_computation()->root_instruction()->operand_count(), - DotFusionAnalysis::kMaxParameterPerScope * 2); -} - } // namespace } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index b3944952ac68da..f490f9b127e21a 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -973,29 +973,6 @@ Status GpuCompiler::OptimizeHloPostLayoutAssignment( }); } - GpuFloatSupport bf16_support(BF16); - GpuFloatSupport f8e5m2_support(F8E5M2); - GpuFloatSupport f8e4m3fn_support(F8E4M3FN); - FloatSupport f8e4m3b11fnuz_support(F8E4M3B11FNUZ); - FloatSupport f8e5m2fnuz_support(F8E5M2FNUZ); - FloatSupport f8e4m3fnuz_support(F8E4M3FNUZ); - - auto add_float_normalization = [&](HloPassPipeline& pipeline) { - auto& sub_pipeline = - pipeline.AddPass("float_normalization"); - sub_pipeline.AddPass(&bf16_support); - sub_pipeline.AddPass(&f8e5m2_support); - sub_pipeline.AddPass(&f8e4m3fn_support); - sub_pipeline.AddPass(&f8e4m3b11fnuz_support); - sub_pipeline.AddPass(&f8e5m2fnuz_support); - sub_pipeline.AddPass(&f8e4m3fnuz_support); - // Remove `f32 -> bf16 -> f32` casts inserted by bf16 normalization. - if (debug_options.xla_gpu_simplify_all_fp_conversions()) { - sub_pipeline.AddPass(); - } - }; - add_float_normalization(pipeline); - // By default use an externally provided thread pool. tsl::thread::ThreadPool* thread_pool = options.thread_pool; std::optional overriding_thread_pool; @@ -1017,8 +994,18 @@ Status GpuCompiler::OptimizeHloPostLayoutAssignment( &pipeline, hlo_module, stream_exec, debug_options, options, gpu_target_config, autotune_results, thread_pool)); - // The Triton autotuner can insert new reductions. - add_float_normalization(pipeline); + GpuFloatSupport bf16_support(BF16); + pipeline.AddPass(&bf16_support); + GpuFloatSupport f8e5m2_support(F8E5M2); + pipeline.AddPass(&f8e5m2_support); + GpuFloatSupport f8e4m3fn_support(F8E4M3FN); + pipeline.AddPass(&f8e4m3fn_support); + FloatSupport f8e4m3b11fnuz_support(F8E4M3B11FNUZ); + pipeline.AddPass(&f8e4m3b11fnuz_support); + FloatSupport f8e5m2fnuz_support(F8E5M2FNUZ); + pipeline.AddPass(&f8e5m2fnuz_support); + FloatSupport f8e4m3fnuz_support(F8E4M3FNUZ); + pipeline.AddPass(&f8e4m3fnuz_support); // Remove `f32 -> bf16 -> f32` casts inserted by bf16 normalization. if (debug_options.xla_gpu_simplify_all_fp_conversions()) { diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_triton.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_triton.cc index 7c9cd87953a848..709f3e40b52c3f 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_triton.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_triton.cc @@ -792,7 +792,7 @@ StatusOr MatMulImpl( if (!analysis.ScopeParameters(DotFusionAnalysis::Scope::LHS).empty()) { const HloInstruction* lhs_param0 = *analysis.ScopeParameters(DotFusionAnalysis::Scope::LHS).begin(); - const TensorIterationSpec::DimIterationSpec* lhs_nc_iter_spec = + const DotFusionAnalysis::DimIterationSpec* lhs_nc_iter_spec = analysis.IterSpec(DotFusionAnalysis::Scope::LHS, lhs_param0, lhs_noncontracting_dim_idx); lhs_nc_split = lhs_nc_iter_spec->size() > 1; diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_triton_test.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_triton_test.cc index 86d7209de81114..fc4bb7204c1632 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_triton_test.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_triton_test.cc @@ -25,14 +25,11 @@ limitations under the License. #include "tensorflow/compiler/xla/autotuning.pb.h" #include "tensorflow/compiler/xla/error_spec.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_computation.h" -#include "tensorflow/compiler/xla/hlo/ir/hlo_instruction.h" #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" #include "tensorflow/compiler/xla/service/gpu/gpu_device_info_for_tests.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/gpu/launch_dimensions.h" #include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h" -#include "tensorflow/compiler/xla/service/pattern_matcher.h" -#include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h" #include "tensorflow/compiler/xla/stream_executor/device_description.h" #include "tensorflow/compiler/xla/tests/verified_hlo_module.h" #include "tensorflow/tsl/lib/core/status_test_util.h" @@ -45,8 +42,6 @@ namespace xla { namespace gpu { namespace { -namespace m = ::xla::match; - class TritonGemmNoTF32Test : public GpuCodegenTest { public: void SetUp() override { @@ -720,153 +715,6 @@ ENTRY e { EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{/*aabs=*/1e-6, /*arel=*/1e-6})); } -TEST_F(TritonGemmTest, BinaryOperationWithSmallInputsIsFused) { - const std::string kHloText = R"( -HloModule m - -ENTRY e { - p0 = s8[7,3] parameter(0) - p1 = f32[3,16] parameter(1) - p2 = f32[3,16] parameter(2) - e = f32[3,16] exponential(p1) - a = f32[3,16] add(e, p2) - c = f32[7,3] convert(p0) - ROOT d = f32[7,16] dot(c, a), - lhs_contracting_dims={1}, rhs_contracting_dims={0} -})"; - - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - GetOptimizedModule(kHloText)); - - EXPECT_THAT( - module->entry_computation()->root_instruction(), - GmockMatch(m::Fusion(m::Parameter(), m::Parameter(), m::Parameter()) - .WithFusionKind(HloInstruction::FusionKind::kCustom))); - - EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-1, /*arel=*/1e-3})); -} - -TEST_F(TritonGemmTest, BinaryOperationWithLargeInputsIsNotFused) { - const std::string kHloText = R"( -HloModule m - -ENTRY e { - p0 = f16[333,1000] parameter(0) - p1 = f32[1000,333] parameter(1) - p1n = f32[1000,333] negate(p1) - p2 = f32[1000,333] parameter(2) - p2n = f32[1000,333] negate(p2) - s = f32[1000,333] subtract(p1n, p2n) - c = f32[333,1000] convert(p0) - ROOT d = f32[1000,1000] dot(s, c), - lhs_contracting_dims={1}, rhs_contracting_dims={0} -})"; - - MatchOptimizedHlo(kHloText, R"( -; CHECK: fused_computation -; CHECK: negate -; CHECK: negate -; CHECK: ROOT -; CHECK-SAME: subtract -; CHECK: ENTRY -; CHECK: kLoop -; CHECK: kCustom -)"); - - EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-1, /*arel=*/1e-3})); -} - -TEST_F(TritonGemmTest, BinaryOperationOnLargeParametersIsFused) { - const std::string kHloText = R"( -HloModule m - -ENTRY e { - p0 = f16[1000,111] parameter(0) - p1 = f32[111,10000] parameter(1) - p2 = f32[111,10000] parameter(2) - s = f32[111,10000] subtract(p1, p2) - c = f32[1000,111] convert(p0) - ROOT d = f32[10000,1000] dot(s, c), - lhs_contracting_dims={0}, rhs_contracting_dims={1} -})"; - - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - GetOptimizedModule(kHloText)); - - EXPECT_THAT( - module->entry_computation()->root_instruction(), - GmockMatch(m::Fusion(m::Parameter(), m::Parameter(), m::Parameter()) - .WithFusionKind(HloInstruction::FusionKind::kCustom))); - - EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-1, /*arel=*/1e-3})); -} - -TEST_F(TritonGemmTest, LinkingLibdeviceTwiceWorks) { - const std::string kHloText = R"( -HloModule m - -ENTRY e { - p0 = s8[7,3] parameter(0) - c0 = f32[7,3] convert(p0) - e0 = f32[7,3] exponential(c0) - p1 = f32[3,16] parameter(1) - e1 = f32[3,16] exponential(p1) - d0 = f32[7,16] dot(c0, e1), - lhs_contracting_dims={1}, rhs_contracting_dims={0} - d1 = f32[7,16] dot(e0, p1), - lhs_contracting_dims={1}, rhs_contracting_dims={0} - ROOT a = f32[7,16] add(d0, d1) -})"; - - MatchOptimizedHlo(kHloText, R"( -; CHECK: ENTRY -; CHECK-NEXT: parameter -; CHECK-NEXT: parameter -; CHECK-NEXT: kCustom -; CHECK-NEXT: kCustom -; CHECK-NEXT: ROOT -; CHECK-SAME: add -)"); - - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - GetOptimizedModule(kHloText)); - - EXPECT_THAT(module->entry_computation()->root_instruction(), - GmockMatch(m::Add( - m::Fusion(m::Parameter(), m::Parameter()) - .WithFusionKind(HloInstruction::FusionKind::kCustom), - m::Fusion(m::Parameter(), m::Parameter()) - .WithFusionKind(HloInstruction::FusionKind::kCustom)))); - - EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-2, /*arel=*/1e-2})); -} - -TEST_F(TritonGemmTest, BroadcastOfConstantIsNotFused) { - const std::string kHloText = R"( -HloModule m - -ENTRY e { - p0 = f16[70,30] parameter(0) - p0c = f32[70,30] convert(p0) - constant_3663 = f32[] constant(4321) - bc0 = f32[30,5] broadcast(constant_3663) - p1 = f32[30,5] parameter(1) - a = f32[30,5] add(p1, bc0) - ROOT d = f32[70,5] dot(p0c, a), - lhs_contracting_dims={1}, rhs_contracting_dims={0} -})"; - - MatchOptimizedHlo(kHloText, R"( -; CHECK: ENTRY -; CHECK: constant -; CHECK: broadcast -; CHECK: fusion -; CHECK-SAME: kind=kCustom -)"); - - EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/2e-3, /*arel=*/2e-3})); -} - TEST_F(TritonGemmTest, Naming) { const char* hlo_text = R"( HloModule t diff --git a/tensorflow/compiler/xla/service/gpu/triton_autotuner.cc b/tensorflow/compiler/xla/service/gpu/triton_autotuner.cc index b8b8b5f6719931..440a9611a8fe27 100644 --- a/tensorflow/compiler/xla/service/gpu/triton_autotuner.cc +++ b/tensorflow/compiler/xla/service/gpu/triton_autotuner.cc @@ -418,11 +418,12 @@ std::vector GetExhaustiveMatmulAutotuneConfigs( std::vector GetFixedMatmulAutotuneConfigs( const se::CudaComputeCapability compute_capability) { std::vector configs = { - GemmKey(32, 32, 256, 1, 1, 4), GemmKey(64, 32, 32, 16, 1, 4), - GemmKey(32, 64, 64, 4, 1, 4), GemmKey(16, 16, 256, 1, 1, 4), - GemmKey(16, 128, 32, 16, 1, 4), GemmKey(16, 64, 128, 1, 1, 4), - GemmKey(16, 128, 32, 8, 1, 4), GemmKey(16, 16, 512, 1, 1, 4), - GemmKey(32, 16, 512, 1, 1, 4), GemmKey(64, 32, 64, 1, 2, 8)}; + GemmKey(32, 32, 256, 1, 1, 4), GemmKey(64, 32, 32, 16, 1, 4), + GemmKey(32, 64, 64, 4, 1, 4), GemmKey(128, 128, 64, 4, 1, 4), + GemmKey(16, 16, 256, 1, 1, 4), GemmKey(16, 128, 32, 16, 1, 4), + GemmKey(16, 64, 128, 1, 1, 4), GemmKey(16, 128, 32, 8, 1, 4), + GemmKey(16, 16, 512, 1, 1, 4), GemmKey(32, 16, 512, 1, 1, 4), + GemmKey(64, 32, 64, 1, 2, 8)}; if (compute_capability.IsAtLeast(se::CudaComputeCapability::AMPERE)) { absl::c_copy( std::vector{ From ac6e67291a4b0605f90829e69b5eafb5d4a58f4d Mon Sep 17 00:00:00 2001 From: Anlun Xu Date: Tue, 11 Jul 2023 01:58:40 -0700 Subject: [PATCH 114/376] [xla:gpu] Re-enable xla_gpu_cuda_graph_level=1 by default PiperOrigin-RevId: 547119126 --- tensorflow/compiler/xla/debug_options_flags.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/compiler/xla/debug_options_flags.cc b/tensorflow/compiler/xla/debug_options_flags.cc index 5847a16cfb209d..299635b30746e0 100644 --- a/tensorflow/compiler/xla/debug_options_flags.cc +++ b/tensorflow/compiler/xla/debug_options_flags.cc @@ -81,10 +81,10 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { // TODO(b/258036887): Enable cuda_graph_level=2. Currently blocked by CUDA 12 // integration. - opts.set_xla_gpu_cuda_graph_level(0); + opts.set_xla_gpu_cuda_graph_level(1); opts.set_xla_gpu_cuda_graph_num_runs_to_instantiate(2); opts.set_xla_gpu_enable_persistent_temp_buffers(false); - opts.set_xla_gpu_cuda_graph_min_graph_size(2); + opts.set_xla_gpu_cuda_graph_min_graph_size(5); opts.set_xla_gpu_cuda_graph_enable_concurrent_region(false); // Despite the name, fast min/max on GPUs does not seem to be any faster, and From b778b1a9513b80075685b0fa61a13a1a88f9cba1 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 11 Jul 2023 02:01:56 -0700 Subject: [PATCH 115/376] Update GraphDef version to 1554. PiperOrigin-RevId: 547119927 --- tensorflow/core/public/version.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index 03427a1910b0ce..28e0085b3d3abf 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -108,7 +108,7 @@ limitations under the License. #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 -#define TF_GRAPH_DEF_VERSION 1553 // Updated: 2023/7/10 +#define TF_GRAPH_DEF_VERSION 1554 // Updated: 2023/7/11 // Checkpoint compatibility versions (the versions field in SavedSliceMeta). // From afaf6438b96e910b00d979c2516902ab3e5ac117 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 11 Jul 2023 02:01:57 -0700 Subject: [PATCH 116/376] compat: Update forward compatibility horizon to 2023-07-11 PiperOrigin-RevId: 547119934 --- tensorflow/python/compat/compat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index 97bc7f8d44c1c1..35e0da462e4704 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -29,7 +29,7 @@ # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2023, 7, 10) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2023, 7, 11) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None From 2ac4fd831d9186f596b77915b1536cb289a768f2 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 11 Jul 2023 05:03:40 -0700 Subject: [PATCH 117/376] Internal change PiperOrigin-RevId: 547155426 --- .../acceleration/mini_benchmark/BUILD | 1 + .../mini_benchmark/build_defs.bzl | 4 +- .../mini_benchmark/model_modifier/BUILD | 6 +-- .../model_modifier/embedder_main.cc | 4 ++ .../mini_benchmark/model_validation_test.cc | 44 ++++++++++++++++++- .../mini_benchmark/special_rules.bzl | 6 +++ .../acceleration/mini_benchmark/validator.cc | 3 ++ tensorflow/lite/tools/benchmark/BUILD | 18 ++++++++ .../tools/benchmark/register_custom_op.cc | 23 ++++++++++ .../lite/tools/benchmark/register_custom_op.h | 23 ++++++++++ 10 files changed, 126 insertions(+), 6 deletions(-) create mode 100644 tensorflow/lite/tools/benchmark/register_custom_op.cc create mode 100644 tensorflow/lite/tools/benchmark/register_custom_op.h diff --git a/tensorflow/lite/experimental/acceleration/mini_benchmark/BUILD b/tensorflow/lite/experimental/acceleration/mini_benchmark/BUILD index 4d6f631e7d5a85..3dcaefcb2f4846 100644 --- a/tensorflow/lite/experimental/acceleration/mini_benchmark/BUILD +++ b/tensorflow/lite/experimental/acceleration/mini_benchmark/BUILD @@ -451,6 +451,7 @@ cc_library( "//tensorflow/lite/core/c:common", "//tensorflow/lite/core/kernels:builtin_ops", "//tensorflow/lite/tools:model_loader", + "//tensorflow/lite/tools/benchmark:register_custom_op", ], ) diff --git a/tensorflow/lite/experimental/acceleration/mini_benchmark/build_defs.bzl b/tensorflow/lite/experimental/acceleration/mini_benchmark/build_defs.bzl index dbab6f2bb7af66..446bbef45e1b2d 100644 --- a/tensorflow/lite/experimental/acceleration/mini_benchmark/build_defs.bzl +++ b/tensorflow/lite/experimental/acceleration/mini_benchmark/build_defs.bzl @@ -21,7 +21,7 @@ load( load("//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", "add_suffix") load("//tensorflow/lite/experimental/acceleration/mini_benchmark:special_rules.bzl", "libjpeg_handle_deps") -def embedded_binary(name, binary, array_variable_name, testonly = False): +def embedded_binary(name, binary, array_variable_name, testonly = False, exec_properties = None): """Create a cc_library that embeds a binary as constant data. Args: @@ -55,6 +55,7 @@ def embedded_binary(name, binary, array_variable_name, testonly = False): srcs = [cc_name], hdrs = [h_name], testonly = testonly, + exec_properties = exec_properties, ) def validation_model( @@ -173,6 +174,7 @@ def validation_test(name, validation_model, tags = [], copts = [], deps = []): ], "//conditions:default": [], }) + libjpeg_handle_deps(), + linkstatic = 1, ) def cc_library_with_forced_in_process_benchmark_variant( diff --git a/tensorflow/lite/experimental/acceleration/mini_benchmark/model_modifier/BUILD b/tensorflow/lite/experimental/acceleration/mini_benchmark/model_modifier/BUILD index 08f0e0a52fdf59..ca96f4c8d67a8b 100644 --- a/tensorflow/lite/experimental/acceleration/mini_benchmark/model_modifier/BUILD +++ b/tensorflow/lite/experimental/acceleration/mini_benchmark/model_modifier/BUILD @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -load("//tensorflow/lite/experimental/acceleration/mini_benchmark:special_rules.bzl", "libjpeg_handle_deps") +load("//tensorflow/lite/experimental/acceleration/mini_benchmark:special_rules.bzl", "libjpeg_handle_deps", "register_selected_ops_deps") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -89,7 +89,7 @@ cc_binary( "//tensorflow/lite/tools:command_line_flags", "@com_google_absl//absl/strings", "@flatbuffers", - ] + libjpeg_handle_deps(), + ] + libjpeg_handle_deps() + register_selected_ops_deps(), ) cc_library( @@ -125,5 +125,5 @@ cc_test( "//tensorflow/lite/schema:schema_fbs", "//tensorflow/lite/tools:model_loader", "@com_google_googletest//:gtest_main", - ] + libjpeg_handle_deps(), + ] + libjpeg_handle_deps() + register_selected_ops_deps(), ) diff --git a/tensorflow/lite/experimental/acceleration/mini_benchmark/model_modifier/embedder_main.cc b/tensorflow/lite/experimental/acceleration/mini_benchmark/model_modifier/embedder_main.cc index 0287f3d1dabf7c..0e11653bb8d1dd 100644 --- a/tensorflow/lite/experimental/acceleration/mini_benchmark/model_modifier/embedder_main.cc +++ b/tensorflow/lite/experimental/acceleration/mini_benchmark/model_modifier/embedder_main.cc @@ -36,6 +36,7 @@ limitations under the License. #include "tensorflow/lite/experimental/acceleration/mini_benchmark/decode_jpeg_register.h" #include "tensorflow/lite/experimental/acceleration/mini_benchmark/model_modifier/embedder.h" #include "tensorflow/lite/schema/reflection/schema_generated.h" +#include "tensorflow/lite/tools/benchmark/register_custom_op.h" #include "tensorflow/lite/tools/command_line_flags.h" namespace tflite { @@ -116,6 +117,9 @@ int RunEmbedder(const EmbedderOptions& options) { resolver.AddCustom( "validation/decode_jpeg", ::tflite::acceleration::decode_jpeg_kernel::Register_DECODE_JPEG(), 1); + + RegisterSelectedOps(&resolver); + auto status = embedder.CreateModelWithEmbeddedValidation(&fbb, &resolver); if (!status.ok()) { std::cerr << "Creating model with embedded validation failed: " diff --git a/tensorflow/lite/experimental/acceleration/mini_benchmark/model_validation_test.cc b/tensorflow/lite/experimental/acceleration/mini_benchmark/model_validation_test.cc index 55cb8db57950da..377c281068bed7 100644 --- a/tensorflow/lite/experimental/acceleration/mini_benchmark/model_validation_test.cc +++ b/tensorflow/lite/experimental/acceleration/mini_benchmark/model_validation_test.cc @@ -226,7 +226,7 @@ TEST_F(LocalizerValidationRegressionTest, NnapiSl) { } #endif // ENABLE_NNAPI_SL_TEST -TEST_F(LocalizerValidationRegressionTest, Gpu) { +TEST_F(LocalizerValidationRegressionTest, GpuAny) { AndroidInfo android_info; auto status = RequestAndroidInfo(&android_info); ASSERT_TRUE(status.ok()); @@ -237,7 +237,47 @@ TEST_F(LocalizerValidationRegressionTest, Gpu) { fbb_.Finish(CreateComputeSettings(fbb_, ExecutionPreference_ANY, CreateTFLiteSettings(fbb_, Delegate_GPU))); #ifdef __ANDROID__ - CheckValidation("GPU"); + CheckValidation("GPUANY"); +#endif // __ANDROID__ +} + +TEST_F(LocalizerValidationRegressionTest, GpuOpenGL) { + AndroidInfo android_info; + auto status = RequestAndroidInfo(&android_info); + ASSERT_TRUE(status.ok()); + if (android_info.is_emulator) { + std::cerr << "Skipping GPU on emulator\n"; + return; + } + fbb_.Finish(CreateComputeSettings( + fbb_, ExecutionPreference_ANY, + CreateTFLiteSettings( + fbb_, Delegate_GPU, 0, + CreateGPUSettings(fbb_, /* allow_precision_loss */ false, + /* allow_quantized_inference */ true, + GPUBackend_OPENGL)))); +#ifdef __ANDROID__ + CheckValidation("GPUOPENGL"); +#endif // __ANDROID__ +} + +TEST_F(LocalizerValidationRegressionTest, GpuOpenCL) { + AndroidInfo android_info; + auto status = RequestAndroidInfo(&android_info); + ASSERT_TRUE(status.ok()); + if (android_info.is_emulator) { + std::cerr << "Skipping GPU on emulator\n"; + return; + } + fbb_.Finish(CreateComputeSettings( + fbb_, ExecutionPreference_ANY, + CreateTFLiteSettings( + fbb_, Delegate_GPU, 0, + CreateGPUSettings(fbb_, /* allow_precision_loss */ false, + /* allow_quantized_inference */ true, + GPUBackend_OPENCL)))); +#ifdef __ANDROID__ + CheckValidation("GPUOPENCL"); #endif // __ANDROID__ } diff --git a/tensorflow/lite/experimental/acceleration/mini_benchmark/special_rules.bzl b/tensorflow/lite/experimental/acceleration/mini_benchmark/special_rules.bzl index c2e0c24bdbef75..522f3d95bec451 100644 --- a/tensorflow/lite/experimental/acceleration/mini_benchmark/special_rules.bzl +++ b/tensorflow/lite/experimental/acceleration/mini_benchmark/special_rules.bzl @@ -33,3 +33,9 @@ def minibenchmark_visibility_allowlist(): return [ "//tensorflow/lite/tools/benchmark/experimental/delegate_performance:__subpackages__", ] + +def register_selected_ops_deps(): + """Return a list of dependencies for registering selected ops.""" + return [ + clean_dep("//tensorflow/lite/tools/benchmark:register_custom_op"), + ] diff --git a/tensorflow/lite/experimental/acceleration/mini_benchmark/validator.cc b/tensorflow/lite/experimental/acceleration/mini_benchmark/validator.cc index a9eee994832f49..98851a68bae1f1 100644 --- a/tensorflow/lite/experimental/acceleration/mini_benchmark/validator.cc +++ b/tensorflow/lite/experimental/acceleration/mini_benchmark/validator.cc @@ -45,6 +45,7 @@ limitations under the License. #include "tensorflow/lite/logger.h" #include "tensorflow/lite/minimal_logging.h" #include "tensorflow/lite/mutable_op_resolver.h" +#include "tensorflow/lite/tools/benchmark/register_custom_op.h" #include "tensorflow/lite/tools/model_loader.h" #ifndef TEMP_FAILURE_RETRY @@ -331,6 +332,8 @@ MinibenchmarkStatus Validator::CreateInterpreter(int* delegate_error_out, "validation/decode_jpeg", ::tflite::acceleration::decode_jpeg_kernel::Register_DECODE_JPEG(), 1); + RegisterSelectedOps(resolver_.get()); + tflite::InterpreterBuilder builder(*model_loader_->GetModel(), *resolver_); // Add delegate if not running on CPU. if (delegate_ != nullptr) { diff --git a/tensorflow/lite/tools/benchmark/BUILD b/tensorflow/lite/tools/benchmark/BUILD index 95ca87a9e870dc..3592da77f5bf0c 100644 --- a/tensorflow/lite/tools/benchmark/BUILD +++ b/tensorflow/lite/tools/benchmark/BUILD @@ -256,6 +256,24 @@ cc_library( ], ) +cc_library( + name = "register_custom_op", + srcs = [ + "register_custom_op.cc", + ], + hdrs = [ + "register_custom_op.h", + ], + copts = common_copts, + deps = [ + "//tensorflow/lite:op_resolver", + "@com_google_absl//absl/base:core_headers", + ], + alwayslink = 1, +) + +exports_files(["register_custom_op.h"]) + cc_library( name = "benchmark_utils", srcs = [ diff --git a/tensorflow/lite/tools/benchmark/register_custom_op.cc b/tensorflow/lite/tools/benchmark/register_custom_op.cc new file mode 100644 index 00000000000000..3592663d3f2f9d --- /dev/null +++ b/tensorflow/lite/tools/benchmark/register_custom_op.cc @@ -0,0 +1,23 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "absl/base/attributes.h" +#include "tensorflow/lite/op_resolver.h" + +// Version with Weak linker attribute doing nothing: if someone links this +// library with another definition of this function (presumably to actually +// register custom ops), that version will be used instead. +void ABSL_ATTRIBUTE_WEAK +RegisterSelectedOps(::tflite::MutableOpResolver* resolver) {} diff --git a/tensorflow/lite/tools/benchmark/register_custom_op.h b/tensorflow/lite/tools/benchmark/register_custom_op.h new file mode 100644 index 00000000000000..9278e31a43fbe7 --- /dev/null +++ b/tensorflow/lite/tools/benchmark/register_custom_op.h @@ -0,0 +1,23 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_TOOLS_BENCHMARK_REGISTER_CUSTOM_OP_H_ +#define TENSORFLOW_LITE_TOOLS_BENCHMARK_REGISTER_CUSTOM_OP_H_ + +#include "tensorflow/lite/op_resolver.h" + +void RegisterSelectedOps(::tflite::MutableOpResolver* resolver); + +#endif // TENSORFLOW_LITE_TOOLS_BENCHMARK_REGISTER_CUSTOM_OP_H_ From 77323f77fb3e40e5e025fd0bd469c7f472d47845 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 11 Jul 2023 05:20:38 -0700 Subject: [PATCH 118/376] Integrate LLVM at llvm/llvm-project@cf410b181f8c Updates LLVM usage to match [cf410b181f8c](https://github.com/llvm/llvm-project/commit/cf410b181f8c) PiperOrigin-RevId: 547158175 --- third_party/llvm/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl index 7a772fb5657237..d544405fd05954 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "86943d863ef66d68bf79d3e2f0ec2c205814b235" - LLVM_SHA256 = "b37024a8d88985b69b240e4222932379f794906f602464c4c31c516580508a93" + LLVM_COMMIT = "cf410b181f8c546b9ae4cd65a82d08e65bacec82" + LLVM_SHA256 = "b46fea00b4d661444425f4dcd39f5eb12f6a5d8c4964e8e0f3c8e0e601490476" tf_http_archive( name = name, From ff180ecb7b7173fee06b719ebd0a7a7f8a9d1be2 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 11 Jul 2023 05:55:26 -0700 Subject: [PATCH 119/376] Update TFRT dependency to use revision http://github.com/tensorflow/runtime/commit/2a33f6928a9584cc8285b1a9e74e3336c41da8d6. PiperOrigin-RevId: 547163736 --- third_party/tf_runtime/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/tf_runtime/workspace.bzl b/third_party/tf_runtime/workspace.bzl index 8a0084f1896fb4..889fb7909d092a 100644 --- a/third_party/tf_runtime/workspace.bzl +++ b/third_party/tf_runtime/workspace.bzl @@ -6,8 +6,8 @@ def repo(): """Imports TFRT.""" # Attention: tools parse and update these lines. - TFRT_COMMIT = "3e5fa08c9a184710601dbf8a1c7b52eaa306124d" - TFRT_SHA256 = "56d5a34fa884ec6eee7a602d90ee8387099c488bf5c3dc21a45ae8e19e2e27ad" + TFRT_COMMIT = "2a33f6928a9584cc8285b1a9e74e3336c41da8d6" + TFRT_SHA256 = "ed37ce13e860d1e3b340cfd0e2e63e0f1b2e3206f0e774e28f29f191c611a199" tf_http_archive( name = "tf_runtime", From a4863eb8b9dc6a35cbe6ec18de79f4aecab0ea79 Mon Sep 17 00:00:00 2001 From: Trevor Morris Date: Tue, 11 Jul 2023 06:08:36 -0700 Subject: [PATCH 120/376] PR #4070: Use ptxas version to determine MinThreadsXRowReduction Imported from GitHub PR https://github.com/openxla/xla/pull/4070 Update for https://github.com/openxla/xla/pull/3432 to use 512 threads for ptxas versions <12.2 and 1024 threads otherwise. cc @nouiz @cheshire @cliffwoolley Copybara import of the project: -- 7336238a9bc3aac2255ea8dd0bde11c555c4aacd by Trevor Morris : Use ptxas version to determine MinThreadsXRowReduction -- 37a441b744b83ce5e2bba949e874986244ece165 by Trevor Morris : Update comment Merging this change closes #4070 PiperOrigin-RevId: 547166949 --- tensorflow/compiler/xla/service/gpu/BUILD | 2 ++ .../xla/service/gpu/hlo_fusion_analysis.cc | 5 ++-- .../xla/service/gpu/ir_emission_utils.cc | 28 +++++++++++++++++-- .../xla/service/gpu/ir_emission_utils.h | 5 ++-- .../service/gpu/tree_reduction_rewriter.cc | 2 +- 5 files changed, 35 insertions(+), 7 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 2d7c4d2cb488ff..3b66d433ffaabd 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -989,6 +989,7 @@ cc_library( hdrs = ["ir_emission_utils.h"], compatible_with = get_compatible_with_cloud(), deps = [ + ":gpu_asm_opts_util", ":target_util", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/hlo/ir:hlo", @@ -999,6 +1000,7 @@ cc_library( "//tensorflow/compiler/xla/service/llvm_ir:llvm_type_conversion_util", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/compiler/xla/stream_executor", + "//tensorflow/compiler/xla/stream_executor/gpu:asm_compiler", "//tensorflow/compiler/xla/translate/mhlo_to_hlo:type_to_shape", "@com_google_absl//absl/container:flat_hash_set", "@llvm-project//llvm:Core", diff --git a/tensorflow/compiler/xla/service/gpu/hlo_fusion_analysis.cc b/tensorflow/compiler/xla/service/gpu/hlo_fusion_analysis.cc index 250ff1c54af1cb..5c8aacd93b257b 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_fusion_analysis.cc +++ b/tensorflow/compiler/xla/service/gpu/hlo_fusion_analysis.cc @@ -693,7 +693,7 @@ StatusOr HloFusionAnalysis::ComputeReductionCodegenInfo( // For multi-output fusions, reduce the block size further to decrease // register pressure when multiple outputs are computed by each thread. int64_t max_block_size = - std::max(MinThreadsXRowReduction(), + std::max(MinThreadsXRowReduction(first_reduce->GetModule()->config()), static_cast(512LL / NearestPowerOfTwo(fan_out))); return std::min(max_block_size, RoundUpTo(CeilOfRatio(reduction_dimensions.dimensions[2], @@ -710,7 +710,8 @@ StatusOr HloFusionAnalysis::ComputeReductionCodegenInfo( int64_t shmem_usage = ProjectedShmemUsageBytes(reduction_dimensions, instr_index_groups); const int64_t shmem_budget = device_info_->shared_memory_per_block; - bool reduction_is_race_free = ReductionIsRaceFree(reduction_dimensions); + bool reduction_is_race_free = ReductionIsRaceFree( + first_reduce->GetModule()->config(), reduction_dimensions); bool vectorize = // Vectorization might cause us to run out of budget. (shmem_usage * 2 <= shmem_budget) && diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc index 083b483fcfb06d..b2fb99ac46c0f4 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc @@ -32,6 +32,7 @@ limitations under the License. #include "tensorflow/compiler/xla/hlo/ir/hlo_instruction.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_opcode.h" #include "tensorflow/compiler/xla/mlir_hlo/mhlo/IR/hlo_ops.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_asm_opts_util.h" #include "tensorflow/compiler/xla/service/gpu/target_util.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_type_conversion_util.h" @@ -40,6 +41,10 @@ limitations under the License. #include "tensorflow/compiler/xla/translate/mhlo_to_hlo/type_to_shape.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#ifdef GOOGLE_CUDA +#include "tensorflow/compiler/xla/stream_executor/gpu/asm_compiler.h" +#endif // GOOGLE_CUDA + namespace xla { namespace gpu { @@ -136,6 +141,23 @@ bool IsMatrixMultiplication(const HloInstruction& dot) { return true; } +int64_t MinThreadsXRowReduction(const HloModuleConfig& hlo_module_config) { +#ifdef GOOGLE_CUDA + auto ptxas_config = + PtxOptsFromDebugOptions(hlo_module_config.debug_options()); + auto ptxas_version_tuple = + se::GetAsmCompilerVersion(ptxas_config.preferred_cuda_dir); + // ptxas versions prior to 12.2 have a very rare bug when very high register + // spilling occurs with some order of instructions, so use less threads to + // reduce register pressure. + if (!ptxas_version_tuple.ok() || + ptxas_version_tuple.value() < std::array{12, 2, 0}) { + return 512; + } +#endif // GOOGLE_CUDA + return 1024; +} + Vector3 GetReductionTiling(const ReductionDimensions& reduction_dimensions) { if (reduction_dimensions.is_row_reduction) { int64_t tile_z = std::min(reduction_dimensions.dimensions[0], @@ -777,11 +799,13 @@ Shape GetShape(mlir::Value value) { return {}; } -bool ReductionIsRaceFree(const ReductionDimensions& reduction_dimensions) { +bool ReductionIsRaceFree(const HloModuleConfig& hlo_module_config, + const ReductionDimensions& reduction_dimensions) { Vector3 reduction_tiling = GetReductionTiling(reduction_dimensions); return (reduction_dimensions.is_row_reduction && reduction_dimensions.dimensions[2] <= - MinThreadsXRowReduction() * reduction_tiling[2] && + MinThreadsXRowReduction(hlo_module_config) * + reduction_tiling[2] && reduction_dimensions.dimensions[0] <= BatchedReductionRaceFreeBound()) || (!reduction_dimensions.is_row_reduction && diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h index fb72d566d87a81..ba031cb1504aaf 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h @@ -51,7 +51,7 @@ inline constexpr int64_t WarpSize() { return 32; } // Need at least 1024 threads/block for reasonable tree reduction // performance (assuming all data fits). -inline constexpr int64_t MinThreadsXRowReduction() { return 1024; } +int64_t MinThreadsXRowReduction(const HloModuleConfig& hlo_module_config); // When doing batched row reduction, how big the batch dimension could be. inline constexpr int64_t BatchedReductionRaceFreeBound() { return 8; } @@ -175,7 +175,8 @@ Shape GetShape(mlir::Value value); // Returns whether the given reduction can be safely generated without atomics: // that is, at most one block will write to every output element. -bool ReductionIsRaceFree(const ReductionDimensions& reduction_dimensions); +bool ReductionIsRaceFree(const HloModuleConfig& hlo_module_config, + const ReductionDimensions& reduction_dimensions); // Description of how to emit a given transposition. // diff --git a/tensorflow/compiler/xla/service/gpu/tree_reduction_rewriter.cc b/tensorflow/compiler/xla/service/gpu/tree_reduction_rewriter.cc index ccf274bd2a0f56..b2f5b48fcb17e7 100644 --- a/tensorflow/compiler/xla/service/gpu/tree_reduction_rewriter.cc +++ b/tensorflow/compiler/xla/service/gpu/tree_reduction_rewriter.cc @@ -99,7 +99,7 @@ class ReductionRewriterVisitor : public DfsHloRewriteVisitor { bool is_row_reduction = reduction_dimensions.is_row_reduction; // Base case: everything fits. - if (ReductionIsRaceFree(reduction_dimensions)) { + if (ReductionIsRaceFree(hlo->GetModule()->config(), reduction_dimensions)) { VLOG(3) << "Base case: dimensions fit"; return OkStatus(); } From e0b9b6b88f577914b44135ae6244f635adafc0e8 Mon Sep 17 00:00:00 2001 From: Sergey Kozub Date: Tue, 11 Jul 2023 06:15:39 -0700 Subject: [PATCH 121/376] Remove unneeded argument for CalculateLaunchDimensions PiperOrigin-RevId: 547168544 --- tensorflow/compiler/xla/service/gpu/BUILD | 4 ---- .../compiler/xla/service/gpu/launch_dimensions.cc | 13 ++++--------- .../compiler/xla/service/gpu/launch_dimensions.h | 4 +--- 3 files changed, 5 insertions(+), 16 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 3b66d433ffaabd..3f6e361c00c65b 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -120,11 +120,7 @@ cc_library( compatible_with = get_compatible_with_cloud(), deps = [ ":gpu_device_info", - "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla/mlir_hlo:lhlo", - "//tensorflow/tsl/platform:logging", - "@llvm-project//mlir:IR", ], ) diff --git a/tensorflow/compiler/xla/service/gpu/launch_dimensions.cc b/tensorflow/compiler/xla/service/gpu/launch_dimensions.cc index 3c4b24b596d6d3..cc4f7f34e4c722 100644 --- a/tensorflow/compiler/xla/service/gpu/launch_dimensions.cc +++ b/tensorflow/compiler/xla/service/gpu/launch_dimensions.cc @@ -17,12 +17,8 @@ limitations under the License. #include #include -#include -#include "tensorflow/compiler/xla/debug_options_flags.h" -#include "tensorflow/compiler/xla/mlir_hlo/lhlo/IR/lhlo_ops.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/tsl/platform/logging.h" namespace xla { namespace gpu { @@ -80,7 +76,7 @@ int64_t ThreadsPerBlockRowVectorized(const Shape& shape, StatusOr CalculateLaunchDimensionsImplExperimental( const Shape& shape, GpuDeviceInfo gpu_device_info, - LaunchDimensionsConfig dim_config, mlir::Operation* op) { + LaunchDimensionsConfig dim_config) { int64_t num_elements = ShapeUtil::ElementsIn(shape); if (num_elements <= 1) { return LaunchDimensions(); @@ -205,12 +201,11 @@ StatusOr CalculateLaunchDimensionsImpl( StatusOr CalculateLaunchDimensions( const Shape& shape, GpuDeviceInfo gpu_device_info, - bool use_experimental_block_size, LaunchDimensionsConfig dim_config, - mlir::Operation* op) { - if (use_experimental_block_size && op != nullptr) { + bool use_experimental_block_size, LaunchDimensionsConfig dim_config) { + if (use_experimental_block_size) { VLOG(2) << "Experimental block size is enabled"; return CalculateLaunchDimensionsImplExperimental(shape, gpu_device_info, - dim_config, op); + dim_config); } return CalculateLaunchDimensionsImpl(shape, gpu_device_info, dim_config); } diff --git a/tensorflow/compiler/xla/service/gpu/launch_dimensions.h b/tensorflow/compiler/xla/service/gpu/launch_dimensions.h index 2140219fafee6d..05ce2b1be70411 100644 --- a/tensorflow/compiler/xla/service/gpu/launch_dimensions.h +++ b/tensorflow/compiler/xla/service/gpu/launch_dimensions.h @@ -19,7 +19,6 @@ limitations under the License. #include #include -#include "mlir/IR/Operation.h" // from @llvm-project #include "tensorflow/compiler/xla/service/gpu/gpu_device_info.h" #include "tensorflow/compiler/xla/shape.h" @@ -137,8 +136,7 @@ int64_t ThreadsPerBlockRowVectorized(const Shape& shape, // Calculates the launch dimensions used to invoke `hlo`. StatusOr CalculateLaunchDimensions( const Shape& shape, GpuDeviceInfo gpu_device_info, - bool use_experimental_block_size, LaunchDimensionsConfig dim_config = {}, - mlir::Operation* op = nullptr); + bool use_experimental_block_size, LaunchDimensionsConfig dim_config = {}); } // namespace gpu } // namespace xla From ca03e95dafa4ef6bedc51c135ebf72528cd81841 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 11 Jul 2023 06:21:27 -0700 Subject: [PATCH 122/376] Internal change PiperOrigin-RevId: 547169992 --- third_party/llvm/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl index d544405fd05954..7a772fb5657237 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "cf410b181f8c546b9ae4cd65a82d08e65bacec82" - LLVM_SHA256 = "b46fea00b4d661444425f4dcd39f5eb12f6a5d8c4964e8e0f3c8e0e601490476" + LLVM_COMMIT = "86943d863ef66d68bf79d3e2f0ec2c205814b235" + LLVM_SHA256 = "b37024a8d88985b69b240e4222932379f794906f602464c4c31c516580508a93" tf_http_archive( name = name, From f983dfbf22284a8c30ae05b3100b3ddf1c5f230b Mon Sep 17 00:00:00 2001 From: Frederik Gossen Date: Tue, 11 Jul 2023 06:45:28 -0700 Subject: [PATCH 123/376] [KernelGen] JIT-compile most the MLIR-generated GPU kernels JIT-compile all MLIR-generated kernels for which the build rules can be reconfigured easily. For now, this excludes i64-indexed kernels and kernels with different input and output types. PiperOrigin-RevId: 547175063 --- tensorflow/core/kernels/mlir_generated/BUILD | 329 ++++++++++--------- tensorflow/python/kernel_tests/linalg/BUILD | 2 +- tensorflow/python/ops/BUILD | 4 +- 3 files changed, 182 insertions(+), 153 deletions(-) diff --git a/tensorflow/core/kernels/mlir_generated/BUILD b/tensorflow/core/kernels/mlir_generated/BUILD index c4c5b6f5d99435..b94f1ffd77aa32 100644 --- a/tensorflow/core/kernels/mlir_generated/BUILD +++ b/tensorflow/core/kernels/mlir_generated/BUILD @@ -639,13 +639,14 @@ gpu_kernel_library( gpu_kernel_library( name = "gpu_atan2_kernels", - op = "atan2", - tile_size = "256", - types = [ + jit_types = [ "f16", "f32", "f64", ], + op = "atan2", + tile_size = "256", + types = [], unroll_factors = "4", ) @@ -748,25 +749,27 @@ gpu_kernel_library( gpu_kernel_library( name = "gpu_ceil_kernels", - op = "ceil", - tile_size = "256", - types = [ + jit_types = [ "f16", "f32", "f64", ], + op = "ceil", + tile_size = "256", + types = [], unroll_factors = "4", ) gpu_kernel_library( name = "gpu_floor_kernels", - op = "floor", - tile_size = "256", - types = [ + jit_types = [ "f16", "f32", "f64", ], + op = "floor", + tile_size = "256", + types = [], unroll_factors = "4", ) @@ -792,26 +795,28 @@ gpu_kernel_library( gpu_kernel_library( name = "gpu_rint_kernels", - jit_types = ["f16"], - op = "rint", - tile_size = "1024", - types = [ + jit_types = [ + "f16", "f32", "f64", ], + op = "rint", + tile_size = "1024", + types = [], ) gpu_kernel_library( name = "gpu_round_kernels", - op = "round", - tile_size = "1024", - types = [ + jit_types = [ "f16", "f32", "f64", "i32", "i64", ], + op = "round", + tile_size = "1024", + types = [], ) # Predicate kernels @@ -1029,12 +1034,13 @@ gpu_kernel_library( gpu_kernel_library( name = "gpu_conj_kernels", - op = "conj", - tile_size = "256", - types = [ + jit_types = [ "c64", "c128", ], + op = "conj", + tile_size = "256", + types = [], unroll_factors = "2", ) @@ -1171,10 +1177,6 @@ gpu_kernel_library( "ui16", "ui32", "ui64", - ], - op = "maximum", - tile_size = "1024", - types = [ "f16", "f32", "f64", @@ -1182,6 +1184,9 @@ gpu_kernel_library( "i64", "ui8", ], + op = "maximum", + tile_size = "1024", + types = [], unroll_factors = "4", ) @@ -1192,10 +1197,6 @@ gpu_kernel_library( "ui16", "ui32", "ui64", - ], - op = "minimum", - tile_size = "1024", - types = [ "f16", "f32", "f64", @@ -1203,6 +1204,9 @@ gpu_kernel_library( "i64", "ui8", ], + op = "minimum", + tile_size = "1024", + types = [], unroll_factors = "4", ) @@ -1254,9 +1258,7 @@ gpu_kernel_library( gpu_kernel_library( name = "gpu_neg_kernels", - op = "neg", - tile_size = "256", - types = [ + jit_types = [ "f16", "f32", "f64", @@ -1267,6 +1269,9 @@ gpu_kernel_library( "c64", "c128", ], + op = "neg", + tile_size = "256", + types = [], unroll_factors = "4", ) @@ -1275,22 +1280,19 @@ gpu_kernel_library( jit_types = [ "i8", "i16", - ], - op = "pow", - tile_size = "1024", - types = [ "f16", "f32", "f64", "i64", ], + op = "pow", + tile_size = "1024", + types = [], ) gpu_kernel_library( name = "gpu_reciprocal_kernels", - op = "reciprocal", - tile_size = "256", - types = [ + jit_types = [ "c64", "c128", "f16", @@ -1298,6 +1300,9 @@ gpu_kernel_library( "f64", "i64", ], + op = "reciprocal", + tile_size = "256", + types = [], unroll_factors = "4", ) @@ -1306,10 +1311,6 @@ gpu_kernel_library( jit_types = [ "i8", "i16", - ], - op = "sign", - tile_size = "256", - types = [ "f16", "f32", "f64", @@ -1318,6 +1319,9 @@ gpu_kernel_library( "c64", "c128", ], + op = "sign", + tile_size = "256", + types = [], unroll_factors = "4", ) @@ -1361,80 +1365,86 @@ gpu_kernel_library( gpu_kernel_library( name = "gpu_xdivy_kernels", - op = "xdivy", - tile_size = "1024", - types = [ + jit_types = [ "f16", "f32", "f64", "c64", "c128", ], + op = "xdivy", + tile_size = "1024", + types = [], unroll_factors = "4", ) # Logarithmic and exponential kernels gpu_kernel_library( name = "gpu_exp_kernels", - op = "exp", - tile_size = "256", - types = [ + jit_types = [ "f16", "f32", "f64", "c64", "c128", ], + op = "exp", + tile_size = "256", + types = [], unroll_factors = "4", ) gpu_kernel_library( name = "gpu_expm1_kernels", - op = "expm1", - tile_size = "256", - types = [ + jit_types = [ "f16", "f32", "f64", ], + op = "expm1", + tile_size = "256", + types = [], unroll_factors = "4", ) gpu_kernel_library( name = "gpu_log_kernels", - op = "log", - tile_size = "256", - types = [ + jit_types = [ "f16", "f32", "f64", ], + op = "log", + tile_size = "256", + types = [], unroll_factors = "4", ) gpu_kernel_library( name = "gpu_log1p_kernels", - op = "log1p", - tile_size = "256", - types = [ + jit_types = [ "f16", "f32", "f64", ], + op = "log1p", + tile_size = "256", + types = [], unroll_factors = "4", ) gpu_kernel_library( name = "gpu_xlogy_kernels", - op = "xlogy", - tile_size = "1024", - types = [ + jit_types = [ "f16", "f32", "f64", "c64", "c128", ], + op = "xlogy", + tile_size = "1024", + types = [], unroll_factors = "4", # For complex XlogyOp kernels, we don't use unrolling, it would only cause # slowdowns. @@ -1446,15 +1456,16 @@ gpu_kernel_library( gpu_kernel_library( name = "gpu_xlog1py_kernels", - op = "xlog1py", - tile_size = "1024", - types = [ + jit_types = [ "f16", "f32", "f64", "c64", "c128", ], + op = "xlog1py", + tile_size = "1024", + types = [], unroll_factors = "4", # For complex Xlog1pyOp kernels, we don't use unrolling, it would only cause # slowdowns. @@ -1468,25 +1479,27 @@ gpu_kernel_library( gpu_kernel_library( name = "gpu_sqrt_kernels", - op = "sqrt", - tile_size = "256", - types = [ + jit_types = [ "f16", "f32", "f64", ], + op = "sqrt", + tile_size = "256", + types = [], unroll_factors = "4", ) gpu_kernel_library( name = "gpu_rsqrt_kernels", - op = "rsqrt", - tile_size = "256", - types = [ + jit_types = [ "f16", "f32", "f64", ], + op = "rsqrt", + tile_size = "256", + types = [], unroll_factors = "4", ) @@ -1499,28 +1512,28 @@ gpu_kernel_library( "ui16", "ui32", "ui64", - ], - op = "square", - tile_size = "1024", - types = [ "f16", "f32", "f64", "i64", ], + op = "square", + tile_size = "1024", + types = [], unroll_factors = "4", ) gpu_kernel_library( name = "gpu_squared_difference_kernels", - op = "squared_difference", - tile_size = "1024", - types = [ + jit_types = [ "f16", "f32", "f64", "i64", ], + op = "squared_difference", + tile_size = "1024", + types = [], unroll_factors = "4", ) @@ -1528,74 +1541,77 @@ gpu_kernel_library( gpu_kernel_library( name = "gpu_bitwise_and_kernels", - op = "bitwise_and", - tile_size = "1024", - types = [ + jit_types = [ "i8", "i16", "i32", "i64", ], + op = "bitwise_and", + tile_size = "1024", + types = [], unroll_factors = "4", ) gpu_kernel_library( name = "gpu_bitwise_or_kernels", - op = "bitwise_or", - tile_size = "1024", - types = [ + jit_types = [ "i8", "i16", "i32", "i64", ], + op = "bitwise_or", + tile_size = "1024", + types = [], unroll_factors = "4", ) gpu_kernel_library( name = "gpu_bitwise_xor_kernels", - op = "bitwise_xor", - tile_size = "1024", - types = [ + jit_types = [ "i8", "i16", "i32", "i64", ], + op = "bitwise_xor", + tile_size = "1024", + types = [], unroll_factors = "4", ) gpu_kernel_library( name = "gpu_invert_kernels", - op = "invert", - tile_size = "1024", - types = [ + jit_types = [ "i8", "i16", "i32", "i64", ], + op = "invert", + tile_size = "1024", + types = [], unroll_factors = "4", ) gpu_kernel_library( name = "gpu_left_shift_kernels", - op = "left_shift", - tile_size = "1024", - types = [ + jit_types = [ "i8", "i16", "i32", "i64", ], + op = "left_shift", + tile_size = "1024", + types = [], unroll_factors = "4", ) gpu_kernel_library( name = "gpu_right_shift_kernels", - op = "right_shift", - tile_size = "1024", - types = [ + jit_types = [ "i8", "i16", "i32", @@ -1605,6 +1621,9 @@ gpu_kernel_library( "ui32", "ui64", ], + op = "right_shift", + tile_size = "1024", + types = [], unroll_factors = "4", ) @@ -1612,52 +1631,57 @@ gpu_kernel_library( gpu_kernel_library( name = "gpu_logical_not_kernels", + jit_types = ["i1"], op = "logical_not", tile_size = "256", - types = ["i1"], + types = [], ) gpu_kernel_library( name = "gpu_logical_and_kernels", - op = "logical_and", - tile_size = "1024", - types = [ + jit_types = [ "i1", ], + op = "logical_and", + tile_size = "1024", + types = [], ) gpu_kernel_library( name = "gpu_logical_or_kernels", - op = "logical_or", - tile_size = "1024", - types = [ + jit_types = [ "i1", ], + op = "logical_or", + tile_size = "1024", + types = [], ) # Erf kernels gpu_kernel_library( name = "gpu_erf_kernels", - op = "erf", - tile_size = "256", - types = [ + jit_types = [ "f16", "f32", "f64", ], + op = "erf", + tile_size = "256", + types = [], unroll_factors = "4", ) gpu_kernel_library( name = "gpu_erfc_kernels", - op = "erfc", - tile_size = "256", - types = [ + jit_types = [ "f16", "f32", "f64", ], + op = "erfc", + tile_size = "256", + types = [], unroll_factors = "4", ) @@ -1665,45 +1689,49 @@ gpu_kernel_library( gpu_kernel_library( name = "gpu_polygamma_kernels", - op = "polygamma", - tile_size = "256", - types = [ + jit_types = [ "f32", "f64", ], + op = "polygamma", + tile_size = "256", + types = [], ) gpu_kernel_library( name = "gpu_digamma_kernels", - op = "digamma", - tile_size = "256", - types = [ + jit_types = [ "f16", "f32", "f64", ], + op = "digamma", + tile_size = "256", + types = [], ) gpu_kernel_library( name = "gpu_lgamma_kernels", - op = "lgamma", - tile_size = "256", - types = [ + jit_types = [ "f16", "f32", "f64", ], + op = "lgamma", + tile_size = "256", + types = [], ) gpu_kernel_library( # The zeta kernels needs many registers so tile at 256. name = "gpu_zeta_kernels", - op = "zeta", - tile_size = "256", - types = [ + jit_types = [ "f32", "f64", ], + op = "zeta", + tile_size = "256", + types = [], # TODO(b/178388085): Enable unrolling after vectorization is fixed. # unroll_factors = "4", ) @@ -1730,61 +1758,64 @@ gpu_kernel_library( "ui16", "ui32", "ui64", - ], - op = "relu", - tile_size = "256", - types = [ "f16", "f32", "f64", ], + op = "relu", + tile_size = "256", + types = [], unroll_factors = "16B", ) gpu_kernel_library( name = "gpu_elu_kernels", - op = "elu", - tile_size = "256", - types = [ + jit_types = [ "f16", "f32", "f64", ], + op = "elu", + tile_size = "256", + types = [], ) gpu_kernel_library( name = "gpu_selu_kernels", - op = "selu", - tile_size = "256", - types = [ + jit_types = [ "f16", "f32", "f64", ], + op = "selu", + tile_size = "256", + types = [], ) gpu_kernel_library( name = "gpu_sigmoid_kernels", - op = "sigmoid", - tile_size = "256", - types = [ + jit_types = [ "f16", "f32", "f64", ], + op = "sigmoid", + tile_size = "256", + types = [], ) # Kernels that support all floating-point types. [ gpu_kernel_library( name = "gpu_" + op + "_kernels", - op = op, - tile_size = "256", - types = [ + jit_types = [ "f16", "f32", "f64", ], + op = op, + tile_size = "256", + types = [], unroll_factors = "4", ) for op in [ @@ -1836,11 +1867,6 @@ gpu_kernel_library( "ui16", "ui32", "ui64", - ], - max_supported_rank = 8, - op = "select_v2", - tile_size = "256", - types = [ "i1", "i32", "i64", @@ -1850,6 +1876,10 @@ gpu_kernel_library( "c64", "c128", ], + max_supported_rank = 8, + op = "select_v2", + tile_size = "256", + types = [], ) gpu_kernel_library( @@ -1861,10 +1891,6 @@ gpu_kernel_library( "ui16", "ui32", "ui64", - ], - op = "zeros_like", - tile_size = "1024", - types = [ "i1", "i64", "f16", @@ -1873,6 +1899,9 @@ gpu_kernel_library( "c64", "c128", ], + op = "zeros_like", + tile_size = "1024", + types = [], ) gpu_kernel_library( @@ -1884,10 +1913,6 @@ gpu_kernel_library( "ui16", "ui32", "ui64", - ], - op = "ones_like", - tile_size = "1024", - types = [ "i1", "i64", "f16", @@ -1896,14 +1921,18 @@ gpu_kernel_library( "c64", "c128", ], + op = "ones_like", + tile_size = "1024", + types = [], ) gpu_kernel_library( name = "gpu_next_after_kernels", - op = "next_after", - tile_size = "1024", - types = [ + jit_types = [ "f32", "f64", ], + op = "next_after", + tile_size = "1024", + types = [], ) diff --git a/tensorflow/python/kernel_tests/linalg/BUILD b/tensorflow/python/kernel_tests/linalg/BUILD index ada2c64403c01b..42fd131137743f 100644 --- a/tensorflow/python/kernel_tests/linalg/BUILD +++ b/tensorflow/python/kernel_tests/linalg/BUILD @@ -271,7 +271,7 @@ cuda_py_strict_test( name = "linear_operator_circulant_test", size = "medium", srcs = ["linear_operator_circulant_test.py"], - shard_count = 15, + shard_count = 32, tags = [ "no_cuda11", # TODO(b/197522782): reenable test after fixing. "optonly", # times out, b/79171797 diff --git a/tensorflow/python/ops/BUILD b/tensorflow/python/ops/BUILD index a409f96095a3af..1da5270622505d 100644 --- a/tensorflow/python/ops/BUILD +++ b/tensorflow/python/ops/BUILD @@ -3127,7 +3127,7 @@ py_strict_library( cuda_py_strict_test( name = "bitwise_ops_test", - size = "small", + size = "medium", srcs = ["bitwise_ops_test.py"], main = "bitwise_ops_test.py", python_version = "PY3", @@ -3472,7 +3472,7 @@ cuda_py_strict_test( cuda_py_strict_test( name = "math_grad_test", - size = "small", + size = "medium", srcs = ["math_grad_test.py"], main = "math_grad_test.py", python_version = "PY3", From 4a09465a95ba827a0b997d11e02f7f6b697e91ad Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 11 Jul 2023 06:55:27 -0700 Subject: [PATCH 124/376] Update TFRT dependency to use revision http://github.com/tensorflow/runtime/commit/4ed29c1d398da53ee2cfa5191868e367a0e4052f. PiperOrigin-RevId: 547176965 --- third_party/tf_runtime/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/tf_runtime/workspace.bzl b/third_party/tf_runtime/workspace.bzl index 889fb7909d092a..1f02109fb9dcf5 100644 --- a/third_party/tf_runtime/workspace.bzl +++ b/third_party/tf_runtime/workspace.bzl @@ -6,8 +6,8 @@ def repo(): """Imports TFRT.""" # Attention: tools parse and update these lines. - TFRT_COMMIT = "2a33f6928a9584cc8285b1a9e74e3336c41da8d6" - TFRT_SHA256 = "ed37ce13e860d1e3b340cfd0e2e63e0f1b2e3206f0e774e28f29f191c611a199" + TFRT_COMMIT = "4ed29c1d398da53ee2cfa5191868e367a0e4052f" + TFRT_SHA256 = "7d0c7a60785da8161d63574ff395229a61a0fd204134d1f670c9ea027df6f73d" tf_http_archive( name = "tf_runtime", From d072cfedb1fd3525ee6a38f07e6a08c5143757f3 Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Tue, 11 Jul 2023 07:13:25 -0700 Subject: [PATCH 125/376] [XLA:GPU][NFC] Clean up unused function and dependencies in softmax_rewriter_triton and its corresponding test file. PiperOrigin-RevId: 547180731 --- tensorflow/compiler/xla/service/gpu/BUILD | 1 + .../xla/service/gpu/softmax_rewriter_triton.cc | 13 ------------- .../xla/service/gpu/softmax_rewriter_triton_test.cc | 1 + 3 files changed, 2 insertions(+), 13 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 3f6e361c00c65b..13721c24a8285e 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -1863,6 +1863,7 @@ xla_cc_test( srcs = ["softmax_rewriter_triton_test.cc"], deps = [ ":softmax_rewriter_triton", + "//tensorflow/compiler/xla/service:pattern_matcher", "//tensorflow/compiler/xla/service:pattern_matcher_gmock", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", # build_cleaner: keep diff --git a/tensorflow/compiler/xla/service/gpu/softmax_rewriter_triton.cc b/tensorflow/compiler/xla/service/gpu/softmax_rewriter_triton.cc index ee203f7a1bb98b..a5e59e1be3d16b 100644 --- a/tensorflow/compiler/xla/service/gpu/softmax_rewriter_triton.cc +++ b/tensorflow/compiler/xla/service/gpu/softmax_rewriter_triton.cc @@ -42,19 +42,6 @@ bool HasDefaultLayout(const Shape& shape) { LayoutUtil::IsMonotonicWithDim0Major(shape.layout()); } -bool IsSupportedReductionComputation(HloComputation* computation) { - static const absl::flat_hash_set* const kSupportedOpcodes = - new absl::flat_hash_set{HloOpcode::kAdd, HloOpcode::kMaximum}; - - HloInstruction* root = computation->root_instruction(); - if (root->operand_count() != 2 || - root->operand(0)->opcode() != HloOpcode::kParameter || - root->operand(1)->opcode() != HloOpcode::kParameter) { - return false; - } - return kSupportedOpcodes->contains(root->opcode()); -} - bool IsTritonSupportedInstruction(const HloInstruction* instr) { // TODO(bchetioui): expand with non-trivial instructions. if (instr->IsElementwise()) { diff --git a/tensorflow/compiler/xla/service/gpu/softmax_rewriter_triton_test.cc b/tensorflow/compiler/xla/service/gpu/softmax_rewriter_triton_test.cc index 77f49d8a423b6c..00d77002725f0e 100644 --- a/tensorflow/compiler/xla/service/gpu/softmax_rewriter_triton_test.cc +++ b/tensorflow/compiler/xla/service/gpu/softmax_rewriter_triton_test.cc @@ -14,6 +14,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/service/pattern_matcher.h" #include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" From 72095f7edeef942512adb200fde15ad59534f3d5 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 11 Jul 2023 08:18:50 -0700 Subject: [PATCH 126/376] Integrate LLVM at llvm/llvm-project@cf410b181f8c Updates LLVM usage to match [cf410b181f8c](https://github.com/llvm/llvm-project/commit/cf410b181f8c) PiperOrigin-RevId: 547195319 --- third_party/llvm/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl index 7a772fb5657237..d544405fd05954 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "86943d863ef66d68bf79d3e2f0ec2c205814b235" - LLVM_SHA256 = "b37024a8d88985b69b240e4222932379f794906f602464c4c31c516580508a93" + LLVM_COMMIT = "cf410b181f8c546b9ae4cd65a82d08e65bacec82" + LLVM_SHA256 = "b46fea00b4d661444425f4dcd39f5eb12f6a5d8c4964e8e0f3c8e0e601490476" tf_http_archive( name = name, From 7a8f0039c81859ab72a73283a5ca361856897eed Mon Sep 17 00:00:00 2001 From: Ilia Sergachev Date: Tue, 11 Jul 2023 08:24:28 -0700 Subject: [PATCH 127/376] [XLA:GPU] Fuse more inputs into Triton GEMMs. - Let the GEMM rewriter do more complex traversals of inputs and fuse elementwise operations and broadcasts of scalar constants. - Limit the number of parameters per fusion. - Reorganize GPU compiler pipeline: bf16 float normalization is now required both before and after Triton GEMM fusion. - Remove an autotuner config that for unknown reasons fails on Volta with new fusions. PiperOrigin-RevId: 547196631 --- .../compiler/xla/debug_options_flags.cc | 6 + tensorflow/compiler/xla/service/gpu/BUILD | 7 + .../xla/service/gpu/gemm_rewriter_triton.cc | 444 ++++++++++++------ .../xla/service/gpu/gemm_rewriter_triton.h | 49 +- .../service/gpu/gemm_rewriter_triton_test.cc | 156 +++++- .../compiler/xla/service/gpu/gpu_compiler.cc | 37 +- .../xla/service/gpu/ir_emitter_triton.cc | 2 +- .../xla/service/gpu/ir_emitter_triton_test.cc | 161 +++++++ .../xla/service/gpu/triton_autotuner.cc | 11 +- tensorflow/compiler/xla/xla.proto | 4 +- 10 files changed, 705 insertions(+), 172 deletions(-) diff --git a/tensorflow/compiler/xla/debug_options_flags.cc b/tensorflow/compiler/xla/debug_options_flags.cc index 299635b30746e0..b3d288bde79532 100644 --- a/tensorflow/compiler/xla/debug_options_flags.cc +++ b/tensorflow/compiler/xla/debug_options_flags.cc @@ -138,6 +138,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_gpu_enable_cudnn_int8x32_convolution_reordering(true); opts.set_xla_gpu_triton_gemm_any(false); opts.set_xla_gpu_enable_triton_softmax_fusion(false); + opts.set_xla_gpu_triton_fusion_level(1); // Moving reduce-scatter out of while loops can increase memory footprint, so // turning it off by default. @@ -1130,6 +1131,11 @@ void MakeDebugOptionsFlags(std::vector* flag_list, "Forces any reductions during matrix multiplications to use the " "accumulator type and not the output type. The precision of the dot " "operation may not increase that much if there is output fusion.")); + flag_list->push_back(tsl::Flag( + "xla_gpu_triton_fusion_level", + int32_setter_for(&DebugOptions::set_xla_gpu_triton_fusion_level), + debug_options->xla_gpu_triton_fusion_level(), + "Triton fusion level, higher levels mean more fused operations.")); } // NOLINT(readability/fn_size) // Allocates flag_values and flag_objects; this function must not be called more diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 13721c24a8285e..741baed5ca1e85 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -434,6 +434,7 @@ cc_library( "//tensorflow/tsl/platform:errors", "//tensorflow/tsl/platform:logging", "//tensorflow/tsl/platform:path", + "//tensorflow/tsl/platform:statusor", "//tensorflow/tsl/platform:tensor_float_32_hdr_lib", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", @@ -489,6 +490,8 @@ xla_test( "//tensorflow/compiler/xla:autotuning_proto_cc", "//tensorflow/compiler/xla:error_spec", "//tensorflow/compiler/xla/hlo/ir:hlo", + "//tensorflow/compiler/xla/service:pattern_matcher", + "//tensorflow/compiler/xla/service:pattern_matcher_gmock", "//tensorflow/compiler/xla/service/gpu/tests:gpu_codegen_test", "//tensorflow/compiler/xla/stream_executor:device_description", "//tensorflow/compiler/xla/stream_executor/cuda:cublas_plugin", @@ -1152,18 +1155,22 @@ cc_library( "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status", + "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/hlo/ir:hlo", "//tensorflow/compiler/xla/hlo/utils:hlo_query", "//tensorflow/compiler/xla/service:hlo_creation_utils", "//tensorflow/compiler/xla/service:hlo_pass", + "//tensorflow/compiler/xla/service:instruction_fusion", + "//tensorflow/compiler/xla/stream_executor:device_description", "//tensorflow/tsl/platform:errors", "//tensorflow/tsl/platform:status", "//tensorflow/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", ], diff --git a/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton.cc b/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton.cc index 6b28352ccd61ab..20971738a95289 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton.cc @@ -22,12 +22,15 @@ limitations under the License. #include #include #include +#include #include #include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/log/check.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/autotuning.pb.h" @@ -37,6 +40,7 @@ limitations under the License. #include "tensorflow/compiler/xla/hlo/ir/hlo_instruction.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_instructions.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_opcode.h" +#include "tensorflow/compiler/xla/hlo/ir/hlo_schedule.h" #include "tensorflow/compiler/xla/hlo/utils/hlo_query.h" #include "tensorflow/compiler/xla/layout.h" #include "tensorflow/compiler/xla/literal_util.h" @@ -46,9 +50,12 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/gpu/matmul_utils.h" #include "tensorflow/compiler/xla/service/hlo_creation_utils.h" +#include "tensorflow/compiler/xla/service/instruction_fusion.h" #include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/stream_executor/device_description.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/tsl/platform/errors.h" @@ -57,6 +64,25 @@ limitations under the License. namespace xla { namespace gpu { + +bool TensorIterationSpec::operator==(const TensorIterationSpec& other) const { + for (int dim = 0; dim < TensorIterationSpec::kMaxDimsPerTensor; ++dim) { + if (dim_iteration_specs_[dim].size() != other[dim].size()) { + return false; + } + for (int fragment = 0; fragment < dim_iteration_specs_[dim].size(); + ++fragment) { + if (dim_iteration_specs_[dim][fragment].stride != + other[dim][fragment].stride || + dim_iteration_specs_[dim][fragment].count != + other[dim][fragment].count) { + return false; + } + } + } + return true; +} + namespace { // Batch dimensions of an operand of a dot instruction. @@ -95,10 +121,10 @@ int64_t NonContractingDimensionIndex(const HloInstruction& dot, } // Data types that are tested to work in the triton GEMM emitter. -bool IsSupportedDataType(PrimitiveType t, GpuVersion gpu_version) { +bool IsSupportedDataType(PrimitiveType type, GpuVersion gpu_version) { auto cuda_compute_capability = std::get(gpu_version); - switch (t) { + switch (type) { case PRED: case S8: case S16: @@ -114,21 +140,19 @@ bool IsSupportedDataType(PrimitiveType t, GpuVersion gpu_version) { } } -Status RequireTritonFusibleConvert(const HloInstruction* input, - GpuVersion gpu_version) { - if (!IsSupportedDataType(input->operand(0)->shape().element_type(), - gpu_version)) { - return Unimplemented("unsupported data type"); +// Let input and output data volumes of a fusion grow by small amounts. +constexpr int64_t kIoToleranceBytes = 1024; + +// Difference of input and output data volumes of an instruction. +int64_t InputMinusOutputBytes(const HloInstruction& hlo) { + CHECK(!hlo.shape().IsTuple()); + int64_t output_size = ShapeUtil::ByteSizeOf(hlo.shape()); + int64_t input_size = 0; + for (const HloInstruction* operand : hlo.operands()) { + CHECK(!operand->shape().IsTuple()); + input_size += ShapeUtil::ByteSizeOf(operand->shape()); } - // TODO(b/266862494): Can pick up almost any - // convert, but if it's reducing the data volume it should rather be fused - // to the output of the producer kernel. However not all operations support - // output fusion - then it should be fused here anyway! - if (ShapeUtil::ByteSizeOf(input->operand(0)->shape()) > - ShapeUtil::ByteSizeOf(input->shape())) { - return FailedPrecondition("narrowing conversion"); - } - return OkStatus(); + return input_size - output_size; } // Handles numbers of dimensions of a target HLO instruction @@ -142,6 +166,13 @@ class DimensionOrder { int64_t target_dim_number; int subdim_number; int64_t size; + bool operator==(const DimDescription& other) const { + return target_dim_number == other.target_dim_number && + subdim_number == other.subdim_number && size == other.size; + } + std::string ToString() const { + return absl::StrCat(target_dim_number, ":", subdim_number, ":", size); + } }; // Sequence describing all dimensions of HLO's output shape // in layout minor-to-major (physical) order. @@ -171,34 +202,35 @@ class DimensionOrder { // Transforms the DimensionOrder so that from a description of the output // of `hlo` it becomes a description of the input of `hlo`. - Status HandleInstruction(const HloInstruction* hlo) { + FusionDecision HandleInstruction(const HloInstruction* hlo) { VLOG(7) << hlo->ToString(); - if (hlo->opcode() == HloOpcode::kParameter) { - return OkStatus(); + if (hlo->opcode() == HloOpcode::kParameter || + hlo->opcode() == HloOpcode::kConstant) { + return FusionDecision{}; } else if (hlo->opcode() == HloOpcode::kTranspose || hlo->opcode() == HloOpcode::kCopy) { return HandleCopyOrTranspose(hlo); } else if (hlo->operand_count() > 0 && IsTritonSupportedElementwise( hlo->opcode(), hlo->operand(0)->shape().element_type())) { - return OkStatus(); + return FusionDecision{}; } else if (hlo->opcode() == HloOpcode::kBitcast) { return HandleBitcast(hlo); } else if (hlo->opcode() == HloOpcode::kReshape) { if (!ShapeUtil::ReshapeIsBitcast(hlo->operand(0)->shape(), hlo->shape())) { - return Unimplemented("Non-bitcast reshape."); + return "Non-bitcast reshape."; } return HandleBitcast(hlo); } else if (hlo_query::IsScalarConstant(hlo) || hlo_query::IsBroadcastOfScalarConstant(*hlo)) { // Dimension order collapses on a scalar, for simplicity leave it equal // to the output one for now. - return OkStatus(); + return FusionDecision{}; } else { - return Unimplemented("Instruction: %s", hlo->ToString()); + return "Unimplemented instruction."; } - return OkStatus(); + return FusionDecision{}; } // Get the raw data of the dimension order. @@ -210,20 +242,32 @@ class DimensionOrder { return splittable_dimension_index_; } + // Tells that two dimension orders describe the same tensor physical layout. + bool IsPhysicallyEquivalent(const DimensionOrder& other) const; + + std::string ToString() const { + return absl::StrJoin(dim_order_, "-", + [](std::string* out, const DimDescription& d) { + absl::StrAppend(out, d.ToString()); + }); + } + private: // See HandleInstruction() for the general description of Handle*(). - Status HandleBitcast(const HloInstruction* hlo); - Status HandleCopyOrTranspose(const HloInstruction* hlo); + FusionDecision HandleBitcast(const HloInstruction* hlo); + FusionDecision HandleCopyOrTranspose(const HloInstruction* hlo); DimOrderVector dim_order_; - int64_t splittable_dimension_index_; + const int64_t splittable_dimension_index_; }; -DotFusionAnalysis::TensorIterationSpec DimensionOrderToTensorIterationSpec( +using DimIterationSpec = TensorIterationSpec::DimIterationSpec; + +TensorIterationSpec DimensionOrderToTensorIterationSpec( const DimensionOrder& order) { const DimensionOrder::DimOrderVector& dim_order_vector = order.GetDimOrderVector(); - DotFusionAnalysis::TensorIterationSpec tensor_spec; + TensorIterationSpec tensor_spec; int64_t accumulated_stride = 1; for (int dim_order_index = 0; dim_order_index < dim_order_vector.size(); ++dim_order_index) { @@ -236,8 +280,7 @@ DotFusionAnalysis::TensorIterationSpec DimensionOrderToTensorIterationSpec( continue; } - DotFusionAnalysis::DimIterationSpec& dim_spec = - tensor_spec[dim.target_dim_number]; + DimIterationSpec& dim_spec = tensor_spec[dim.target_dim_number]; if (dim_order_index > 0 && dim_order_vector[dim_order_index - 1].target_dim_number == dim.target_dim_number) { @@ -257,7 +300,7 @@ DotFusionAnalysis::TensorIterationSpec DimensionOrderToTensorIterationSpec( accumulated_stride *= dim.size; } // Create all absent dimensions as degenerate ones to simplify later queries. - for (DotFusionAnalysis::DimIterationSpec& dim_spec : tensor_spec) { + for (DimIterationSpec& dim_spec : tensor_spec) { if (dim_spec.empty()) { dim_spec.push_back({/*stride=*/0, /*count=*/1, /*subfragments=*/{1}}); } @@ -265,6 +308,11 @@ DotFusionAnalysis::TensorIterationSpec DimensionOrderToTensorIterationSpec( return tensor_spec; } +bool DimensionOrder::IsPhysicallyEquivalent(const DimensionOrder& other) const { + return DimensionOrderToTensorIterationSpec(*this) == + DimensionOrderToTensorIterationSpec(other); +} + DimensionOrder DimensionOrder::FromDotOperand(const HloInstruction& dot, const int operand_number, const int64_t split_k) { @@ -287,7 +335,7 @@ DimensionOrder DimensionOrder::FromDotOutput(const HloInstruction& dot) { return DimensionOrder(&dot); } -Status DimensionOrder::HandleBitcast(const HloInstruction* hlo) { +FusionDecision DimensionOrder::HandleBitcast(const HloInstruction* hlo) { const Shape& operand_shape = hlo->operand(0)->shape(); DimOrderVector operand_dim_order; operand_dim_order.reserve(dim_order_.size()); @@ -301,7 +349,7 @@ Status DimensionOrder::HandleBitcast(const HloInstruction* hlo) { ++out_dim) { if (operand_remaining_size >= out_dim->size) { if (operand_remaining_size % out_dim->size) { - return Unimplemented("Unsupported bitcast: %s", hlo->ToString()); + return "Unsupported bitcast"; } // Output dimension fragment completely fits into the operand one: // just copy it as is. @@ -319,7 +367,7 @@ Status DimensionOrder::HandleBitcast(const HloInstruction* hlo) { // If there is a remaining fragment of a previous operand dimension // assign it first. if (out_remaining_size % operand_remaining_size) { - return Unimplemented("Unsupported bitcast: %s", hlo->ToString()); + return "Unsupported bitcast"; } operand_dim_order.push_back( {out_dim->target_dim_number, subdim_index, operand_remaining_size}); @@ -337,7 +385,7 @@ Status DimensionOrder::HandleBitcast(const HloInstruction* hlo) { // assign the remainder of the output and carry over the remainder // of the operand. if (operand_dim_size % out_remaining_size) { - return Unimplemented("Unsupported bitcast: %s", hlo->ToString()); + return "Unsupported bitcast"; } operand_remaining_size = operand_dim_size / out_remaining_size; new_fragment_size = out_remaining_size; @@ -358,7 +406,7 @@ Status DimensionOrder::HandleBitcast(const HloInstruction* hlo) { int subdim_index = operand_dim_order.back().subdim_number + 1; while (operand_dim_iter != operand_shape.layout().minor_to_major().cend()) { if (operand_shape.dimensions(*operand_dim_iter) != 1) { - return Unimplemented("Unsupported bitcast: %s", hlo->ToString()); + return "Unsupported bitcast"; } operand_dim_order.push_back( {operand_dim_order.back().target_dim_number, subdim_index, 1}); @@ -367,10 +415,11 @@ Status DimensionOrder::HandleBitcast(const HloInstruction* hlo) { } dim_order_ = operand_dim_order; - return OkStatus(); + return FusionDecision{}; } -Status DimensionOrder::HandleCopyOrTranspose(const HloInstruction* hlo) { +FusionDecision DimensionOrder::HandleCopyOrTranspose( + const HloInstruction* hlo) { // Every HLO dimension can correspond to a group of subdimensions in // dim_order_. For the easier handling of permutations: group dim_order_ by // dimension, apply permutations, then finally remove the grouping. @@ -419,25 +468,25 @@ Status DimensionOrder::HandleCopyOrTranspose(const HloInstruction* hlo) { dim_order_.push_back(subdim); } } - return OkStatus(); + return FusionDecision{}; } // Tells if the dimension order is supported by the triton GEMM emitter. // Only the dimension indicated by SplittableDimensionIndex() can be split // physically once by other dimensions. Other ones can be only split logically. // All subdimensions within a dimension have to be ordered. -Status RequireTritonGemmSupportedDimOrder(const DimensionOrder& order) { - std::array subdim_counters = { +FusionDecision RequireTritonGemmSupportedDimOrder(const DimensionOrder& order) { + std::array subdim_counters = { -1, -1, -1, -1}; - std::array split_counters = { + std::array split_counters = { -1, -1, -1, -1}; const DimensionOrder::DimOrderVector& dim_order_vector = order.GetDimOrderVector(); + VLOG(8) << order.ToString(); for (int i = 0; i < dim_order_vector.size(); i++) { const auto [dim_number, subdim_number, size] = dim_order_vector[i]; - VLOG(8) << dim_number << "\t" << subdim_number << "\t" << size; if (subdim_counters[dim_number] != subdim_number - 1) { - return Unimplemented("Transpose within a dimension."); + return "Transpose within a dimension."; } ++subdim_counters[dim_number]; if (size == 1) { @@ -447,31 +496,185 @@ Status RequireTritonGemmSupportedDimOrder(const DimensionOrder& order) { ++split_counters[dim_number]; if (dim_number == order.SplittableDimensionIndex()) { if (split_counters[dim_number] > 1) { - return Unimplemented("2nd split of a splittable dimension."); + return "2nd split of a splittable dimension."; } } else if (split_counters[dim_number] > 0) { - return Unimplemented("Split of a non-splittable dimension."); + return "Split of a non-splittable dimension."; } } } - return OkStatus(); + return FusionDecision{}; } -// Transforms dim_order describing the output of `hlo` into a +// Tells if an instruction has no input into which it could be fused. +// More cases should be added here. +bool CanNotBeFusedIntoAProducer(const HloInstruction& hlo) { + return hlo_query::AllOperandsAreParametersOrConstants(hlo); +} + +// Tells that fusing an instruction is efficient. +bool IsInputWorthFusing(const HloInstruction& hlo) { + return hlo_query::AllOperandsAreParametersOrConstants(hlo) || + InputMinusOutputBytes(hlo) < kIoToleranceBytes; +} + +// Checks if the instruction is possible and profitable to fuse. +// If so tries to transform dim_order describing output of `hlo` into a // description of its input if it is supported by the triton GEMM emitter. -Status CanFuse(const HloInstruction* hlo, DimensionOrder& dim_order, - const GpuVersion gpu_version) { - if (hlo->opcode() == HloOpcode::kConvert) { - return RequireTritonFusibleConvert(hlo, gpu_version); - } else if (hlo->IsElementwise() && hlo->opcode() != HloOpcode::kCopy) { - // Temporarily forbid fusing elementwise operations - // other than copy and convert. - return Unimplemented("Unsupported elementwise operation"); +FusionDecision CanFuse(const HloInstruction& hlo, DimensionOrder& dim_order, + const GpuVersion gpu_version) { + if (hlo.opcode() == HloOpcode::kTuple || + hlo.opcode() == HloOpcode::kGetTupleElement) { + return "Unsupported instruction."; + } + for (const HloInstruction* operand : hlo.operands()) { + if (!IsSupportedDataType(operand->shape().element_type(), gpu_version)) { + return "Unsupported input data type."; + } + } + if (!IsSupportedDataType(hlo.shape().element_type(), gpu_version)) { + return "Unsupported output data type."; + } + if (hlo.IsConstant()) { + return "Not fusing a constant."; + } + if (hlo.opcode() == HloOpcode::kBroadcast) { + return "Not fusing a broadcast."; + } + if (!CanNotBeFusedIntoAProducer(hlo) && !IsInputWorthFusing(hlo)) { + return "Not obviously profitable to fuse as input."; + } + if (hlo.IsElementwise() && hlo.opcode() != HloOpcode::kCopy && + hlo.opcode() != HloOpcode::kConvert && + hlo.GetModule()->config().debug_options().xla_gpu_triton_fusion_level() < + 2) { + return "Skipping most elementwise operations at low fusion levels."; + } + if (FusionDecision decision = dim_order.HandleInstruction(&hlo); !decision) { + return decision; } - TF_RETURN_IF_ERROR(dim_order.HandleInstruction(hlo)); return RequireTritonGemmSupportedDimOrder(dim_order); } +// Clone an instruction into the fusion. +void Fuse(HloInstruction& hlo, + absl::flat_hash_map& + old_to_new_mapping, + std::vector& call_operands, + HloComputation::Builder& builder) { + if (old_to_new_mapping.contains(&hlo)) { + return; + } + VLOG(3) << "Fusing " << hlo.ToString(); + auto get_or_add_parameter = [&](HloInstruction& instr) { + if (auto it = old_to_new_mapping.find(&instr); + it != old_to_new_mapping.end()) { + return it->second; + } + call_operands.push_back(&instr); + return old_to_new_mapping + .insert({&instr, + builder.AddInstruction(HloInstruction::CreateParameter( + call_operands.size() - 1, instr.shape(), + absl::StrCat("parameter_", call_operands.size() - 1)))}) + .first->second; + }; + if (hlo.opcode() == HloOpcode::kParameter || + hlo.opcode() == HloOpcode::kGetTupleElement) { + get_or_add_parameter(hlo); + } else { + std::vector hlo_new_operands; + for (HloInstruction* operand : hlo.operands()) { + hlo_new_operands.push_back(get_or_add_parameter(*operand)); + } + old_to_new_mapping[&hlo] = builder.AddInstruction( + hlo.CloneWithNewOperands(hlo.shape(), hlo_new_operands)); + } +} + +// Tells how many new parameters does a fusion gain by fusing the operation as +// an input. +int64_t NumAddedParameters(const HloInstruction& hlo) { + // Non-scalar constant is equivalent to a parameter: one input, one output. + if (hlo.opcode() == HloOpcode::kConstant && + !ShapeUtil::IsScalar(hlo.shape())) { + return 0; + } + // All other instructions add all own inputs and remove own single output. + return hlo.operand_count() - 1; +} + +// Fuse an instruction with all its fusible inputs. +// If an input is not fusible stop there and make a parameter of the new +// fusion, otherwise put it onto stack and check its own inputs first. +void FuseWithInputsRecursively( + HloInstruction* root, DimensionOrder root_dim_order, + // Dimension orders describing inputs of corresponding instructions. + absl::flat_hash_map& dim_orders, + const GpuVersion gpu_version, + absl::flat_hash_map& + old_to_new_mapping, + std::vector& call_operands, + HloComputation::Builder& builder) { + absl::flat_hash_set visited; + std::stack to_fuse; + // Instructions at the edge 'to_fuse' that can either get fused too or + // become parameters of the fusion. Used to track the number of parameters + // of the fusion. + absl::flat_hash_set inputs; + // Currently only one physically unique dim order per scope is supported. + // Let it change while the scope has one input; afterwards require all + // of them to be physically compatible. + const HloInstruction* reference_dim_order_hlo = nullptr; + if (CanFuse(*root, root_dim_order, gpu_version)) { + to_fuse.push(root); + inputs.insert(root->operands().begin(), root->operands().end()); + // root_dim_order went through output -> input transformation here. + CHECK(dim_orders.insert({root, root_dim_order}).second) << root->ToString(); + } + visited.insert(root); + while (!to_fuse.empty()) { + bool top_is_ready_to_fuse = true; + HloInstruction* hlo = to_fuse.top(); + if (reference_dim_order_hlo == nullptr && hlo->operand_count() > 1) { + reference_dim_order_hlo = hlo; + } + for (HloInstruction* operand : hlo->mutable_operands()) { + if (visited.insert(operand).second) { + // Stop adding new parameters. + if (inputs.size() >= DotFusionAnalysis::kMaxParameterPerScope && + NumAddedParameters(*operand) > 0) { + continue; + } + // Operand's output is described by its consumer's input. + DimensionOrder operand_dim_order(dim_orders.at(hlo)); + // CanFuse() makes output -> input transformation of + // operand_dim_order if succeeds. + if (CanFuse(*operand, operand_dim_order, gpu_version)) { + if (reference_dim_order_hlo != nullptr && + !operand_dim_order.IsPhysicallyEquivalent( + dim_orders.at(reference_dim_order_hlo))) { + continue; + } + to_fuse.push(operand); + if (operand->opcode() != HloOpcode::kParameter) { + inputs.erase(operand); + } + inputs.insert(operand->operands().begin(), operand->operands().end()); + // Save the dimension order description of operand's input. + CHECK(dim_orders.insert({operand, operand_dim_order}).second) + << operand->ToString(); + top_is_ready_to_fuse = false; + } + } + } + if (top_is_ready_to_fuse) { + Fuse(*hlo, old_to_new_mapping, call_operands, builder); + to_fuse.pop(); + } + } +} + // Extracts into fused computations parts of HLO graph including dot() // operations that can target the triton GEMM emitter. class GemmRewriterTritonVisitor : public DfsHloRewriteVisitor { @@ -483,8 +686,9 @@ class GemmRewriterTritonVisitor : public DfsHloRewriteVisitor { // and replaces the original dot() with a call to the computation. Status HandleDot(HloInstruction* dot) override { VLOG(5) << dot->ToString(); - - if (!CanTritonHandleGEMM(*dot, gpu_version_)) { + FusionDecision can_handle = CanTritonHandleGEMM(*dot, gpu_version_); + if (!can_handle) { + VLOG(3) << can_handle.Explain(); return OkStatus(); } @@ -503,72 +707,28 @@ class GemmRewriterTritonVisitor : public DfsHloRewriteVisitor { std::string suggested_name = absl::StrCat("triton_gemm_", dot->name()); HloComputation::Builder builder( absl::StrCat(suggested_name, "_computation")); + std::vector call_operands; // Original instruction -> fused one. absl::flat_hash_map old_to_new_mapping; - absl::flat_hash_set visited; - std::vector call_operands; - // Traverse and fuse dot() inputs bottom-up starting from direct operands. - // If an input is not fusible stop there and make it a parameter of the new - // fusion, otherwise put it onto stack and check its own inputs first. - std::stack to_fuse; - // Dimension orders describing inputs of corresponding instructions. - absl::flat_hash_map dim_orders; - to_fuse.push(dot); - while (!to_fuse.empty()) { - bool top_is_ready_to_fuse = true; - HloInstruction* hlo = to_fuse.top(); - for (HloInstruction* operand : hlo->mutable_operands()) { - if (visited.insert(operand).second) { - DimensionOrder operand_dim_order = [&] { - // Direct dot inputs are described by default dimension orders. - if (operand == dot->operand(0)) { - return DimensionOrder::FromDotOperand(*dot, 0); - } else if (operand == dot->operand(1)) { - return DimensionOrder::FromDotOperand(*dot, 1); - } - // Otherwise operand's output is described by its consumer's input. - return DimensionOrder(dim_orders.at(hlo)); - }(); - // CanFuse() makes output -> input transformation of - // operand_dim_order if succeeds. - if (CanFuse(operand, operand_dim_order, gpu_version_).ok()) { - VLOG(3) << "Fusing " << operand->ToString(); - to_fuse.push(operand); - // Save the dimension order description of operand's input. - dim_orders.insert({operand, operand_dim_order}); - top_is_ready_to_fuse = false; - } - } - } - if (top_is_ready_to_fuse) { - if (hlo->opcode() == HloOpcode::kParameter || - hlo->opcode() == HloOpcode::kGetTupleElement) { - old_to_new_mapping[hlo] = - builder.AddInstruction(HloInstruction::CreateParameter( - call_operands.size(), hlo->shape(), - absl::StrCat("parameter_", call_operands.size()))); - call_operands.push_back(hlo); - } else { - std::vector hlo_new_operands; - for (HloInstruction* operand : hlo->operands()) { - const auto iter = old_to_new_mapping.find(operand); - if (iter != old_to_new_mapping.end()) { - hlo_new_operands.push_back(iter->second); - } else { - hlo_new_operands.push_back( - builder.AddInstruction(HloInstruction::CreateParameter( - call_operands.size(), operand->shape(), - absl::StrCat("parameter_", call_operands.size())))); - call_operands.push_back(operand); - } - } - old_to_new_mapping[hlo] = builder.AddInstruction( - hlo->CloneWithNewOperands(hlo->shape(), hlo_new_operands)); - } - to_fuse.pop(); - } - } + + auto fuse_inputs = [&](int operand_number) { + absl::flat_hash_map dim_orders; + int operand_count_before = call_operands.size(); + // Direct dot inputs have well defined dimension orders. + FuseWithInputsRecursively( + dot->mutable_operand(operand_number), + DimensionOrder::FromDotOperand(*dot, operand_number), dim_orders, + gpu_version_, old_to_new_mapping, call_operands, builder); + return call_operands.size() - operand_count_before; + }; + // Separate traversal from LHS and RHS inputs of the dot: they use + // differently shaped tiles but may go through same HLO graph nodes. + TF_RET_CHECK(fuse_inputs(0) <= DotFusionAnalysis::kMaxParameterPerScope); + TF_RET_CHECK(fuse_inputs(1) <= DotFusionAnalysis::kMaxParameterPerScope); + + Fuse(*dot, old_to_new_mapping, call_operands, builder); + HloComputation* computation = dot->GetModule()->AddComputationAndUnifyNamesAndIds(builder.Build(), /*is_entry=*/false); @@ -592,7 +752,7 @@ class GemmRewriterTritonVisitor : public DfsHloRewriteVisitor { } else { TF_RETURN_IF_ERROR(ReplaceInstruction(dot, dot_fusion)); } - VLOG(5) << computation->ToString(); + XLA_VLOG_LINES(5, computation->ToString()); return OkStatus(); } @@ -643,7 +803,7 @@ StatusOr MakeSplitKOperand( for (const HloInstruction* param : analysis.ScopeParameters(scope)) { // If an operand of dot does not read any parameters its K dimension // does not need analysis for fragmentation. - const DotFusionAnalysis::DimIterationSpec* spec = + const DimIterationSpec* spec = analysis.IterSpec(scope, param, contracting_dim_idx); // Split contracting dimension is not implemented yet. CHECK_EQ(spec->size(), 1); @@ -885,8 +1045,8 @@ DotFusionAnalysis::DotFusionAnalysis(const HloComputation* dot_computation, absl::flat_hash_map dim_orders; DimensionOrder dot_operand_dim_order = DimensionOrder::FromDotOperand(*dot, operand_number, split_k); - TF_CHECK_OK(dot_operand_dim_order.HandleInstruction(dot_operand)); - TF_CHECK_OK(RequireTritonGemmSupportedDimOrder(dot_operand_dim_order)) + CHECK(dot_operand_dim_order.HandleInstruction(dot_operand)); + CHECK(RequireTritonGemmSupportedDimOrder(dot_operand_dim_order)) << dot_computation->ToString(); dim_orders.insert({dot_operand, dot_operand_dim_order}); visited.insert(dot_operand); @@ -907,14 +1067,18 @@ DotFusionAnalysis::DotFusionAnalysis(const HloComputation* dot_computation, {hlo_operand, DimensionOrder(dim_orders.at(hlo))}); CHECK(inserted); DimensionOrder& hlo_operand_dim_order = it->second; - TF_CHECK_OK(hlo_operand_dim_order.HandleInstruction(hlo_operand)); - TF_CHECK_OK(RequireTritonGemmSupportedDimOrder(hlo_operand_dim_order)) + CHECK(hlo_operand_dim_order.HandleInstruction(hlo_operand)); + CHECK(RequireTritonGemmSupportedDimOrder(hlo_operand_dim_order)) << " " << dot_computation->ToString(); to_process.push(hlo_operand); } } + // For now all parameters of one scope have to use the same tiling. for (const HloInstruction* parameter : parameters_[scope]) { + CHECK(dim_orders.at(parameter).IsPhysicallyEquivalent( + dim_orders.at(*parameters_[scope].cbegin()))) + << dot_computation->ToString(); iter_specs_[scope][parameter] = DimensionOrderToTensorIterationSpec(dim_orders.at(parameter)); } @@ -926,22 +1090,22 @@ DotFusionAnalysis::DotFusionAnalysis(const HloComputation* dot_computation, .second); } -const DotFusionAnalysis::DimIterationSpec* DotFusionAnalysis::IterSpec( +const DimIterationSpec* DotFusionAnalysis::IterSpec( const DotFusionAnalysis::Scope scope, const HloInstruction* hlo, const int dimension) const { auto ret = iter_specs_.at(scope).find(hlo); if (ret != iter_specs_.at(scope).end()) { - return &ret->second.at(dimension); + return &ret->second[dimension]; } return nullptr; } -bool CanTritonHandleGEMM(const HloInstruction& dot, - const GpuVersion gpu_version) { +FusionDecision CanTritonHandleGEMM(const HloInstruction& dot, + const GpuVersion gpu_version) { if (dot.opcode() != HloOpcode::kDot || absl::c_any_of(dot.precision_config().operand_precision(), [](int x) { return x != PrecisionConfig::DEFAULT; })) { - return false; + return "Non-default precision."; } auto supported_output_type = [&](const PrimitiveType t) { @@ -961,21 +1125,21 @@ bool CanTritonHandleGEMM(const HloInstruction& dot, // TODO(b/266862493): Support more output types. if (!supported_output_type(dot.shape().element_type())) { - return false; + return "Unsupported output data type."; } if (!IsSupportedDataType(dot.operand(0)->shape().element_type(), gpu_version) || !IsSupportedDataType(dot.operand(1)->shape().element_type(), gpu_version)) { - return false; + return "Unsupported input data type."; } const DotDimensionNumbers& dim_numbers = dot.dot_dimension_numbers(); // TODO(b/269580541): support multiple batch dimensions. if (dim_numbers.lhs_batch_dimensions().size() > 1) { - return false; + return "Multiple batch dimensions."; } // Cases where lhs or rhs have no non-contracting dims are not handled. @@ -985,10 +1149,10 @@ bool CanTritonHandleGEMM(const HloInstruction& dot, dim_numbers.rhs_batch_dimensions().size() + dim_numbers.rhs_contracting_dimensions().size() == dot.operand(1)->shape().rank()) { - return false; + return "No non-contracting dimensions."; } - return true; + return FusionDecision{}; } bool ShouldTritonHandleGEMM(const HloInstruction& dot, @@ -1008,7 +1172,7 @@ bool ShouldTritonHandleGEMM(const HloInstruction& dot, while (!queue.empty()) { const HloInstruction* current = queue.front(); queue.pop(); - if (!CanFuse(current, dim_order, gpu_version).ok()) { + if (!CanFuse(*current, dim_order, gpu_version)) { continue; } // Stop as soon as a profitable operation is fused. diff --git a/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton.h b/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton.h index 715c79d9114659..0afc939b43ede2 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton.h +++ b/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton.h @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/compiler/xla/hlo/ir/hlo_opcode.h" #include "tensorflow/compiler/xla/service/gpu/gpu_types.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/compiler/xla/service/instruction_fusion.h" namespace xla { namespace gpu { @@ -52,13 +53,13 @@ Status MakeDotSplitKBatch(HloInstruction* dot_fusion, const AutotuneResult::TritonGemmKey& tiling); // Filters GEMMs which can be handled using Triton. -bool CanTritonHandleGEMM(const HloInstruction&, GpuVersion gpu_version); +FusionDecision CanTritonHandleGEMM(const HloInstruction&, + GpuVersion gpu_version); // Filters GEMMs which are better to handle using Triton. bool ShouldTritonHandleGEMM(const HloInstruction&, GpuVersion gpu_version); -// Analysis of iteration of HLO shapes within a fusion around dot(). -class DotFusionAnalysis { +class TensorIterationSpec { public: // Description of basic iteration: `count` elements separated by `stride`. struct IterationSpecFragment { @@ -68,16 +69,42 @@ class DotFusionAnalysis { // of several HLO dimensions. Product of subfragments equals `count`. std::vector subfragments; }; - // Description of complex iteration over a sequence of several strides. // Describes a logically contiguous dimension of a tensor physically // separated into multiple fragments by other dimensions. using DimIterationSpec = std::vector; // At most: contracting, non-contracting, split-K, another batch. - static const int kMaxDimsPerTensor = 4; - using TensorIterationSpec = std::array; + static constexpr int kMaxDimsPerTensor = 4; + using StorageType = std::array; + + const DimIterationSpec& operator[](int dimension) const { + return dim_iteration_specs_[dimension]; + } + + DimIterationSpec& operator[](int dimension) { + return dim_iteration_specs_[dimension]; + } + + // Compares physical layouts of tensors ignoring subfragments of dimensions. + bool operator==(const TensorIterationSpec& other) const; + + StorageType::iterator begin() { return dim_iteration_specs_.begin(); } + StorageType::iterator end() { return dim_iteration_specs_.end(); } + StorageType::const_iterator cbegin() const { + return dim_iteration_specs_.cbegin(); + } + StorageType::const_iterator cend() const { + return dim_iteration_specs_.cend(); + } + + private: + StorageType dim_iteration_specs_; +}; +// Analysis of iteration of HLO shapes within a fusion around dot(). +class DotFusionAnalysis { + public: // Execute analysis of dot fusion computation. // split_k indicates whether this operation was converted to the split-K // form and tells the analysis how to interpret the batch dimensions. @@ -88,9 +115,15 @@ class DotFusionAnalysis { // defined by left operand, right operand and output. enum class Scope { LHS = 0, RHS = 1, OUTPUT = 2 }; + // Every parameter requires a separate piece of shared memory for asynchronous + // loads. Multiple parameters are approximately equivalent to multiple + // pipeline stages. + static constexpr int kMaxParameterPerScope = 4; + // Scope -> HLO -> dot dimension number -> iteration spec at the HLO's output. - const DimIterationSpec* IterSpec(Scope scope, const HloInstruction*, - int dimension) const; + const TensorIterationSpec::DimIterationSpec* IterSpec(Scope scope, + const HloInstruction*, + int dimension) const; // Parameter HLO instructions used in a scope of `dot`. const absl::flat_hash_set& ScopeParameters( const Scope scope) const { diff --git a/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton_test.cc b/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton_test.cc index d02faa5b3abdc9..95eaf51915d2e5 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton_test.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton_test.cc @@ -94,7 +94,7 @@ ENTRY e { GmockMatch(m::Fusion(m::Parameter(), m::Parameter()))); } -TEST_F(GemmRewriterTritonTest, DoNotFuseConstant) { +TEST_F(GemmRewriterTritonTest, DoNotFuseConstants) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(R"( HloModule m @@ -102,14 +102,14 @@ HloModule m ENTRY e { p0 = s8[60,5] parameter(0) c0 = f16[60,5] convert(p0) - cst1 = f16[600] constant({...}) - r1 = f16[5,120] reshape(cst1) + cst1 = f16[] constant(1234) + r1 = f16[5,120] broadcast(cst1) ROOT d = f16[60,120] dot(c0, r1), lhs_contracting_dims={1}, rhs_contracting_dims={0} })")); EXPECT_TRUE(GemmRewriterTriton(gpu_version_).Run(module.get()).value()); EXPECT_THAT(module->entry_computation()->root_instruction(), - GmockMatch(m::Fusion(m::Constant(), m::Parameter()))); + GmockMatch(m::Fusion(m::Parameter(), m::Broadcast()))); } using TritonDotAnalysisTest = HloTestBase; @@ -793,6 +793,154 @@ ENTRY e { EXPECT_TRUE(GemmRewriterTriton(cc).Run(module.get()).value()); } +class GemmRewriterTritonLevel2Test : public GemmRewriterTritonTest { + public: + DebugOptions GetDebugOptionsForTest() override { + DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest(); + debug_options.set_xla_gpu_triton_fusion_level(2); + return debug_options; + } +}; + +TEST_F(GemmRewriterTritonLevel2Test, DoNotFuseIncompatibleDimOrders) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +HloModule m + +ENTRY e { + p0 = f16[5,3] parameter(0) + p1 = f16[5,7] parameter(1) + p2 = f16[7,5] parameter(2) + t = f16[5,7] transpose(p2), dimensions={1,0} + a = f16[5,7] add(t, p1) + ROOT d = f16[3,7] dot(p0, a), + lhs_contracting_dims={0}, rhs_contracting_dims={0} +})")); + + EXPECT_TRUE(GemmRewriterTriton(gpu_version_).Run(module.get()).value()); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch(m::Fusion(m::Parameter(), m::Parameter(), m::Transpose()))); +} + +TEST_F(GemmRewriterTritonLevel2Test, DoNotFuseTooManyParameters) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +ENTRY e { + tmp_0 = f32[] constant(1) + tmp_1 = f32[3,49]{1,0} broadcast(tmp_0), dimensions={} + tmp_2 = f32[3,49]{1,0} parameter(6) + tmp_3 = f32[] constant(0) + tmp_4 = f32[3,49]{1,0} broadcast(tmp_3), dimensions={} + tmp_5 = pred[3,49]{1,0} compare(tmp_2, tmp_4), direction=GT + tmp_6 = f32[3,49]{1,0} convert(tmp_5) + tmp_7 = f32[3,49]{1,0} subtract(tmp_1, tmp_6) + tmp_8 = s32[] parameter(13) + tmp_9 = f32[] convert(tmp_8) + tmp_10 = f32[] maximum(tmp_9, tmp_0) + tmp_11 = f32[] divide(tmp_3, tmp_10) + tmp_12 = f32[3,49]{1,0} broadcast(tmp_11), dimensions={} + tmp_13 = pred[3,49]{1,0} parameter(7) + tmp_14 = pred[3,49]{1,0} parameter(10) + tmp_15 = pred[3,49]{1,0} and(tmp_13, tmp_14) + tmp_16 = f32[3,49]{1,0} convert(tmp_15) + tmp_17 = f32[3,49]{1,0} multiply(tmp_12, tmp_16) + tmp_18 = f32[3,49]{1,0} negate(tmp_17) + tmp_19 = f32[3,49]{1,0} multiply(tmp_7, tmp_18) + tmp_20 = f32[3,49]{1,0} parameter(19) + tmp_21 = f32[3,49]{1,0} subtract(tmp_1, tmp_20) + tmp_22 = f32[3,49]{1,0} divide(tmp_19, tmp_21) + tmp_23 = f32[3,49]{1,0} negate(tmp_22) + tmp_24 = f32[3,49]{1,0} negate(tmp_6) + tmp_25 = f32[3,49]{1,0} multiply(tmp_24, tmp_17) + tmp_26 = f32[3,49]{1,0} divide(tmp_25, tmp_20) + tmp_27 = f32[3,49]{1,0} add(tmp_23, tmp_26) + tmp_28 = f32[3,49]{1,0} parameter(18) + tmp_29 = f32[3,49]{1,0} multiply(tmp_27, tmp_28) + tmp_30 = f32[3,49]{1,0} parameter(17) + tmp_31 = f32[3,49]{1,0} multiply(tmp_29, tmp_30) + tmp_32 = f32[3,49]{1,0} parameter(16) + tmp_33 = f32[3,49]{1,0} multiply(tmp_31, tmp_32) + tmp_34 = f32[3,49]{1,0} parameter(15) + tmp_35 = f32[3,49]{1,0} add(tmp_33, tmp_34) + tmp_36 = f32[3,49]{1,0} parameter(14) + tmp_37 = f32[3,49]{1,0} add(tmp_35, tmp_36) + tmp_38 = f32[1,1]{1,0} constant({ {0} }) + tmp_39 = f32[1,1]{1,0} broadcast(tmp_38), dimensions={0,1} + tmp_40 = f32[] reshape(tmp_39) + tmp_41 = f32[3,32]{1,0} broadcast(tmp_40), dimensions={} + tmp_42 = u32[48]{0} parameter(11) + tmp_43 = u32[48]{0} parameter(5) + tmp_44 = u32[96]{0} concatenate(tmp_42, tmp_43), dimensions={0} + tmp_45 = u32[3,32]{1,0} reshape(tmp_44) + tmp_46 = u32[96]{0} reshape(tmp_45) + tmp_47 = u32[] constant(1) + tmp_48 = u32[3,32]{1,0} broadcast(tmp_47), dimensions={} + tmp_49 = u32[96]{0} reshape(tmp_48) + tmp_50 = u32[96]{0} shift-right-logical(tmp_46, tmp_49) + tmp_51 = u32[3,32]{1,0} reshape(tmp_50) + tmp_52 = u32[3,32]{1,0} or(tmp_51, tmp_48) + tmp_53 = f32[3,32]{1,0} bitcast-convert(tmp_52) + tmp_54 = f32[3,32]{1,0} broadcast(tmp_0), dimensions={} + tmp_55 = f32[3,32]{1,0} subtract(tmp_53, tmp_54) + tmp_56 = f32[1,1]{1,0} constant({ {1} }) + tmp_57 = f32[1,1]{1,0} broadcast(tmp_56), dimensions={0,1} + tmp_58 = f32[] reshape(tmp_57) + tmp_59 = f32[3,32]{1,0} broadcast(tmp_58), dimensions={} + tmp_60 = f32[3,32]{1,0} multiply(tmp_55, tmp_59) + tmp_61 = f32[3,32]{1,0} add(tmp_60, tmp_41) + tmp_62 = f32[3,32]{1,0} maximum(tmp_41, tmp_61) + tmp_63 = f32[3,32]{1,0} broadcast(tmp_3), dimensions={} + tmp_64 = pred[3,32]{1,0} compare(tmp_62, tmp_63), direction=LT + tmp_65 = f32[3,32]{1,0} convert(tmp_64) + tmp_66 = f32[3,49]{1,0} parameter(9) + tmp_67 = f32[49]{0} parameter(4) + tmp_68 = f32[3,49]{1,0} broadcast(tmp_67), dimensions={1} + tmp_69 = f32[3,49]{1,0} add(tmp_66, tmp_68) + tmp_70 = f32[1,49]{1,0} parameter(12) + tmp_71 = f32[1,49]{1,0} broadcast(tmp_0), dimensions={} + tmp_72 = f32[1,49]{1,0} divide(tmp_70, tmp_71) + tmp_73 = f32[1,49]{1,0} broadcast(tmp_72), dimensions={0,1} + tmp_74 = f32[49]{0} reshape(tmp_73) + tmp_75 = f32[3,49]{1,0} broadcast(tmp_74), dimensions={1} + tmp_76 = f32[3,49]{1,0} subtract(tmp_69, tmp_75) + tmp_77 = f32[1,49]{1,0} parameter(3) + tmp_78 = f32[1,49]{1,0} parameter(8) + tmp_79 = f32[1,49]{1,0} divide(tmp_78, tmp_71) + tmp_80 = f32[1,49]{1,0} multiply(tmp_72, tmp_72) + tmp_81 = f32[1,49]{1,0} subtract(tmp_79, tmp_80) + tmp_82 = f32[1,49]{1,0} add(tmp_81, tmp_71) + tmp_83 = f32[1,49]{1,0} rsqrt(tmp_82) + tmp_84 = f32[1,49]{1,0} multiply(tmp_77, tmp_83) + tmp_85 = f32[1,49]{1,0} broadcast(tmp_84), dimensions={0,1} + tmp_86 = f32[49]{0} reshape(tmp_85) + tmp_87 = f32[3,49]{1,0} broadcast(tmp_86), dimensions={1} + tmp_88 = f32[3,49]{1,0} multiply(tmp_76, tmp_87) + tmp_89 = f32[1,49]{1,0} parameter(2) + tmp_90 = f32[1,49]{1,0} broadcast(tmp_89), dimensions={0,1} + tmp_91 = f32[49]{0} reshape(tmp_90) + tmp_92 = f32[3,49]{1,0} broadcast(tmp_91), dimensions={1} + tmp_93 = f32[3,49]{1,0} add(tmp_88, tmp_92) + tmp_94 = f32[49,32]{1,0} parameter(1) + tmp_95 = f32[3,32]{1,0} dot(tmp_93, tmp_94), lhs_contracting_dims={1}, rhs_contracting_dims={0} + tmp_96 = f32[32]{0} parameter(0) + tmp_97 = f32[3,32]{1,0} broadcast(tmp_96), dimensions={1} + tmp_98 = f32[3,32]{1,0} add(tmp_95, tmp_97) + tmp_99 = f32[3,32]{1,0} multiply(tmp_65, tmp_98) + tmp_100 = f32[3,32]{1,0} divide(tmp_99, tmp_63) + tmp_101 = f32[3,32]{1,0} maximum(tmp_100, tmp_63) + ROOT tmp_102 = f32[49,32]{1,0} dot(tmp_37, tmp_101), lhs_contracting_dims={0}, rhs_contracting_dims={0} +})")); + + EXPECT_TRUE(GemmRewriterTriton(gpu_version_).Run(module.get()).value()); + EXPECT_EQ(module->entry_computation()->root_instruction()->opcode(), + HloOpcode::kFusion); + EXPECT_EQ(module->entry_computation()->root_instruction()->fusion_kind(), + HloInstruction::FusionKind::kCustom); + EXPECT_LE(module->entry_computation()->root_instruction()->operand_count(), + DotFusionAnalysis::kMaxParameterPerScope * 2); +} + } // namespace } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index f490f9b127e21a..b3944952ac68da 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -973,6 +973,29 @@ Status GpuCompiler::OptimizeHloPostLayoutAssignment( }); } + GpuFloatSupport bf16_support(BF16); + GpuFloatSupport f8e5m2_support(F8E5M2); + GpuFloatSupport f8e4m3fn_support(F8E4M3FN); + FloatSupport f8e4m3b11fnuz_support(F8E4M3B11FNUZ); + FloatSupport f8e5m2fnuz_support(F8E5M2FNUZ); + FloatSupport f8e4m3fnuz_support(F8E4M3FNUZ); + + auto add_float_normalization = [&](HloPassPipeline& pipeline) { + auto& sub_pipeline = + pipeline.AddPass("float_normalization"); + sub_pipeline.AddPass(&bf16_support); + sub_pipeline.AddPass(&f8e5m2_support); + sub_pipeline.AddPass(&f8e4m3fn_support); + sub_pipeline.AddPass(&f8e4m3b11fnuz_support); + sub_pipeline.AddPass(&f8e5m2fnuz_support); + sub_pipeline.AddPass(&f8e4m3fnuz_support); + // Remove `f32 -> bf16 -> f32` casts inserted by bf16 normalization. + if (debug_options.xla_gpu_simplify_all_fp_conversions()) { + sub_pipeline.AddPass(); + } + }; + add_float_normalization(pipeline); + // By default use an externally provided thread pool. tsl::thread::ThreadPool* thread_pool = options.thread_pool; std::optional overriding_thread_pool; @@ -994,18 +1017,8 @@ Status GpuCompiler::OptimizeHloPostLayoutAssignment( &pipeline, hlo_module, stream_exec, debug_options, options, gpu_target_config, autotune_results, thread_pool)); - GpuFloatSupport bf16_support(BF16); - pipeline.AddPass(&bf16_support); - GpuFloatSupport f8e5m2_support(F8E5M2); - pipeline.AddPass(&f8e5m2_support); - GpuFloatSupport f8e4m3fn_support(F8E4M3FN); - pipeline.AddPass(&f8e4m3fn_support); - FloatSupport f8e4m3b11fnuz_support(F8E4M3B11FNUZ); - pipeline.AddPass(&f8e4m3b11fnuz_support); - FloatSupport f8e5m2fnuz_support(F8E5M2FNUZ); - pipeline.AddPass(&f8e5m2fnuz_support); - FloatSupport f8e4m3fnuz_support(F8E4M3FNUZ); - pipeline.AddPass(&f8e4m3fnuz_support); + // The Triton autotuner can insert new reductions. + add_float_normalization(pipeline); // Remove `f32 -> bf16 -> f32` casts inserted by bf16 normalization. if (debug_options.xla_gpu_simplify_all_fp_conversions()) { diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_triton.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_triton.cc index 709f3e40b52c3f..7c9cd87953a848 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_triton.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_triton.cc @@ -792,7 +792,7 @@ StatusOr MatMulImpl( if (!analysis.ScopeParameters(DotFusionAnalysis::Scope::LHS).empty()) { const HloInstruction* lhs_param0 = *analysis.ScopeParameters(DotFusionAnalysis::Scope::LHS).begin(); - const DotFusionAnalysis::DimIterationSpec* lhs_nc_iter_spec = + const TensorIterationSpec::DimIterationSpec* lhs_nc_iter_spec = analysis.IterSpec(DotFusionAnalysis::Scope::LHS, lhs_param0, lhs_noncontracting_dim_idx); lhs_nc_split = lhs_nc_iter_spec->size() > 1; diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_triton_test.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_triton_test.cc index fc4bb7204c1632..e87b973b4a60c8 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_triton_test.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_triton_test.cc @@ -25,11 +25,14 @@ limitations under the License. #include "tensorflow/compiler/xla/autotuning.pb.h" #include "tensorflow/compiler/xla/error_spec.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_computation.h" +#include "tensorflow/compiler/xla/hlo/ir/hlo_instruction.h" #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" #include "tensorflow/compiler/xla/service/gpu/gpu_device_info_for_tests.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/gpu/launch_dimensions.h" #include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h" +#include "tensorflow/compiler/xla/service/pattern_matcher.h" +#include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h" #include "tensorflow/compiler/xla/stream_executor/device_description.h" #include "tensorflow/compiler/xla/tests/verified_hlo_module.h" #include "tensorflow/tsl/lib/core/status_test_util.h" @@ -42,6 +45,8 @@ namespace xla { namespace gpu { namespace { +namespace m = ::xla::match; + class TritonGemmNoTF32Test : public GpuCodegenTest { public: void SetUp() override { @@ -715,6 +720,162 @@ ENTRY e { EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{/*aabs=*/1e-6, /*arel=*/1e-6})); } +class TritonGemmLevel2Test : public TritonGemmTest { + public: + DebugOptions GetDebugOptionsForTest() override { + DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest(); + debug_options.set_xla_gpu_triton_fusion_level(2); + return debug_options; + } +}; + +TEST_F(TritonGemmLevel2Test, BinaryOperationWithSmallInputsIsFused) { + const std::string kHloText = R"( +HloModule m + +ENTRY e { + p0 = s8[7,3] parameter(0) + p1 = f32[3,16] parameter(1) + p2 = f32[3,16] parameter(2) + e = f32[3,16] exponential(p1) + a = f32[3,16] add(e, p2) + c = f32[7,3] convert(p0) + ROOT d = f32[7,16] dot(c, a), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +})"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + GetOptimizedModule(kHloText)); + + EXPECT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch(m::Fusion(m::Parameter(), m::Parameter(), m::Parameter()) + .WithFusionKind(HloInstruction::FusionKind::kCustom))); + + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-1, /*arel=*/1e-3})); +} + +TEST_F(TritonGemmLevel2Test, BinaryOperationWithLargeInputsIsNotFused) { + const std::string kHloText = R"( +HloModule m + +ENTRY e { + p0 = f16[333,1000] parameter(0) + p1 = f32[1000,333] parameter(1) + p1n = f32[1000,333] negate(p1) + p2 = f32[1000,333] parameter(2) + p2n = f32[1000,333] negate(p2) + s = f32[1000,333] subtract(p1n, p2n) + c = f32[333,1000] convert(p0) + ROOT d = f32[1000,1000] dot(s, c), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +})"; + + MatchOptimizedHlo(kHloText, R"( +; CHECK: fused_computation +; CHECK: negate +; CHECK: negate +; CHECK: ROOT +; CHECK-SAME: subtract +; CHECK: ENTRY +; CHECK: kLoop +; CHECK: kCustom +)"); + + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-1, /*arel=*/1e-3})); +} + +TEST_F(TritonGemmLevel2Test, BinaryOperationOnLargeParametersIsFused) { + const std::string kHloText = R"( +HloModule m + +ENTRY e { + p0 = f16[1000,111] parameter(0) + p1 = f32[111,10000] parameter(1) + p2 = f32[111,10000] parameter(2) + s = f32[111,10000] subtract(p1, p2) + c = f32[1000,111] convert(p0) + ROOT d = f32[10000,1000] dot(s, c), + lhs_contracting_dims={0}, rhs_contracting_dims={1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + GetOptimizedModule(kHloText)); + + EXPECT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch(m::Fusion(m::Parameter(), m::Parameter(), m::Parameter()) + .WithFusionKind(HloInstruction::FusionKind::kCustom))); + + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-1, /*arel=*/1e-3})); +} + +TEST_F(TritonGemmLevel2Test, LinkingLibdeviceTwiceWorks) { + const std::string kHloText = R"( +HloModule m + +ENTRY e { + p0 = s8[7,3] parameter(0) + c0 = f32[7,3] convert(p0) + e0 = f32[7,3] exponential(c0) + p1 = f32[3,16] parameter(1) + e1 = f32[3,16] exponential(p1) + d0 = f32[7,16] dot(c0, e1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} + d1 = f32[7,16] dot(e0, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} + ROOT a = f32[7,16] add(d0, d1) +})"; + + MatchOptimizedHlo(kHloText, R"( +; CHECK: ENTRY +; CHECK-NEXT: parameter +; CHECK-NEXT: parameter +; CHECK-NEXT: kCustom +; CHECK-NEXT: kCustom +; CHECK-NEXT: ROOT +; CHECK-SAME: add +)"); + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + GetOptimizedModule(kHloText)); + + EXPECT_THAT(module->entry_computation()->root_instruction(), + GmockMatch(m::Add( + m::Fusion(m::Parameter(), m::Parameter()) + .WithFusionKind(HloInstruction::FusionKind::kCustom), + m::Fusion(m::Parameter(), m::Parameter()) + .WithFusionKind(HloInstruction::FusionKind::kCustom)))); + + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-2, /*arel=*/1e-2})); +} + +TEST_F(TritonGemmLevel2Test, BroadcastOfConstantIsNotFused) { + const std::string kHloText = R"( +HloModule m + +ENTRY e { + p0 = f16[70,30] parameter(0) + p0c = f32[70,30] convert(p0) + constant_3663 = f32[] constant(4321) + bc0 = f32[30,5] broadcast(constant_3663) + p1 = f32[30,5] parameter(1) + a = f32[30,5] add(p1, bc0) + ROOT d = f32[70,5] dot(p0c, a), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +})"; + + MatchOptimizedHlo(kHloText, R"( +; CHECK: ENTRY +; CHECK: constant +; CHECK: broadcast +; CHECK: fusion +; CHECK-SAME: kind=kCustom +)"); + + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/2e-3, /*arel=*/2e-3})); +} + TEST_F(TritonGemmTest, Naming) { const char* hlo_text = R"( HloModule t diff --git a/tensorflow/compiler/xla/service/gpu/triton_autotuner.cc b/tensorflow/compiler/xla/service/gpu/triton_autotuner.cc index 440a9611a8fe27..b8b8b5f6719931 100644 --- a/tensorflow/compiler/xla/service/gpu/triton_autotuner.cc +++ b/tensorflow/compiler/xla/service/gpu/triton_autotuner.cc @@ -418,12 +418,11 @@ std::vector GetExhaustiveMatmulAutotuneConfigs( std::vector GetFixedMatmulAutotuneConfigs( const se::CudaComputeCapability compute_capability) { std::vector configs = { - GemmKey(32, 32, 256, 1, 1, 4), GemmKey(64, 32, 32, 16, 1, 4), - GemmKey(32, 64, 64, 4, 1, 4), GemmKey(128, 128, 64, 4, 1, 4), - GemmKey(16, 16, 256, 1, 1, 4), GemmKey(16, 128, 32, 16, 1, 4), - GemmKey(16, 64, 128, 1, 1, 4), GemmKey(16, 128, 32, 8, 1, 4), - GemmKey(16, 16, 512, 1, 1, 4), GemmKey(32, 16, 512, 1, 1, 4), - GemmKey(64, 32, 64, 1, 2, 8)}; + GemmKey(32, 32, 256, 1, 1, 4), GemmKey(64, 32, 32, 16, 1, 4), + GemmKey(32, 64, 64, 4, 1, 4), GemmKey(16, 16, 256, 1, 1, 4), + GemmKey(16, 128, 32, 16, 1, 4), GemmKey(16, 64, 128, 1, 1, 4), + GemmKey(16, 128, 32, 8, 1, 4), GemmKey(16, 16, 512, 1, 1, 4), + GemmKey(32, 16, 512, 1, 1, 4), GemmKey(64, 32, 64, 1, 2, 8)}; if (compute_capability.IsAtLeast(se::CudaComputeCapability::AMPERE)) { absl::c_copy( std::vector{ diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto index 3eb4ae20db045d..f41923aeae68b6 100644 --- a/tensorflow/compiler/xla/xla.proto +++ b/tensorflow/compiler/xla/xla.proto @@ -570,7 +570,9 @@ message DebugOptions { bool xla_gpu_triton_gemm_disable_reduced_precision_reduction = 226; - // Next id: 229 + int32 xla_gpu_triton_fusion_level = 229; + + // Next id: 230 // Extra options to pass to the compilation backend (e.g. LLVM); specific // interpretation of these values is left to the backend. From 0676e862f0b3d4a5d06e18854c887ee7db12a0a2 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 11 Jul 2023 09:20:27 -0700 Subject: [PATCH 128/376] Enable `ToLiteral` for asynchronous execution PiperOrigin-RevId: 547210213 --- .../xla/pjrt/gpu/se_gpu_pjrt_client_test.cc | 41 +++++++ .../xla/pjrt/pjrt_stream_executor_client.cc | 105 ++++++++++++------ 2 files changed, 109 insertions(+), 37 deletions(-) diff --git a/tensorflow/compiler/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc b/tensorflow/compiler/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc index 532dc1bef757b8..3d9fee99e156c6 100644 --- a/tensorflow/compiler/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc +++ b/tensorflow/compiler/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc @@ -271,6 +271,47 @@ TEST(StreamExecutorGpuClientTest, ToLiteralAsync) { literal->Relayout(src_literal.shape().layout()).data()); } +TEST(StreamExecutorGpuClientTest, ToLiteralAsyncBeforeBufferReady) { + TF_ASSERT_OK_AND_ASSIGN( + auto client, GetStreamExecutorGpuClient(true, /*allocator_config=*/{}, + /*node_id=*/0)); + ASSERT_GE(client->addressable_devices().size(), 1); + + auto src_literal = LiteralUtil::CreateR1({41.0f, 42.0f, 43.0f, 44.0f}); + TF_ASSERT_OK_AND_ASSIGN( + auto transfer_manager, + client->CreateBuffersForAsyncHostToDevice( + {src_literal.shape()}, client->addressable_devices()[0])); + auto buffer = transfer_manager->RetrieveBuffer(0); + + absl::Mutex mu; + auto literal = std::make_shared( + ShapeUtil::DeviceShapeToHostShape(buffer->on_device_shape())); + bool got_literal = false; + + buffer->ToLiteral(literal.get(), [&](Status s) { + absl::MutexLock l(&mu); + TF_ASSERT_OK(s); + got_literal = true; + }); + + absl::SleepFor(absl::Milliseconds(10)); + ASSERT_FALSE(got_literal); + TF_ASSERT_OK( + transfer_manager->TransferLiteralToBuffer(0, src_literal, [&]() {})); + + buffer.reset(); + + { + absl::MutexLock l(&mu); + mu.Await(absl::Condition(&got_literal)); + } + + ASSERT_TRUE(ShapeUtil::Compatible(src_literal.shape(), literal->shape())); + ASSERT_EQ(src_literal.data(), + literal->Relayout(src_literal.shape().layout()).data()); +} + TEST(StreamExecutorGpuClientTest, FromHostAsync) { TF_ASSERT_OK_AND_ASSIGN( auto client, GetStreamExecutorGpuClient(true, /*allocator_config=*/{}, diff --git a/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc b/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc index 2043ff62361fb2..11947167552f16 100644 --- a/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc +++ b/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc @@ -1088,7 +1088,21 @@ PjRtStreamExecutorClient::CreateViewOfDeviceBuffer( void* device_ptr, const Shape& shape, PjRtDevice* device, std::function on_delete_callback) { se::DeviceMemoryBase buffer(device_ptr, ShapeUtil::ByteSizeOf(shape)); - absl::Span> definition_events; + + TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, + tensorflow::down_cast(device) + ->GetLocalDeviceState()); + + absl::InlinedVector, 2> + definition_events; + definition_events.emplace_back( + std::make_shared(this->thread_pool())); + TF_ASSIGN_OR_RETURN(EventPool::Handle event, + local_device->event_pool().ThenAllocateAndRecordEvent( + local_device->compute_stream())); + definition_events.back()->SetSequencingEvent(std::move(event), + local_device->compute_stream()); + auto device_buffer = std::make_shared( /*allocator=*/nullptr, device->local_hardware_id(), std::initializer_list{buffer}, definition_events, @@ -1343,48 +1357,65 @@ PjRtFuture PjRtStreamExecutorBuffer::ToLiteral( AcquireHoldLocked(&device_buffer); } - WaitForBufferDefinitionEventsOnStream(*device_buffer, stream); - ShapedBuffer shaped_buffer = device_buffer->AsShapedBuffer(on_device_shape_); - StatusOr event_or = - local_device->event_pool().AllocateEvent(stream->parent()); - if (!event_or.ok()) { - return PjRtFuture(event_or.status()); - } - - GenericTransferManager::LiteralFromDeviceMetadata transfer_metadata; - // We never call device functions from the `done` callback. - transfer_metadata.callback_is_host_callback_safe = true; + auto promise = PjRtFuture::CreatePromise(); + auto usage_event = + std::make_shared(client_->thread_pool()); TransferManager* transfer_manager = client_->client()->backend().transfer_manager(); - TransferManager::TransferMetadata* transfer_metadata_ptr = - (dynamic_cast(transfer_manager) != nullptr) - ? &transfer_metadata - : nullptr; + auto tracked_device_buffer = device_buffer.buffer(); + + // When using the ComputeSynchronized allocation model, retain a + // reference to the device_buffer until the copy completes, to + // ensure that the buffer isn't deleted or donated while it is still + // in use. The choice of retaining a reference at the host is a + // heuristic; the alternative is to ensure, before freeing the + // buffer, that the compute stream is synchronized past the + // transfer, but it seems better to hold onto the buffer too long + // than to stall the compute stream, particularly since the + // overwhelmingly common use case of CopyToHostAsync will hold onto + // the reference long enough to read the buffer in a subsequent call + // to ToLiteral. + device_buffer.ConvertUsageHold(stream, usage_event, /*reference_held=*/true); + + auto async_to_literal = [usage_event, tracked_device_buffer, stream, + transfer_manager = std::move(transfer_manager), + on_device_shape{on_device_shape_}, literal, promise, + local_device]() mutable { + StatusOr event_or = + local_device->event_pool().AllocateEvent(stream->parent()); + if (!event_or.ok()) { + promise.Set(event_or.status()); + return; + } + WaitForBufferDefinitionEventsOnStream(*tracked_device_buffer, stream); + ShapedBuffer shaped_buffer = + tracked_device_buffer->AsShapedBuffer(on_device_shape); - auto promise = PjRtFuture::CreatePromise(); - transfer_manager->TransferLiteralFromDevice( - stream, shaped_buffer, literal, - [promise](Status status) mutable { promise.Set(status); }, - transfer_metadata_ptr); + GenericTransferManager::LiteralFromDeviceMetadata transfer_metadata; + // We never call device functions from the `done` callback. + transfer_metadata.callback_is_host_callback_safe = true; - auto usage_event = - std::make_shared(client_->thread_pool()); - local_device->event_pool().ThenRecordEvent(stream, event_or.value()); - usage_event->SetSequencingEvent(std::move(event_or).value(), stream); - // When using the ComputeSynchronized allocation model, retain a reference to - // the device_buffer until the copy completes, to ensure that the buffer isn't - // deleted or donated while it is still in use. The choice of retaining a - // reference at the host is a heuristic; the alternative is to ensure, before - // freeing the buffer, that the compute stream is synchronized past the - // transfer, but it seems better to hold onto the buffer too long than to - // stall the compute stream, particularly since the overwhelmingly common - // use case of CopyToHostAsync will hold onto the reference long enough to - // read the buffer in a subsequent call to ToLiteral. - RecordUsage(std::move(device_buffer), local_device, local_device, usage_event, - stream, - /*prefer_to_retain_reference=*/true); + TransferManager::TransferMetadata* transfer_metadata_ptr = + (dynamic_cast(transfer_manager) != nullptr) + ? &transfer_metadata + : nullptr; + + transfer_manager->TransferLiteralFromDevice( + stream, shaped_buffer, literal, + [promise](Status status) mutable { promise.Set(status); }, + transfer_metadata_ptr); + + local_device->event_pool().ThenRecordEvent(stream, event_or.value()); + usage_event->SetSequencingEvent(std::move(event_or).value(), stream); + + local_device->ThenRelease(stream, tracked_device_buffer); + }; + + tracked_device_buffer->definition_events()[0]->ExecuteOrAddToFutureTasks( + absl::StrFormat("async_to_literal_%p", literal), + std::move(async_to_literal)); return PjRtFuture( std::move(promise), From 70c449a6163acdde00168f35549ea9927e5228d6 Mon Sep 17 00:00:00 2001 From: Kuangyuan Chen Date: Tue, 11 Jul 2023 09:23:10 -0700 Subject: [PATCH 129/376] Fixing the tsan issue in stream_test PiperOrigin-RevId: 547210851 --- tensorflow/core/tfrt/runtime/stream_test.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/tfrt/runtime/stream_test.cc b/tensorflow/core/tfrt/runtime/stream_test.cc index 50fc07e79c1571..df377486627941 100644 --- a/tensorflow/core/tfrt/runtime/stream_test.cc +++ b/tensorflow/core/tfrt/runtime/stream_test.cc @@ -105,7 +105,7 @@ TEST(StreamTest, MultipleWriters) { {{"c", AsTensor({300})}}}; for (const auto& p : expected) { - tsl::Env::Default()->SchedClosure([&, p]() { + tsl::Env::Default()->SchedClosure([callback_id, step_id, p]() { // The stream callback may be dropped early, and in that case we ignore // the error. GetGlobalStreamCallbackRegistry() From 08eb66d4aecb2b0dd64c867204fe0c87492cc8cd Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 11 Jul 2023 09:29:04 -0700 Subject: [PATCH 130/376] Update TFRT dependency to use revision http://github.com/tensorflow/runtime/commit/62d593e29280c8ef8dee7a5477b04b89ac77c06c. PiperOrigin-RevId: 547212324 --- third_party/tf_runtime/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/tf_runtime/workspace.bzl b/third_party/tf_runtime/workspace.bzl index 1f02109fb9dcf5..49c3aa40d3faea 100644 --- a/third_party/tf_runtime/workspace.bzl +++ b/third_party/tf_runtime/workspace.bzl @@ -6,8 +6,8 @@ def repo(): """Imports TFRT.""" # Attention: tools parse and update these lines. - TFRT_COMMIT = "4ed29c1d398da53ee2cfa5191868e367a0e4052f" - TFRT_SHA256 = "7d0c7a60785da8161d63574ff395229a61a0fd204134d1f670c9ea027df6f73d" + TFRT_COMMIT = "62d593e29280c8ef8dee7a5477b04b89ac77c06c" + TFRT_SHA256 = "31e762e2cdfd4c956ba92f9f90fe7d5f0896cb8ec3d52111cc261f797b8aba65" tf_http_archive( name = "tf_runtime", From 93403820580764bb46c30aede8f9cb5b9b990545 Mon Sep 17 00:00:00 2001 From: Xinyi Wang Date: Tue, 11 Jul 2023 09:47:20 -0700 Subject: [PATCH 131/376] Disable failing test. PiperOrigin-RevId: 547217151 --- tensorflow/python/distribute/BUILD | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD index e61e481e4d3814..05b88a5e37cb7b 100644 --- a/tensorflow/python/distribute/BUILD +++ b/tensorflow/python/distribute/BUILD @@ -2584,6 +2584,7 @@ distribute_py_strict_test( "multi_and_single_gpu", "no_oss", # TODO(b/249822228) "noasan", # TODO(b/237407459) + "nomsan", # TODO(b/290745680) "notpu", "notsan", # Tsan failure doesn't seem to be caused by TF. ], From b3c0aa15d4678812f866b2a344bb4d8d2807ac66 Mon Sep 17 00:00:00 2001 From: Yu Feng Date: Tue, 11 Jul 2023 09:58:05 -0700 Subject: [PATCH 132/376] Decouple Layout C++ from ShardingSpec protobuf. Less mental struggle as of when to use sharding spec, when to use string. PiperOrigin-RevId: 547220003 --- tensorflow/dtensor/cc/tensor_layout.cc | 128 +++++++----------- tensorflow/dtensor/cc/tensor_layout.h | 16 +-- tensorflow/dtensor/mlir/collectives.cc | 28 ++-- .../mlir/expansions/bias_add_spmd_expander.cc | 6 +- .../mlir/expansions/einsum_spmd_expander.cc | 95 ++++++------- .../expansions/elementwise_spmd_expander.cc | 2 +- .../expansions/expanddims_spmd_expander.cc | 9 +- .../mlir/expansions/gather_spmd_expander.cc | 49 +++---- .../mlir/expansions/in_top_k_spmd_expander.cc | 13 +- .../mlir/expansions/matmul_spmd_expander.cc | 16 +-- .../mlir/expansions/meta_spmd_expander.cc | 83 ++++++------ .../mlir/expansions/nullary_spmd_expander.cc | 3 +- .../expansions/random_op_spmd_expander.cc | 7 +- .../mlir/expansions/reduce_spmd_expander.cc | 8 +- .../mlir/expansions/scatter_spmd_expander.cc | 29 ++-- .../mlir/expansions/softmax_spmd_expander.cc | 101 +++++++------- .../mlir/expansions/split_spmd_expander.cc | 21 ++- .../mlir/expansions/squeeze_spmd_expander.cc | 21 ++- .../mlir/expansions/top_k_spmd_expander.cc | 11 +- .../mlir/expansions/where_spmd_expander.cc | 4 +- tensorflow/dtensor/mlir/ir/tf_dtensor.cc | 11 +- .../dtensor/mlir/utils/collective_lowering.cc | 21 ++- 22 files changed, 304 insertions(+), 378 deletions(-) diff --git a/tensorflow/dtensor/cc/tensor_layout.cc b/tensorflow/dtensor/cc/tensor_layout.cc index 4e3a69d7d83ff3..f38a70137eca30 100644 --- a/tensorflow/dtensor/cc/tensor_layout.cc +++ b/tensorflow/dtensor/cc/tensor_layout.cc @@ -800,27 +800,13 @@ Mesh Mesh::CreateMesh(const std::string& mesh_name, } StatusOr Layout::GetLayout( - const std::vector& sharding_spec_strs, const Mesh& mesh) { - // Re-format sharding specs. - std::vector sharding_specs; - sharding_specs.reserve(sharding_spec_strs.size()); - for (const std::string& spec_str : sharding_spec_strs) { - ShardingSpec spec; - spec.set_sharding_spec(spec_str); - sharding_specs.push_back(spec); - } - return GetLayout(sharding_specs, mesh); -} - -StatusOr Layout::GetLayout( - const std::vector& sharding_specs, const Mesh& mesh) { + const std::vector& sharding_specs, const Mesh& mesh) { Layout layout; // Append mesh, then check sharding_specs are legal. layout.mesh_ = mesh; // Check sharding_specs are either mesh dimension or special value. - for (const auto& dim : sharding_specs) { - const std::string& sharding_spec = dim.sharding_spec(); + for (const auto& sharding_spec : sharding_specs) { if (!(sharding_spec == kUnshardedDim || sharding_spec == kAny || sharding_spec == kMatch || mesh.IsMeshDim(sharding_spec) || sharding_spec == "scalar")) @@ -831,8 +817,7 @@ StatusOr Layout::GetLayout( } // Check same tensor dimensions not sharded over same mesh dimension twice. std::set dims_set; - for (const auto& dim : sharding_specs) { - const std::string& sharding_spec = dim.sharding_spec(); + for (const auto& sharding_spec : sharding_specs) { if (sharding_spec == kUnshardedDim || sharding_spec == kAny) continue; // If scalar, delete all sharding specs. if (sharding_spec == "scalar") { @@ -876,8 +861,7 @@ bool Layout::IsEmpty() const { return mesh_.IsEmpty(); } namespace { Mesh ReducedAbstractMesh(const Layout* layout) { - const std::vector& shard_spec_strs = - layout->sharding_spec_strs(); + const std::vector shard_spec_strs = layout->sharding_spec_strs(); std::vector reduced_mesh_dims; reduced_mesh_dims.reserve(layout->mesh().dims().size()); for (const MeshDimension& mesh_dim : layout->mesh().dims()) { @@ -934,12 +918,9 @@ Mesh Layout::ReducedMesh() const { namespace { Layout ReducedLayout(const Layout* layout) { - // Change format sharding specs. - std::vector shard_specs(layout->sharding_specs().size()); - for (size_t i = 0; i < shard_specs.size(); ++i) - shard_specs[i] = layout->dim(i); // Retrieve layout. - return Layout::GetLayout(shard_specs, layout->ReducedMesh()).value(); + return Layout::GetLayout(layout->sharding_spec_strs(), layout->ReducedMesh()) + .value(); } // Returns index of the given mesh dimension or mesh dim size if not found. @@ -952,16 +933,13 @@ StatusOr IndexOfMeshDimension(const Mesh& mesh, } // namespace ShardVector Layout::GetShardVector() const { - // Change format sharding specs. - std::vector shard_specs(sharding_specs().size()); - for (size_t i = 0; i < shard_specs.size(); ++i) shard_specs[i] = dim(i); // Obtain a shard position (i.e. sharded section of a tensor) from a mesh // location, using the sharding specs. auto GetShardFromDeviceLocation = [&](const DeviceLocation& loc) -> Shard { Shard shard; - for (size_t i = 0; i < shard_specs.size(); ++i) { + for (size_t i = 0; i < sharding_specs_.size(); ++i) { // If unsharded, there is only one shard, that is 1. - std::string spec = shard_specs[i].sharding_spec(); + std::string spec = sharding_specs_[i]; if (spec == Layout::kUnshardedDim) { shard.push_back(1); } else { @@ -974,11 +952,11 @@ ShardVector Layout::GetShardVector() const { }; // Obtain dims of shard vector. auto ShardVectorDims = [&]() -> std::vector { - std::vector num_shards_per_dim(shard_specs.size()); - for (size_t i = 0; i < sharding_specs().size(); ++i) { - ShardingSpec spec = sharding_specs()[i]; - if (Layout::IsShardedSpec(spec)) { - StatusOr dim_size = mesh().dim_size(spec.sharding_spec()); + std::vector num_shards_per_dim(sharding_specs_.size()); + for (size_t i = 0; i < sharding_specs_.size(); ++i) { + std::string spec = sharding_specs_[i]; + if (Layout::IsShardedDimension(spec)) { + StatusOr dim_size = mesh().dim_size(spec); num_shards_per_dim[i] = dim_size.value(); } else { num_shards_per_dim[i] = 1; @@ -1033,28 +1011,25 @@ std::map Layout::HostShardMap() const { } const std::string& Layout::sharding_spec(int idx) const { - return sharding_specs_[idx].sharding_spec(); + return sharding_specs_[idx]; } std::vector Layout::num_shards() const { std::vector num_shards; num_shards.reserve(sharding_specs_.size()); - for (const auto& sharding_spec : sharding_specs_) { - num_shards.push_back(num_shards_for_dim(sharding_spec)); + for (int64_t index = 0; index < sharding_specs_.size(); ++index) { + num_shards.push_back(num_shards_for_dim(index)); } return num_shards; } -size_t Layout::num_shards_for_dim(const ShardingSpec& dim) const { - absl::string_view name = dim.sharding_spec(); - if (name == Layout::kUnshardedDim) return 1; - if (name == Layout::kMatch) return -1; - - return mesh().dim_size(name).value(); -} size_t Layout::num_shards_for_dim(int dim) const { - return num_shards_for_dim(sharding_specs_[dim]); + const std::string spec = sharding_specs_[dim]; + if (spec == Layout::kUnshardedDim) return 1; + if (spec == Layout::kMatch) return -1; + + return mesh().dim_size(spec).value(); } bool Layout::IsFullyReplicated() const { @@ -1062,7 +1037,7 @@ bool Layout::IsFullyReplicated() const { return false; } for (const auto& sharding_spec : sharding_specs_) { - if (sharding_spec.sharding_spec() != Layout::kUnshardedDim) return false; + if (sharding_spec != Layout::kUnshardedDim) return false; } return true; } @@ -1070,7 +1045,7 @@ bool Layout::IsFullyReplicated() const { bool Layout::IsLastDimReplicated() const { return (mesh_.IsTile() && ((sharding_specs_.empty()) || - (sharding_specs_.back().sharding_spec() == Layout::kUnshardedDim))); + (sharding_specs_.back() == Layout::kUnshardedDim))); } bool Layout::IsBatchParallel() const { @@ -1082,12 +1057,12 @@ bool Layout::IsBatchParallel() const { } for (int i = 1; i < sharding_specs_.size(); ++i) { - const auto& dim = sharding_specs_[i]; - if (dim.sharding_spec() != Layout::kUnshardedDim) { + const auto& spec = sharding_specs_[i]; + if (spec != Layout::kUnshardedDim) { return false; } } - return sharding_specs_[0].sharding_spec() != Layout::kUnshardedDim; + return sharding_specs_[0] != Layout::kUnshardedDim; } // TODO(samuelslee) Replace this with the IsBatchParallel() everywhere @@ -1097,7 +1072,7 @@ bool Layout::IsBatchParallel(int non_batch_rank) const { } if (sharding_specs_.empty()) return true; for (int i = rank() - non_batch_rank; i < rank(); ++i) { - if (num_shards_for_dim(sharding_specs_[i]) != 1) return false; + if (num_shards_for_dim(i) != 1) return false; } return true; } @@ -1105,8 +1080,8 @@ bool Layout::IsBatchParallel(int non_batch_rank) const { StatusOr Layout::ToProto() const { LayoutProto proto; TF_ASSIGN_OR_RETURN(*proto.mutable_mesh_config(), mesh_.ToProto()); - for (const auto& dim : sharding_specs_) { - *proto.add_sharding_specs() = dim; + for (const auto& spec : sharding_specs_) { + proto.add_sharding_specs()->set_sharding_spec(spec); } return proto; } @@ -1115,10 +1090,8 @@ bool Layout::IsEquivalent(const Layout& b) const { if (this->rank() != b.rank()) return false; if (this->mesh() != b.mesh()) return false; for (int i = 0; i < this->rank(); ++i) { - if (this->sharding_specs_[i].sharding_spec() != - b.sharding_specs_[i].sharding_spec()) { - if ((this->num_shards_for_dim(this->sharding_specs_[i]) != 1) || - (b.num_shards_for_dim(b.sharding_specs_[i]) != 1)) + if (this->sharding_specs_[i] != b.sharding_specs_[i]) { + if ((this->num_shards_for_dim(i) != 1) || (b.num_shards_for_dim(i) != 1)) return false; } } @@ -1142,7 +1115,7 @@ std::vector Layout::GlobalShapeFromLocalShape( } std::vector stride_for_dim; - stride_for_dim.resize(sharding_specs().size()); + stride_for_dim.resize(sharding_specs_.size()); size_t stride = mesh().num_local_devices(); for (int i = 0; i < stride_for_dim.size(); i++) { stride = stride / num_shards_for_dim(i); @@ -1168,8 +1141,8 @@ std::vector Layout::GlobalShapeFromLocalShape( }; std::vector global_shape; - global_shape.reserve(sharding_specs().size()); - for (int i = 0; i < sharding_specs().size(); ++i) { + global_shape.reserve(sharding_specs_.size()); + for (int i = 0; i < sharding_specs_.size(); ++i) { global_shape.push_back(dimension_size(i)); } return global_shape; @@ -1182,7 +1155,7 @@ std::vector Layout::LocalShapeFromGlobalShape( } std::vector shards = num_shards(); std::vector local_shape; - for (int i = 0; i < sharding_specs().size(); ++i) { + for (int i = 0; i < sharding_specs_.size(); ++i) { int64_t dim_shards = shards[i]; // TODO(hthu): Shape might not be always divisible. int64_t local_size = IsDynamicSize(global_shape[i]) @@ -1200,7 +1173,7 @@ PartialTensorShape Layout::LocalShapeFromGlobalShape( } std::vector shards = num_shards(); PartialTensorShape local_shape({}); - for (int spec_index = 0; spec_index < sharding_specs().size(); ++spec_index) { + for (int spec_index = 0; spec_index < sharding_specs_.size(); ++spec_index) { int64_t dim_size = global_shape.dim_size(spec_index); int64_t local_size = IsDynamicSize(dim_size) ? dim_size : dim_size / shards[spec_index]; @@ -1213,7 +1186,7 @@ StatusOr Layout::FromProto(const LayoutProto& proto) { Layout layout; if (proto.mesh_config().single_device().empty()) { for (const auto& spec : proto.sharding_specs()) - layout.sharding_specs_.push_back(spec); + layout.sharding_specs_.push_back(spec.sharding_spec()); TF_ASSIGN_OR_RETURN(auto mesh, Mesh::ParseFromProto(proto.mesh_config())); layout.mesh_ = std::move(mesh); @@ -1299,10 +1272,7 @@ StatusOr Layout::FromString(absl::string_view layout_str) { } std::vector Layout::sharding_spec_strs() const { - std::vector sharding_spec_strs(sharding_specs().size()); - for (size_t i = 0; i < sharding_specs().size(); ++i) - sharding_spec_strs[i] = sharding_spec(i); - return sharding_spec_strs; + return sharding_specs_; } std::string Layout::ToString() const { @@ -1315,8 +1285,7 @@ std::string Layout::ToString() const { std::string layout_str = "sharding_specs:"; // Print sharding specs. - for (const ShardingSpec& dim : sharding_specs_) { - std::string dim_name = dim.sharding_spec(); + for (const auto& dim_name : sharding_specs_) { absl::StrAppend(&layout_str, dim_name + ","); } // Append mesh. @@ -1326,19 +1295,16 @@ std::string Layout::ToString() const { StatusOr Layout::GetLayoutWithReducedDims( const absl::flat_hash_set& reduced_dims, bool keep_dims) const { - dtensor::LayoutProto output_layout; - TF_ASSIGN_OR_RETURN(*output_layout.mutable_mesh_config(), mesh().ToProto()); - + std::vector sharding_specs; for (int i = 0; i < rank(); ++i) { // reduced_dims may contain negative values. if (!reduced_dims.contains(i) && !reduced_dims.contains(i - rank())) { - *output_layout.add_sharding_specs() = dim(i); + sharding_specs.push_back(sharding_spec(i)); } else if (keep_dims) { - auto* replicated_dim = output_layout.add_sharding_specs(); - replicated_dim->set_sharding_spec(kUnshardedDim); + sharding_specs.push_back(kUnshardedDim); } } - return Layout::FromProto(output_layout).value(); + return Layout::GetLayout(sharding_specs, mesh()); } Layout Layout::Truncate(int64 split_point, bool end) const { @@ -1362,9 +1328,7 @@ Layout Layout::LeftPad(int64_t rank) const { Layout output_layout(*this); auto& specs = output_layout.sharding_specs_; - ShardingSpec spec; - spec.set_sharding_spec(Layout::kUnshardedDim); - specs.insert(specs.begin(), rank - this->rank(), spec); + specs.insert(specs.begin(), rank - this->rank(), Layout::kUnshardedDim); return output_layout; } @@ -1408,7 +1372,7 @@ StatusOr GetMostShardedLayout(const std::vector& layouts) { absl::flat_hash_map> layout_map; for (const Layout& layout : layouts) { for (int i = 0; i < layout.rank(); ++i) { - const std::string& mesh_dim = layout.dim(i).sharding_spec(); + const std::string& mesh_dim = layout.sharding_spec(i); if (mesh_dim == Layout::kUnshardedDim) continue; layout_map[mesh_dim].insert(i); @@ -1461,7 +1425,7 @@ StatusOr GetLeastShardedLayout(const std::vector& layouts) { } specs.resize(rank, Layout::kAny); for (const auto& layout : layouts) { - auto current_specs = layout.sharding_spec_strs(); + const auto current_specs = layout.sharding_spec_strs(); for (int i = 0; i < rank; i++) { auto current_spec = current_specs[i]; if (specs[i] == Layout::kAny) { diff --git a/tensorflow/dtensor/cc/tensor_layout.h b/tensorflow/dtensor/cc/tensor_layout.h index 2cd59d123b9cab..4bee6860646539 100644 --- a/tensorflow/dtensor/cc/tensor_layout.h +++ b/tensorflow/dtensor/cc/tensor_layout.h @@ -340,16 +340,8 @@ class Layout { static bool IsShardedDimension(const absl::string_view name) { return !IsUnshardedDimension(name); } - static bool IsUnshardedSpec(const ShardingSpec& spec) { - return IsUnshardedDimension(spec.sharding_spec()); - } - static bool IsShardedSpec(const ShardingSpec& spec) { - return !IsUnshardedDimension(spec.sharding_spec()); - } static StatusOr GetLayout( const std::vector& sharding_spec_strs, const Mesh& mesh); - static StatusOr GetLayout( - const std::vector& sharding_specs, const Mesh& mesh); static StatusOr GetSingleDeviceLayout(const Mesh& mesh); // Makes a new layout from this one dropping the given dimensions. @@ -391,15 +383,9 @@ class Layout { const PartialTensorShape& global_shape) const; int64 rank() const { return sharding_specs_.size(); } - size_t num_shards_for_dim(const ShardingSpec& dim) const; size_t num_shards_for_dim(int) const; std::vector num_shards() const; - const ShardingSpec& dim(int64 idx) const { return sharding_specs_[idx]; } - absl::Span sharding_specs() const { - return sharding_specs_; - } - // Computes the corresponding shard vector to this layout. ShardVector GetShardVector() const; @@ -426,7 +412,7 @@ class Layout { } private: - std::vector sharding_specs_; + std::vector sharding_specs_; Mesh mesh_; }; diff --git a/tensorflow/dtensor/mlir/collectives.cc b/tensorflow/dtensor/mlir/collectives.cc index c54eb2b725366d..95c91dc086b0a3 100644 --- a/tensorflow/dtensor/mlir/collectives.cc +++ b/tensorflow/dtensor/mlir/collectives.cc @@ -165,22 +165,22 @@ bool CanUseAllToAll(const dtensor::Layout& src_layout, // all-to-all in addition to these which can be supported later. int num_split_dims = 0; int num_concat_dims = 0; - ShardingSpec split_spec; - ShardingSpec concat_spec; + std::string split_spec; + std::string concat_spec; for (int i = 0; i < src_layout.rank(); ++i) { if (src_layout.sharding_spec(i) == tgt_layout.sharding_spec(i)) continue; if (Layout::IsUnshardedDimension(src_layout.sharding_spec(i)) && Layout::IsShardedDimension(tgt_layout.sharding_spec(i))) { num_split_dims++; - split_spec = tgt_layout.dim(i); + split_spec = tgt_layout.sharding_spec(i); } else if (Layout::IsShardedDimension(src_layout.sharding_spec(i)) && Layout::IsUnshardedDimension(tgt_layout.sharding_spec(i))) { num_concat_dims++; - concat_spec = src_layout.dim(i); + concat_spec = src_layout.sharding_spec(i); } } return num_split_dims == 1 && num_concat_dims == 1 && - split_spec.sharding_spec() == concat_spec.sharding_spec(); + split_spec == concat_spec; } StatusOr EmitAllToAll( @@ -339,14 +339,14 @@ StatusOr EmitRelayout( for (int i = 0; i < src_layout.rank(); ++i) src_sharding_dims.emplace(src_layout.sharding_spec(i)); - std::vector intermediate_specs_1(src_layout.rank()); + std::vector intermediate_specs_1(src_layout.rank()); for (int i = 0; i < src_layout.rank(); ++i) { - if (Layout::IsShardedSpec(tgt_layout.dim(i)) && - !Layout::IsShardedSpec(src_layout.dim(i)) && + if (Layout::IsShardedDimension(tgt_layout.sharding_spec(i)) && + !Layout::IsShardedDimension(src_layout.sharding_spec(i)) && !src_sharding_dims.contains(tgt_layout.sharding_spec(i))) - intermediate_specs_1[i] = tgt_layout.dim(i); + intermediate_specs_1[i] = tgt_layout.sharding_spec(i); else - intermediate_specs_1[i] = src_layout.dim(i); + intermediate_specs_1[i] = src_layout.sharding_spec(i); } TF_ASSIGN_OR_RETURN( Layout intermediate_layout_1, @@ -357,11 +357,11 @@ StatusOr EmitRelayout( EmitAllScatter(builder, input, src_layout, intermediate_layout_1, newly_created_ops)); - std::vector intermediate_specs_2(src_layout.rank()); + std::vector intermediate_specs_2(src_layout.rank()); for (int i = 0; i < src_layout.rank(); ++i) { - if (Layout::IsShardedSpec(intermediate_specs_1[i]) && - intermediate_specs_1[i].sharding_spec() != tgt_layout.sharding_spec(i)) - intermediate_specs_2[i].set_sharding_spec(Layout::kUnshardedDim); + if (Layout::IsShardedDimension(intermediate_specs_1[i]) && + intermediate_specs_1[i] != tgt_layout.sharding_spec(i)) + intermediate_specs_2[i] = Layout::kUnshardedDim; else intermediate_specs_2[i] = intermediate_specs_1[i]; } diff --git a/tensorflow/dtensor/mlir/expansions/bias_add_spmd_expander.cc b/tensorflow/dtensor/mlir/expansions/bias_add_spmd_expander.cc index fa9ce8c6435c69..4b97bd0786641a 100644 --- a/tensorflow/dtensor/mlir/expansions/bias_add_spmd_expander.cc +++ b/tensorflow/dtensor/mlir/expansions/bias_add_spmd_expander.cc @@ -69,10 +69,8 @@ StatusOr BiasAddExpander::ExpandOp(mlir::Operation* op) { // Check if output is sharded more, change input layout to match output // layout. - int64_t num_input_shards = - input_layout.num_shards_for_dim(input_layout.dim(c_dim_idx)); - int64_t num_output_shards = - output_layout.num_shards_for_dim(output_layout.dim(c_dim_idx)); + int64_t num_input_shards = input_layout.num_shards_for_dim(c_dim_idx); + int64_t num_output_shards = output_layout.num_shards_for_dim(c_dim_idx); if (num_input_shards < num_output_shards) { mlir::Value output; diff --git a/tensorflow/dtensor/mlir/expansions/einsum_spmd_expander.cc b/tensorflow/dtensor/mlir/expansions/einsum_spmd_expander.cc index 6b439730931bfe..eff66fa0354784 100644 --- a/tensorflow/dtensor/mlir/expansions/einsum_spmd_expander.cc +++ b/tensorflow/dtensor/mlir/expansions/einsum_spmd_expander.cc @@ -30,13 +30,6 @@ limitations under the License. namespace tensorflow { namespace dtensor { -namespace { - -bool Equal(const ShardingSpec& a, const ShardingSpec& b) { - return a.sharding_spec() == b.sharding_spec(); -} - -} // namespace // Einsum, like reductions, is implemented as a local operation followed by // an all-reduce over dimensions that have been reduced. @@ -150,10 +143,10 @@ Status ExtractEquationRelations( // sharding specs we raise an error if replicate_incompatible_dimensions is // false. Otherwise we treat the dimension as if it were unsharded. // Labels with unsharded dimensions are not recorded in the output. -StatusOr> GetLabelToShardingSpec( +StatusOr> GetLabelToShardingSpec( bool replicate_incompatible_dimensions, const std::vector& layouts, const std::vector>>& mappings) { - absl::flat_hash_map label_to_sharding_spec; + absl::flat_hash_map label_to_sharding_spec; absl::flat_hash_set incompatible_labels; // For each mapping, identify the mesh dimension and whether it has been @@ -171,23 +164,23 @@ StatusOr> GetLabelToShardingSpec( layouts[index].rank()) .str()); - const ShardingSpec& sharding_spec = layouts[index].dim(offset); + const std::string& sharding_spec = layouts[index].sharding_spec(offset); if (label_to_sharding_spec.contains(mapping.first)) { - if (Layout::IsShardedSpec(sharding_spec) && - !Equal(label_to_sharding_spec[mapping.first], sharding_spec)) { + if (Layout::IsShardedDimension(sharding_spec) && + label_to_sharding_spec[mapping.first] != sharding_spec) { if (!replicate_incompatible_dimensions) return errors::InvalidArgument( llvm::formatv( "incompatible mesh dimensions in equation, label '{0}' " "is mapped to mesh dimension '{1}' and '{2}'", - mapping.first, sharding_spec.sharding_spec(), - label_to_sharding_spec[mapping.first].sharding_spec()) + mapping.first, sharding_spec, + label_to_sharding_spec[mapping.first]) .str()); else incompatible_labels.insert(mapping.first); } - } else if (Layout::IsShardedSpec(sharding_spec)) { + } else if (Layout::IsShardedDimension(sharding_spec)) { label_to_sharding_spec[mapping.first] = sharding_spec; } } @@ -205,42 +198,41 @@ StatusOr> GetLabelToShardingSpec( // multiple times. E.g. ab,bc->ac (i.e. matmul) with a and c sharded over the // same dim. In this case we mark all such dimensions as replicated. StatusOr VerifyOrFixLayout( - std::pair, absl::flat_hash_map> + std::pair, absl::flat_hash_map> pair, const Mesh& mesh) { - std::vector sharding_specs = pair.first; + std::vector sharding_specs = pair.first; absl::flat_hash_map dimension_use_count = pair.second; for (int i = 0; i < sharding_specs.size(); ++i) - if (Layout::IsShardedSpec(sharding_specs[i]) && - dimension_use_count[sharding_specs[i].sharding_spec()] > 1) - sharding_specs[i].set_sharding_spec(Layout::kUnshardedDim); + if (Layout::IsShardedDimension(sharding_specs[i]) && + dimension_use_count[sharding_specs[i]] > 1) + sharding_specs[i] = Layout::kUnshardedDim; return Layout::GetLayout(sharding_specs, mesh); } // Construct a layout on a given mesh from the label to tensor dimension map // and the label to mesh_dimension map. -std::pair, absl::flat_hash_map> +std::pair, absl::flat_hash_map> GetSpecsFromLabelsAndMap( const absl::flat_hash_map>& label_to_index, - const absl::flat_hash_map& label_to_sharding_spec) { + const absl::flat_hash_map& label_to_sharding_spec) { int layout_rank = 0; for (const auto& label_and_indices : label_to_index) layout_rank += label_and_indices.second.size(); - std::vector sharding_specs(layout_rank); + std::vector sharding_specs(layout_rank); absl::flat_hash_map dimension_use_count; absl::flat_hash_set dimension_use_set; for (const auto& label_and_indices : label_to_index) { const auto& loc = label_to_sharding_spec.find(label_and_indices.first); if (loc != label_to_sharding_spec.end()) { - const ShardingSpec& sharding_spec = loc->second; + const std::string& sharding_spec = loc->second; for (int index : label_and_indices.second) sharding_specs[index] = sharding_spec; - dimension_use_count[sharding_spec.sharding_spec()] += - label_and_indices.second.size(); + dimension_use_count[sharding_spec] += label_and_indices.second.size(); } else { for (int index : label_and_indices.second) - sharding_specs[index].set_sharding_spec(Layout::kUnshardedDim); + sharding_specs[index] = Layout::kUnshardedDim; } } return std::make_pair(sharding_specs, dimension_use_count); @@ -353,11 +345,11 @@ StatusOr> EinsumSPMDExpander::ComputeLayoutBackward( for (size_t i = 0; i < num_inputs; ++i) { absl::flat_hash_map> labels_to_indices = input_mappings[i]; - std::pair, absl::flat_hash_map> + std::pair, absl::flat_hash_map> sharding_specs_and_dim_count = GetSpecsFromLabelsAndMap( labels_to_indices, output_label_to_sharding_spec); - std::vector sharding_specs = + std::vector sharding_specs = sharding_specs_and_dim_count.first; absl::flat_hash_map dim_count = sharding_specs_and_dim_count.second; @@ -367,7 +359,7 @@ StatusOr> EinsumSPMDExpander::ComputeLayoutBackward( char label = label_to_indices.first; if (labels_for_any.contains(label)) { int index = label_to_indices.second[0]; - sharding_specs[index].set_sharding_spec(Layout::kAny); + sharding_specs[index] = Layout::kAny; } } TF_ASSIGN_OR_RETURN( @@ -422,13 +414,12 @@ Status EinsumSPMDExpander::MaybeRelayoutInputs( for (const char label : all_labels) { if (input_label_to_sharding_spec.contains(label) && output_label_to_sharding_spec.contains(label) && - !Equal(input_label_to_sharding_spec[label], - output_label_to_sharding_spec.find(label)->second)) + input_label_to_sharding_spec[label] != + output_label_to_sharding_spec.find(label)->second) return errors::InvalidArgument( "for label ", label, " input and output layouts are sharded on ", - " non-equal dimensions ", - input_label_to_sharding_spec[label].sharding_spec(), " and ", - output_label_to_sharding_spec.find(label)->second.sharding_spec(), + " non-equal dimensions ", input_label_to_sharding_spec[label], + " and ", output_label_to_sharding_spec.find(label)->second, "respectively"); } @@ -438,23 +429,21 @@ Status EinsumSPMDExpander::MaybeRelayoutInputs( for (const auto& input_mapping : input_mappings) for (const auto& char_and_positions : input_mapping) if (char_and_positions.second.size() > 1) - input_label_to_sharding_spec[char_and_positions.first] - .set_sharding_spec(Layout::kUnshardedDim); + input_label_to_sharding_spec[char_and_positions.first] = + Layout::kUnshardedDim; absl::flat_hash_map> sharding_dim_to_non_contracting_labels; absl::flat_hash_map> sharding_dim_to_contracting_labels; for (const auto& label_and_spec : input_label_to_sharding_spec) { - if (Layout::IsShardedSpec(label_and_spec.second)) { + if (Layout::IsShardedDimension(label_and_spec.second)) { if (contracting_labels.contains(label_and_spec.first)) - sharding_dim_to_contracting_labels[label_and_spec.second - .sharding_spec()] - .insert(label_and_spec.first); + sharding_dim_to_contracting_labels[label_and_spec.second].insert( + label_and_spec.first); else - sharding_dim_to_non_contracting_labels[label_and_spec.second - .sharding_spec()] - .insert(label_and_spec.first); + sharding_dim_to_non_contracting_labels[label_and_spec.second].insert( + label_and_spec.first); } } @@ -469,12 +458,11 @@ Status EinsumSPMDExpander::MaybeRelayoutInputs( if (!contracting_labels.contains(label) && output_label_to_sharding_spec.contains(label) && !input_label_to_sharding_spec.contains(label)) { - const ShardingSpec& sharding_spec = + const std::string& string_spec = output_label_to_sharding_spec.find(label)->second; - const std::string& string_spec = sharding_spec.sharding_spec(); if (!sharding_dim_to_non_contracting_labels.contains(string_spec) && !sharding_dim_to_contracting_labels.contains(string_spec)) { - input_label_to_sharding_spec[label] = sharding_spec; + input_label_to_sharding_spec[label] = string_spec; sharding_dim_to_non_contracting_labels[string_spec].insert(label); } } @@ -503,8 +491,7 @@ Status EinsumSPMDExpander::MaybeRelayoutInputs( // keep this stable with respect to ordering. for (const char label : sharding_dim_to_non_contracting_labels[dim]) { if (output_label_to_sharding_spec.contains(label) && - output_label_to_sharding_spec.find(label)->second.sharding_spec() == - dim) { + output_label_to_sharding_spec.find(label)->second == dim) { label_to_keep = label; break; } else if (label < label_to_keep) { @@ -513,8 +500,7 @@ Status EinsumSPMDExpander::MaybeRelayoutInputs( } for (const char label : sharding_dim_to_non_contracting_labels[dim]) if (label != label_to_keep) - input_label_to_sharding_spec[label].set_sharding_spec( - Layout::kUnshardedDim); + input_label_to_sharding_spec[label] = Layout::kUnshardedDim; sharding_dim_to_non_contracting_labels[dim].clear(); sharding_dim_to_non_contracting_labels[dim].insert(label_to_keep); } @@ -530,8 +516,8 @@ Status EinsumSPMDExpander::MaybeRelayoutInputs( assert(spec_and_labels.second.size() == 1); assert(sharding_dim_to_non_contracting_labels[spec_and_labels.first] .size() == 1); - input_label_to_sharding_spec[*spec_and_labels.second.begin()] - .set_sharding_spec(Layout::kUnshardedDim); + input_label_to_sharding_spec[*spec_and_labels.second.begin()] = + Layout::kUnshardedDim; } } @@ -557,8 +543,7 @@ Status EinsumSPMDExpander::MaybeRelayoutInputs( output_layout.mesh())); for (const auto& contracting : contracting_labels) - reduce_dims.emplace( - input_label_to_sharding_spec[contracting].sharding_spec()); + reduce_dims.emplace(input_label_to_sharding_spec[contracting]); return OkStatus(); } diff --git a/tensorflow/dtensor/mlir/expansions/elementwise_spmd_expander.cc b/tensorflow/dtensor/mlir/expansions/elementwise_spmd_expander.cc index e8e24c9b5a2efb..216025d62f5ee5 100644 --- a/tensorflow/dtensor/mlir/expansions/elementwise_spmd_expander.cc +++ b/tensorflow/dtensor/mlir/expansions/elementwise_spmd_expander.cc @@ -164,7 +164,7 @@ ElementwiseSPMDExpander::ComputeLayoutBackward( TF_ASSIGN_OR_RETURN(auto operand_shape, GetShape(operand)); Layout output_layout_truncated = output_layout.Truncate( - output_layout.sharding_specs().size() - operand_shape.size(), + output_layout.sharding_spec_strs().size() - operand_shape.size(), /*end=*/true); auto inferred_operand_layout_strs = output_layout_truncated.sharding_spec_strs(); diff --git a/tensorflow/dtensor/mlir/expansions/expanddims_spmd_expander.cc b/tensorflow/dtensor/mlir/expansions/expanddims_spmd_expander.cc index d2c58de60a3423..1151d46cb9159b 100644 --- a/tensorflow/dtensor/mlir/expansions/expanddims_spmd_expander.cc +++ b/tensorflow/dtensor/mlir/expansions/expanddims_spmd_expander.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include #include "absl/types/optional.h" #include "tensorflow/dtensor/mlir/collectives.h" @@ -47,14 +48,14 @@ StatusOr ExpandDimsExpander::ExpandOp(mlir::Operation* op) { ExtractConstIntFromValue(expand_dims_op.getDim())); if (dim < 0) dim += global_output_shape.size(); - std::vector sharding_specs(global_output_shape.size()); + std::vector sharding_specs(global_output_shape.size()); for (int i = 0; i < global_output_shape.size(); ++i) { if (i < dim) - sharding_specs[i] = operand_layout->dim(i); + sharding_specs[i] = operand_layout->sharding_spec(i); else if (i == dim) - sharding_specs[i].set_sharding_spec(Layout::kUnshardedDim); + sharding_specs[i] = Layout::kUnshardedDim; else - sharding_specs[i] = operand_layout->dim(i - 1); + sharding_specs[i] = operand_layout->sharding_spec(i - 1); } TF_ASSIGN_OR_RETURN(const Layout current_output_layout, Layout::GetLayout(sharding_specs, output_layout->mesh())); diff --git a/tensorflow/dtensor/mlir/expansions/gather_spmd_expander.cc b/tensorflow/dtensor/mlir/expansions/gather_spmd_expander.cc index f6bd92b4ea9770..fd42942f7e4130 100644 --- a/tensorflow/dtensor/mlir/expansions/gather_spmd_expander.cc +++ b/tensorflow/dtensor/mlir/expansions/gather_spmd_expander.cc @@ -209,26 +209,28 @@ StatusOr GatherNdGetOutputLayoutFromInput( // replace them with replicated. // If sharding dimension is used by both params and indices, the params // layout will be respected as generally params is larger than indices. - std::vector output_specs(params_rank - index_dimensions + - indices_rank - 1); + std::vector output_specs(params_rank - index_dimensions + + indices_rank - 1); absl::flat_hash_set used_dimensions; const int params_offset = -index_dimensions + indices_rank - 1; for (int i = index_dimensions; i < params_rank; ++i) { - if (params_layout && Layout::IsShardedSpec(params_layout->dim(i))) { - const ShardingSpec& params_spec = params_layout->dim(i); + if (params_layout && + Layout::IsShardedDimension(params_layout->sharding_spec(i))) { + const auto& params_spec = params_layout->sharding_spec(i); output_specs[i + params_offset] = params_spec; - used_dimensions.emplace(params_spec.sharding_spec()); + used_dimensions.emplace(params_spec); } else { - output_specs[i + params_offset].set_sharding_spec(Layout::kUnshardedDim); + output_specs[i + params_offset] = Layout::kUnshardedDim; } } for (int i = 0; i < indices_rank - 1; ++i) { - if (indices_layout && Layout::IsShardedSpec(indices_layout->dim(i)) && + if (indices_layout && + Layout::IsShardedDimension(indices_layout->sharding_spec(i)) && !used_dimensions.contains(indices_layout->sharding_spec(i))) - output_specs[i] = indices_layout->dim(i); + output_specs[i] = indices_layout->sharding_spec(i); else - output_specs[i].set_sharding_spec(Layout::kUnshardedDim); + output_specs[i] = Layout::kUnshardedDim; } return Layout::GetLayout(output_specs, mesh); } @@ -242,20 +244,20 @@ Status GatherNdGetInputLayoutFromOutput(const Layout& output_layout, // indices_layout (with the last dimensions replicated) and the remaining // dimensions to params_layout (with the first index_dimensions dimensions // replicated). - std::vector params_specs(params_rank); - std::vector indices_specs(indices_rank); + std::vector params_specs(params_rank); + std::vector indices_specs(indices_rank); for (int i = 0; i < index_dimensions; ++i) - params_specs[i].set_sharding_spec(Layout::kUnshardedDim); + params_specs[i] = Layout::kUnshardedDim; const int params_offset = -index_dimensions + indices_rank - 1; for (int i = index_dimensions; i < params_rank; ++i) - params_specs[i] = output_layout.dim(i + params_offset); + params_specs[i] = output_layout.sharding_spec(i + params_offset); for (int i = 0; i < indices_rank - 1; ++i) - indices_specs[i] = output_layout.dim(i); + indices_specs[i] = output_layout.sharding_spec(i); - indices_specs[indices_rank - 1].set_sharding_spec(Layout::kUnshardedDim); + indices_specs[indices_rank - 1] = Layout::kUnshardedDim; TF_ASSIGN_OR_RETURN(*params_layout, Layout::GetLayout(params_specs, mesh)); TF_ASSIGN_OR_RETURN(*indices_layout, Layout::GetLayout(indices_specs, mesh)); @@ -309,21 +311,20 @@ StatusOr GatherNdSPMDExpander::ExpandOp(mlir::Operation* op) { // Step 2) llvm::DenseSet used_dimensions; - for (const ShardingSpec& spec : pre_output_layout.sharding_specs()) - if (Layout::IsShardedSpec(spec)) - used_dimensions.insert(spec.sharding_spec()); + for (const auto& spec : pre_output_layout.sharding_spec_strs()) + if (Layout::IsShardedDimension(spec)) used_dimensions.insert(spec); - std::vector sharding_specs(output_layout.rank()); + std::vector sharding_specs(output_layout.rank()); for (int i = 0; i < sharding_specs.size(); ++i) { - if (Layout::IsShardedSpec(pre_output_layout.dim(i))) - sharding_specs[i] = pre_output_layout.dim(i); + if (Layout::IsShardedDimension(pre_output_layout.sharding_spec(i))) + sharding_specs[i] = pre_output_layout.sharding_spec(i); // Merge in sharded dimensions from the output which aren't already used // by the pre_output_layout. - else if (Layout::IsShardedSpec(output_layout.dim(i)) && + else if (Layout::IsShardedDimension(output_layout.sharding_spec(i)) && !used_dimensions.contains(output_layout.sharding_spec(i))) - sharding_specs[i] = output_layout.dim(i); + sharding_specs[i] = output_layout.sharding_spec(i); else - sharding_specs[i].set_sharding_spec(Layout::kUnshardedDim); + sharding_specs[i] = Layout::kUnshardedDim; } // Step 3) diff --git a/tensorflow/dtensor/mlir/expansions/in_top_k_spmd_expander.cc b/tensorflow/dtensor/mlir/expansions/in_top_k_spmd_expander.cc index 8771b956097519..caae0ade1822ae 100644 --- a/tensorflow/dtensor/mlir/expansions/in_top_k_spmd_expander.cc +++ b/tensorflow/dtensor/mlir/expansions/in_top_k_spmd_expander.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/dtensor/mlir/expansions/in_top_k_spmd_expander.h" #include +#include #include "absl/types/optional.h" #include "mlir/IR/IRMapping.h" // from @llvm-project @@ -35,9 +36,9 @@ namespace { // layout, ensuring that the 2nd dimension is replicated. StatusOr GetSuggestedPredictionsLayout(const Layout& layout) { // predictions is a rank-2 tensor (batch_size x num_classes) - std::vector layout_specs(2); - layout_specs[0].set_sharding_spec(layout.sharding_spec(0)); - layout_specs[1].set_sharding_spec(Layout::kUnshardedDim); + std::vector layout_specs(2); + layout_specs[0] = layout.sharding_spec(0); + layout_specs[1] = Layout::kUnshardedDim; return Layout::GetLayout(layout_specs, layout.mesh()); } @@ -46,10 +47,10 @@ StatusOr GetSuggestedPredictionsLayout(const Layout& layout) { // of "other_layout". StatusOr MatchBatchDim(const Layout& layout, const Layout& other_layout) { - std::vector layout_specs(layout.rank()); - layout_specs[0].set_sharding_spec(other_layout.sharding_spec(0)); + std::vector layout_specs(layout.rank()); + layout_specs[0] = other_layout.sharding_spec(0); for (int i = 1; i < layout.rank(); ++i) { - layout_specs[i].set_sharding_spec(layout.sharding_spec(i)); + layout_specs[i] = layout.sharding_spec(i); } return Layout::GetLayout(layout_specs, layout.mesh()); diff --git a/tensorflow/dtensor/mlir/expansions/matmul_spmd_expander.cc b/tensorflow/dtensor/mlir/expansions/matmul_spmd_expander.cc index 6cd8ac7f174599..af109f98a15a07 100644 --- a/tensorflow/dtensor/mlir/expansions/matmul_spmd_expander.cc +++ b/tensorflow/dtensor/mlir/expansions/matmul_spmd_expander.cc @@ -156,21 +156,19 @@ StatusOr MatMulSPMDExpander::OutputLayoutAndReducedDims( // Input layouts are [batch...],a,b;[batch...],b,c // Output layout is [batch...],a,c - const auto& batch_sharding_specs = batch_layout.sharding_specs(); - std::vector output_dims(batch_sharding_specs.begin(), - batch_sharding_specs.end()); + const auto& batch_sharding_specs = batch_layout.sharding_spec_strs(); + std::vector output_dims(batch_sharding_specs.begin(), + batch_sharding_specs.end()); if (Layout::IsShardedDimension(left_layout.sharding_spec(0)) && left_layout.sharding_spec(0) == right_layout.sharding_spec(1)) { // If a and c above are the same and sharded, we should output a replicated // layout during propagation. This is so we don't create an illegal layout. output_dims.resize(output_dims.size() + 2); - output_dims[output_dims.size() - 2].set_sharding_spec( - Layout::kUnshardedDim); - output_dims[output_dims.size() - 1].set_sharding_spec( - Layout::kUnshardedDim); + output_dims[output_dims.size() - 2] = Layout::kUnshardedDim; + output_dims[output_dims.size() - 1] = Layout::kUnshardedDim; } else { - output_dims.emplace_back(left_layout.dim(0)); - output_dims.emplace_back(right_layout.dim(1)); + output_dims.emplace_back(left_layout.sharding_spec(0)); + output_dims.emplace_back(right_layout.sharding_spec(1)); } return Layout::GetLayout(output_dims, left_layout.mesh()); diff --git a/tensorflow/dtensor/mlir/expansions/meta_spmd_expander.cc b/tensorflow/dtensor/mlir/expansions/meta_spmd_expander.cc index d6896a5540d832..a100300d5320a3 100644 --- a/tensorflow/dtensor/mlir/expansions/meta_spmd_expander.cc +++ b/tensorflow/dtensor/mlir/expansions/meta_spmd_expander.cc @@ -197,20 +197,21 @@ StatusOr UnpackSPMDExpander::ExpandOp(mlir::Operation* op) { TF_ASSIGN_OR_RETURN( int axis, CanonicalizeAxis(unpack.getAxis(), /*packed_rank=*/input_rank)); - if (input_layout->num_shards_for_dim(input_layout->dim(axis)) != 1) { + if (input_layout->num_shards_for_dim(axis) != 1) { // If the axis being unpacked is sharded, relayout to replicated along that // axis since each device needs to split across it. - std::vector new_layout_specs(input_rank); + std::vector new_layout_specs(input_rank); for (int input_index = 0; input_index < input_rank; ++input_index) { if (input_index == axis) { - new_layout_specs[input_index].set_sharding_spec(Layout::kUnshardedDim); + new_layout_specs[input_index] = Layout::kUnshardedDim; } else { - new_layout_specs[input_index] = input_layout->dim(input_index); + new_layout_specs[input_index] = + input_layout->sharding_spec(input_index); } } TF_ASSIGN_OR_RETURN( Layout new_input_layout, - Layout::GetLayout(std::move(new_layout_specs), input_layout->mesh())); + Layout::GetLayout(new_layout_specs, input_layout->mesh())); TF_ASSIGN_OR_RETURN( mlir::Value new_input, EmitRelayout(unpack.getOperand(), *input_layout, new_input_layout)); @@ -250,15 +251,12 @@ Status VerifyPaddedDimensionNotSharded(const Layout& layout, const auto input_shape = input_type.getShape(); const auto output_shape = input_type.getShape(); - for (const auto& dim_shard_and_index : - llvm::enumerate(layout.sharding_specs())) { - const int index = dim_shard_and_index.index(); - const auto& tensor_dimension = dim_shard_and_index.value(); + for (int index = 0; index < layout.rank(); ++index) { const int input_shape_for_dim = input_shape[index]; const int output_shape_for_dim = output_shape[index]; if ((input_shape_for_dim == -1 || output_shape_for_dim == -1 || output_shape_for_dim != input_shape_for_dim) && - layout.num_shards_for_dim(tensor_dimension) > 1) { + layout.num_shards_for_dim(index) > 1) { return errors::InvalidArgument( "Padding over sharded dimension is not allowed."); } @@ -346,11 +344,10 @@ namespace { Status VerifyTileOperandLayout(const Layout& operand_layout, llvm::ArrayRef static_multiples) { for (const auto& tensor_dim_and_multiple : - llvm::zip(operand_layout.sharding_specs(), static_multiples)) { - const auto& tensor_dimension = std::get<0>(tensor_dim_and_multiple); - const int64_t multiple_factor = std::get<1>(tensor_dim_and_multiple); - if (multiple_factor > 1 && - operand_layout.num_shards_for_dim(tensor_dimension) > 1) + llvm::enumerate(static_multiples)) { + const auto& index = tensor_dim_and_multiple.index(); + const int64_t multiple_factor = tensor_dim_and_multiple.value(); + if (multiple_factor > 1 && operand_layout.num_shards_for_dim(index) > 1) return errors::InvalidArgument( "tile op with input sharded at dimension where `multiple` > 1 is not " "supported."); @@ -486,12 +483,11 @@ StatusOr> TileSPMDExpander::ComputeLayoutForward( const Layout input_layout = input_layouts.lookup(0); std::vector output_layout_specs; for (const auto& multiple_and_dim_sharding : - llvm::zip(static_multiple, input_layout.sharding_specs())) { + llvm::zip(static_multiple, input_layout.sharding_spec_strs())) { const int multiple = std::get<0>(multiple_and_dim_sharding); const auto& tensor_dimension = std::get<1>(multiple_and_dim_sharding); - output_layout_specs.push_back(multiple == 1 - ? tensor_dimension.sharding_spec() - : Layout::kUnshardedDim); + output_layout_specs.push_back(multiple == 1 ? tensor_dimension + : Layout::kUnshardedDim); } TF_ASSIGN_OR_RETURN(const Layout output_layout, @@ -542,12 +538,11 @@ StatusOr> TileSPMDExpander::ComputeLayoutBackward( const Layout output_layout = output_layouts.lookup(0); std::vector input_layout_specs; for (const auto& multiple_and_dim_sharding : - llvm::zip(static_multiple, output_layout.sharding_specs())) { + llvm::zip(static_multiple, output_layout.sharding_spec_strs())) { const int multiple = std::get<0>(multiple_and_dim_sharding); const auto& tensor_dimension = std::get<1>(multiple_and_dim_sharding); - input_layout_specs.push_back(multiple == 1 - ? tensor_dimension.sharding_spec() - : Layout::kUnshardedDim); + input_layout_specs.push_back(multiple == 1 ? tensor_dimension + : Layout::kUnshardedDim); } TF_ASSIGN_OR_RETURN(const Layout input_layout, Layout::GetLayout(input_layout_specs, mesh)); @@ -648,8 +643,8 @@ StatusOr MakeLayoutForReshape( // first entry of the input segment divides the output shape on the first // entry of the output segment, we request a sharded layout on that axis. for (int i = 0; i < input_segment_start.size(); ++i) { - const int num_shards = input_layout.num_shards_for_dim( - input_layout.dim(input_segment_start[i])); + const int num_shards = + input_layout.num_shards_for_dim(input_segment_start[i]); if (output_shape[output_segment_start[i]] % num_shards == 0) layout_specs[output_segment_start[i]] = input_layout.sharding_spec(input_segment_start[i]); @@ -711,8 +706,8 @@ StatusOr ReshapeSPMDExpander::ExpandOp(mlir::Operation* op) { // inserted. For example, reshape a [2, 4, 3] shape tensor with layout // ['not_sharded', 'x', 'not_sharded'] to [2, 12] shape tensor fully // replicated can be supported. - std::vector tgt_input_layout(input_layout->rank()); - std::vector tgt_output_layout(output_layout->rank()); + std::vector tgt_input_layout(input_layout->rank()); + std::vector tgt_output_layout(output_layout->rank()); for (int i = 0; i < input_segment_start.size(); ++i) { const int input_start = input_segment_start[i]; @@ -724,14 +719,13 @@ StatusOr ReshapeSPMDExpander::ExpandOp(mlir::Operation* op) { // Between this segment and the last segment, if there is a gap, insert // dimensions of size 1 and kUnshardedDim as output layout dim. for (int j = prev_input_segment_end; j < input_start; ++j) - tgt_input_layout[j].set_sharding_spec(Layout::kUnshardedDim); + tgt_input_layout[j] = Layout::kUnshardedDim; for (int j = prev_output_segment_end; j < output_start; ++j) { local_reshape_const.emplace_back(1); - tgt_output_layout[j].set_sharding_spec(Layout::kUnshardedDim); + tgt_output_layout[j] = Layout::kUnshardedDim; } - const int num_input_shards = - input_layout->num_shards_for_dim(input_layout->dim(input_start)); + const int num_input_shards = input_layout->num_shards_for_dim(input_start); // Decide on the sharding of the input for this segment. // If the input is already sharded, we try to keep this sharding (unless @@ -740,31 +734,32 @@ StatusOr ReshapeSPMDExpander::ExpandOp(mlir::Operation* op) { // we could 'preshard' the input on this dimension before the reshape. // This is unlikely to have any major gains in performance. if (global_output_shape[output_start] % num_input_shards != 0) { - tgt_input_layout[input_start].set_sharding_spec(Layout::kUnshardedDim); - tgt_output_layout[output_start].set_sharding_spec(Layout::kUnshardedDim); + tgt_input_layout[input_start] = Layout::kUnshardedDim; + tgt_output_layout[output_start] = Layout::kUnshardedDim; local_reshape_const.emplace_back(global_output_shape[output_start]); } else { - tgt_input_layout[input_start] = input_layout->dim(input_start); - tgt_output_layout[output_start] = input_layout->dim(input_start); + tgt_input_layout[input_start] = input_layout->sharding_spec(input_start); + tgt_output_layout[output_start] = + input_layout->sharding_spec(input_start); local_reshape_const.emplace_back(global_output_shape[output_start] / num_input_shards); } for (int j = input_start + 1; j < input_segment_end[i]; ++j) - tgt_input_layout[j].set_sharding_spec(Layout::kUnshardedDim); + tgt_input_layout[j] = Layout::kUnshardedDim; for (int j = output_start + 1; j < output_segment_end[i]; ++j) { local_reshape_const.emplace_back(global_output_shape[j]); - tgt_output_layout[j].set_sharding_spec(Layout::kUnshardedDim); + tgt_output_layout[j] = Layout::kUnshardedDim; } } // Fill any remaining dimensions of size 1 and sharding dim on the end of the // layout. for (int j = input_segment_end.back(); j < tgt_input_layout.size(); ++j) - tgt_input_layout[j].set_sharding_spec(Layout::kUnshardedDim); + tgt_input_layout[j] = Layout::kUnshardedDim; for (int j = output_segment_end.back(); j < tgt_output_layout.size(); ++j) { local_reshape_const.emplace_back(1); - tgt_output_layout[j].set_sharding_spec(Layout::kUnshardedDim); + tgt_output_layout[j] = Layout::kUnshardedDim; } TF_ASSIGN_OR_RETURN( @@ -889,8 +884,8 @@ StatusOr TransposeSPMDExpander::ExpandOp( TF_RETURN_IF_ERROR(ExtractConstVectorFromValue(transpose.getPerm(), &perm)); for (const auto& p : llvm::enumerate(perm)) { - if (operand_layout->dim(p.value()).sharding_spec() != - output_layout->dim(p.index()).sharding_spec()) { + if (operand_layout->sharding_spec(p.value()) != + output_layout->sharding_spec(p.index())) { return errors::InvalidArgument( "TransposeOp SPMD needs communication is not supported yet. \n " "operand layout: ", @@ -969,12 +964,12 @@ Status RelayoutOneHotInput(const absl::optional& input_layout, " SPMD expansion. Consider adding Relayout() op to specify the " "layout."); - std::vector sharding_specs(input_layout->rank()); + std::vector sharding_specs(input_layout->rank()); for (int i = 0; i < input_layout->rank(); ++i) { if (i < axis) - sharding_specs[i] = output_layout->dim(i); + sharding_specs[i] = output_layout->sharding_spec(i); else - sharding_specs[i] = output_layout->dim(i + 1); + sharding_specs[i] = output_layout->sharding_spec(i + 1); } TF_ASSIGN_OR_RETURN(const Layout new_input_layout, Layout::GetLayout(sharding_specs, input_layout->mesh())); diff --git a/tensorflow/dtensor/mlir/expansions/nullary_spmd_expander.cc b/tensorflow/dtensor/mlir/expansions/nullary_spmd_expander.cc index dafb83d289e6aa..20ca961d55c8f4 100644 --- a/tensorflow/dtensor/mlir/expansions/nullary_spmd_expander.cc +++ b/tensorflow/dtensor/mlir/expansions/nullary_spmd_expander.cc @@ -63,8 +63,7 @@ StatusOr NullarySPMDExpander::ExpandOp(mlir::Operation* op) { auto shape = dense.getType().getShape(); std::vector new_shape(dense.getType().getRank()); for (int i = 0; i < op_layouts[0]->rank(); ++i) { - const int num_shards = - op_layouts[0]->num_shards_for_dim(op_layouts[0]->dim(i)); + const int num_shards = op_layouts[0]->num_shards_for_dim(i); if (shape[i] % num_shards != 0) return errors::InvalidArgument( "has output dimension size ", shape[i], diff --git a/tensorflow/dtensor/mlir/expansions/random_op_spmd_expander.cc b/tensorflow/dtensor/mlir/expansions/random_op_spmd_expander.cc index d6ec2663fc951b..32e6f4315849ba 100644 --- a/tensorflow/dtensor/mlir/expansions/random_op_spmd_expander.cc +++ b/tensorflow/dtensor/mlir/expansions/random_op_spmd_expander.cc @@ -85,10 +85,9 @@ StatusOr GetDeviceSeed(const Layout& layout, mlir::Operation* op) { // to use as the attribute attached to the squeeze op. llvm::SmallVector layout_dims; llvm::SmallSet layout_dims_set; - for (const ShardingSpec& spec : layout.sharding_specs()) { - if (Layout::IsUnshardedSpec(spec)) continue; - layout_dims.emplace_back( - layout.mesh().GetMeshDimIndexWithName(spec.sharding_spec())); + for (const auto& spec : layout.sharding_spec_strs()) { + if (Layout::IsUnshardedDimension(spec)) continue; + layout_dims.emplace_back(layout.mesh().GetMeshDimIndexWithName(spec)); layout_dims_set.insert(layout_dims.back()); } llvm::sort(layout_dims); diff --git a/tensorflow/dtensor/mlir/expansions/reduce_spmd_expander.cc b/tensorflow/dtensor/mlir/expansions/reduce_spmd_expander.cc index 448d7652aa7fbb..217b41c986bbf0 100644 --- a/tensorflow/dtensor/mlir/expansions/reduce_spmd_expander.cc +++ b/tensorflow/dtensor/mlir/expansions/reduce_spmd_expander.cc @@ -82,11 +82,11 @@ Status AssertReplicated(mlir::Value operand) { absl::flat_hash_set ReducedMeshDimensions( const dtensor::Layout& input, const dtensor::Layout& output) { absl::flat_hash_set mesh_dims; - for (const auto& dim : input.sharding_specs()) { - mesh_dims.insert(dim.sharding_spec()); + for (const auto& dim : input.sharding_spec_strs()) { + mesh_dims.insert(dim); } - for (const auto& dim : output.sharding_specs()) { - mesh_dims.erase(dim.sharding_spec()); + for (const auto& dim : output.sharding_spec_strs()) { + mesh_dims.erase(dim); } return mesh_dims; } diff --git a/tensorflow/dtensor/mlir/expansions/scatter_spmd_expander.cc b/tensorflow/dtensor/mlir/expansions/scatter_spmd_expander.cc index f469ca35e069d8..de11f7494e5750 100644 --- a/tensorflow/dtensor/mlir/expansions/scatter_spmd_expander.cc +++ b/tensorflow/dtensor/mlir/expansions/scatter_spmd_expander.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include #include "llvm/Support/FormatVariadic.h" #include "mlir/IR/BuiltinTypes.h" // from @llvm-project @@ -41,33 +42,32 @@ StatusOr GetOutputLayout(const absl::optional& tensor_layout, // to replicated. The remainder are set from tensor_layout and updates_layout // with tensor_layout taking priority, as it is generally larger than updates // (as unsharding updates is faster). - std::vector output_specs(tensor_rank); + std::vector output_specs(tensor_rank); // The number of dimensions at the start of the tensor input that are used // for the index, also the size of the second dimension of the indices tensor. const int index_dimensions = tensor_rank - (updates_rank - 1); - for (int i = 0; i < tensor_rank; ++i) - output_specs[i].set_sharding_spec(Layout::kUnshardedDim); + for (int i = 0; i < tensor_rank; ++i) output_specs[i] = Layout::kUnshardedDim; absl::flat_hash_set used_mesh_dims; if (tensor_layout) { for (int i = index_dimensions; i < tensor_rank; ++i) { - output_specs[i] = tensor_layout->dim(i); - if (Layout::IsShardedSpec(output_specs[i])) - used_mesh_dims.emplace(output_specs[i].sharding_spec()); + output_specs[i] = tensor_layout->sharding_spec(i); + if (Layout::IsShardedDimension(output_specs[i])) + used_mesh_dims.emplace(output_specs[i]); } } if (updates_layout) { for (int i = index_dimensions; i < tensor_rank; ++i) { - const ShardingSpec& update_spec = - updates_layout->dim(i - index_dimensions + 1); + const auto& update_spec = + updates_layout->sharding_spec(i - index_dimensions + 1); - if (Layout::IsUnshardedSpec(output_specs[i]) && - Layout::IsShardedSpec(update_spec) && - !used_mesh_dims.contains(update_spec.sharding_spec())) + if (Layout::IsUnshardedDimension(output_specs[i]) && + Layout::IsShardedDimension(update_spec) && + !used_mesh_dims.contains(update_spec)) output_specs[i] = update_spec; } } @@ -122,13 +122,14 @@ StatusOr TensorScatterOpExpand(mlir::Operation* op) { GetOutputLayout(tensor_layout, tensor_rank, updates_layout, updates_rank, tensor_layout->mesh())); - std::vector updates_specs(updates_rank); - updates_specs[0].set_sharding_spec(Layout::kUnshardedDim); + std::vector updates_specs(updates_rank); + updates_specs[0] = Layout::kUnshardedDim; const int index_dimensions = tensor_rank - (updates_rank - 1); for (int i = 0; i < updates_rank - 1; ++i) - updates_specs[i + 1] = pre_output_layout.dim(index_dimensions + i); + updates_specs[i + 1] = + pre_output_layout.sharding_spec(index_dimensions + i); TF_ASSIGN_OR_RETURN(Layout new_updates_layout, Layout::GetLayout(updates_specs, updates_layout->mesh())); diff --git a/tensorflow/dtensor/mlir/expansions/softmax_spmd_expander.cc b/tensorflow/dtensor/mlir/expansions/softmax_spmd_expander.cc index d28847e2cf896f..ec798f7faf0532 100644 --- a/tensorflow/dtensor/mlir/expansions/softmax_spmd_expander.cc +++ b/tensorflow/dtensor/mlir/expansions/softmax_spmd_expander.cc @@ -71,7 +71,7 @@ StatusOr ComputeGlobalReduce( // Then an all reduce. absl::flat_hash_set reduced_sharding_specs; for (const int dim : reduced_dims) - if (Layout::IsShardedSpec(input_layout.dim(dim))) + if (Layout::IsShardedDimension(input_layout.sharding_spec(dim))) reduced_sharding_specs.emplace(input_layout.sharding_spec(dim)); TF_ASSIGN_OR_RETURN( mlir::Operation * global_reduce, @@ -192,12 +192,12 @@ StatusOr ComputeShardedSoftmax(mlir::OpBuilder& builder, // 1) Left truncated to match the size of global_shape. // 2) Has unsharded dimensions where ever global_shape is 1. StatusOr GetBroadcastedLayout(llvm::ArrayRef global_shape, - const std::vector& specs, + const std::vector& specs, const Mesh& mesh) { - std::vector new_specs(global_shape.size()); + std::vector new_specs(global_shape.size()); for (int i = 0; i < global_shape.size(); ++i) { if (global_shape[i] == 1) - new_specs[i].set_sharding_spec(Layout::kUnshardedDim); + new_specs[i] = Layout::kUnshardedDim; else new_specs[i] = specs[i + specs.size() - global_shape.size()]; } @@ -248,11 +248,10 @@ StatusOr ComputeOneHot(mlir::OpBuilder& builder, "expected feature input to have at least rank 1, but found rank 0"); const int64_t local_classes = features_type.getShape().back(); - const int64_t classes = - local_classes * - desired_layout.num_shards_for_dim(desired_layout.sharding_specs().back()); + const int64_t classes = local_classes * desired_layout.num_shards_for_dim( + desired_layout.rank() - 1); - int64_t num_shards = desired_layout.num_shards_for_dim(desired_layout.dim(1)); + int64_t num_shards = desired_layout.num_shards_for_dim(1); if (classes % num_shards) return errors::InvalidArgument("unable to shard onehot with size ", classes, " over dimension with ", num_shards, @@ -399,9 +398,9 @@ StatusOr SoftmaxLossOpSPMDExpander::MaybeRelayoutInputs( // This layout represents the 'internal layout' that the softmax will be // operating on. Inputs will be relayout'ed to this layout and outputs will be // relayout'ed from this layout to their desired layout. - std::vector internal_layout(2); - internal_layout[0].set_sharding_spec(Layout::kUnshardedDim); - internal_layout[1].set_sharding_spec(Layout::kUnshardedDim); + std::vector internal_layout(2); + internal_layout[0] = Layout::kUnshardedDim; + internal_layout[1] = Layout::kUnshardedDim; // Choose an internal layout, ideally this layout would be chosen so that // the relayout costs for the inputs (from features_layout/labels_layout to @@ -412,32 +411,34 @@ StatusOr SoftmaxLossOpSPMDExpander::MaybeRelayoutInputs( // Pick a batch sharding, first from features, then labels, loss and backprop. // Due to possible broadcasting on features and labels, they will only // have a batch dim if they are rank 2. - if (features_layout.rank() == 2) internal_layout[0] = features_layout.dim(0); + if (features_layout.rank() == 2) + internal_layout[0] = features_layout.sharding_spec(0); if (((labels_layout.rank() == 2) || (is_sparse && labels_layout.rank() == 1)) && - Layout::IsUnshardedSpec(internal_layout[0])) - internal_layout[0] = labels_layout.dim(0); - if (Layout::IsUnshardedSpec(internal_layout[0])) - internal_layout[0] = loss_layout.dim(0); - if (Layout::IsUnshardedSpec(internal_layout[0])) - internal_layout[0] = backprop_layout.dim(0); + Layout::IsUnshardedDimension(internal_layout[0])) + internal_layout[0] = labels_layout.sharding_spec(0); + if (Layout::IsUnshardedDimension(internal_layout[0])) + internal_layout[0] = loss_layout.sharding_spec(0); + if (Layout::IsUnshardedDimension(internal_layout[0])) + internal_layout[0] = backprop_layout.sharding_spec(0); // Pick a class sharding, first from features, then labels and backprop. // The class dim for features and labels is always the last dim if it exists. // Note that loss and backprop have fixed ranks 1 and 2 respectively where as // ranks of features and labels may involved broadcasting. if (features_layout.rank() > 0 && - (internal_layout[0].sharding_spec() != + (internal_layout[0] != features_layout.sharding_spec(features_layout.rank() - 1))) - internal_layout[1] = features_layout.dim(features_layout.rank() - 1); + internal_layout[1] = + features_layout.sharding_spec(features_layout.rank() - 1); if (!is_sparse && labels_layout.rank() > 0 && - Layout::IsUnshardedSpec(internal_layout[1]) && - (internal_layout[0].sharding_spec() != + Layout::IsUnshardedDimension(internal_layout[1]) && + (internal_layout[0] != labels_layout.sharding_spec(labels_layout.rank() - 1))) - internal_layout[1] = labels_layout.dim(labels_layout.rank() - 1); - if (Layout::IsUnshardedSpec(internal_layout[1]) && - (internal_layout[0].sharding_spec() != backprop_layout.sharding_spec(1))) - internal_layout[1] = backprop_layout.dim(1); + internal_layout[1] = labels_layout.sharding_spec(labels_layout.rank() - 1); + if (Layout::IsUnshardedDimension(internal_layout[1]) && + (internal_layout[0] != backprop_layout.sharding_spec(1))) + internal_layout[1] = backprop_layout.sharding_spec(1); TF_ASSIGN_OR_RETURN( llvm::ArrayRef features_global_shape, @@ -464,7 +465,7 @@ StatusOr SoftmaxLossOpSPMDExpander::MaybeRelayoutInputs( if (is_sparse) { // If we are sparse, then the only possible dimension is the batch_dim. - std::vector sparse_specs = {internal_layout[0]}; + std::vector sparse_specs = {internal_layout[0]}; TF_ASSIGN_OR_RETURN(new_labels_layout, GetBroadcastedLayout(labels_global_shape, sparse_specs, labels_layout.mesh())); @@ -560,7 +561,7 @@ StatusOr SoftmaxLossOpSPMDExpander::ExpandOp( assert(internal_layout.rank() == 2); // If the class dim is unshared, we can emit a local op. - if (Layout::IsUnshardedSpec(internal_layout.dim(1))) { + if (Layout::IsUnshardedDimension(internal_layout.sharding_spec(1))) { op = InferSPMDExpandedLocalShape(op); return MaybeRelayoutOutputs(op, op->getResult(0), op->getResult(1), internal_layout, output_layouts[0], @@ -662,31 +663,32 @@ SoftmaxLossOpSPMDExpander::ComputeLayoutForward( labels_layout.emplace(input_layouts.lookup(1)); // We need to compute shardings for two dimensions: batch and class. - std::vector layout_specs(2); - layout_specs[0].set_sharding_spec(Layout::kUnshardedDim); - layout_specs[1].set_sharding_spec(Layout::kUnshardedDim); + std::vector layout_specs(2); + layout_specs[0] = Layout::kUnshardedDim; + layout_specs[1] = Layout::kUnshardedDim; // First pick the batch dimension, set it to the batch dimension of features // if it exists otherwise to the batch dimesion of labels. if (features_layout && (features_layout->rank() == 2)) - layout_specs[0] = features_layout->dim(0); + layout_specs[0] = features_layout->sharding_spec(0); if (labels_layout && (labels_layout->rank() == 2 || (is_sparse && labels_layout->rank() == 1)) && - Layout::IsUnshardedSpec(layout_specs[0])) - layout_specs[0] = labels_layout->dim(0); + Layout::IsUnshardedDimension(layout_specs[0])) + layout_specs[0] = labels_layout->sharding_spec(0); - // The class dim for features and labels is always the last dim if it - // exists. + // The class sharding_spec for features and labels is always the last + // sharding_spec if it exists. if (features_layout && (features_layout->rank() > 0) && - (layout_specs[0].sharding_spec() != + (layout_specs[0] != features_layout->sharding_spec(features_layout->rank() - 1))) - layout_specs[1] = features_layout->dim(features_layout->rank() - 1); + layout_specs[1] = + features_layout->sharding_spec(features_layout->rank() - 1); if (!is_sparse && labels_layout && (labels_layout->rank() > 0) && - Layout::IsUnshardedSpec(layout_specs[1]) && - (layout_specs[0].sharding_spec() != + Layout::IsUnshardedDimension(layout_specs[1]) && + (layout_specs[0] != labels_layout->sharding_spec(labels_layout->rank() - 1))) - layout_specs[1] = labels_layout->dim(labels_layout->rank() - 1); + layout_specs[1] = labels_layout->sharding_spec(labels_layout->rank() - 1); TF_ASSIGN_OR_RETURN(const Layout backprop_layout, Layout::GetLayout(layout_specs, mesh)); @@ -711,20 +713,19 @@ SoftmaxLossOpSPMDExpander::ComputeLayoutBackward( // We need to compute two possible shardings: // One for the batch dimension and one for the class dimension. - std::vector layout_specs(2); - layout_specs[0].set_sharding_spec(Layout::kUnshardedDim); - layout_specs[1].set_sharding_spec(Layout::kUnshardedDim); + std::vector layout_specs(2); + layout_specs[0] = Layout::kUnshardedDim; + layout_specs[1] = Layout::kUnshardedDim; // Respect the loss layout if it is set, otherwise use the backprop // layout for the batch_dim. - if (loss_layout) layout_specs[0] = loss_layout->dim(0); - if (backprop_layout && Layout::IsUnshardedSpec(layout_specs[0])) - layout_specs[0] = backprop_layout->dim(0); + if (loss_layout) layout_specs[0] = loss_layout->sharding_spec(0); + if (backprop_layout && Layout::IsUnshardedDimension(layout_specs[0])) + layout_specs[0] = backprop_layout->sharding_spec(0); // Only backprop has class dim so use that if it is available. - if (backprop_layout && - backprop_layout->sharding_spec(1) != layout_specs[0].sharding_spec()) - layout_specs[1] = backprop_layout->dim(1); + if (backprop_layout && backprop_layout->sharding_spec(1) != layout_specs[0]) + layout_specs[1] = backprop_layout->sharding_spec(1); TF_ASSIGN_OR_RETURN(const auto features_shape, GetShapeOfValue(op->getOperand(0))); diff --git a/tensorflow/dtensor/mlir/expansions/split_spmd_expander.cc b/tensorflow/dtensor/mlir/expansions/split_spmd_expander.cc index f72b22d98df62f..8b271cccb469d9 100644 --- a/tensorflow/dtensor/mlir/expansions/split_spmd_expander.cc +++ b/tensorflow/dtensor/mlir/expansions/split_spmd_expander.cc @@ -17,6 +17,8 @@ limitations under the License. #include #include +#include +#include #include "absl/types/optional.h" #include "mlir/IR/Value.h" // from @llvm-project @@ -40,25 +42,22 @@ StatusOr MergeLayoutsForSplitOutput( int64_t split_dim, const llvm::DenseMap& layouts) { assert(!layouts.empty()); const Layout& first_layout = layouts.begin()->getSecond(); - std::vector sharding_specs( - first_layout.sharding_specs().begin(), - first_layout.sharding_specs().end()); + std::vector sharding_specs = first_layout.sharding_spec_strs(); // Merge remaining layouts. If there is a conflicting sharding, then set the // dim to replicated. for (auto it = layouts.begin(); it != layouts.end(); ++it) { const Layout& output_layout = it->getSecond(); for (int dim = 0; dim < output_layout.rank(); ++dim) { - if (Layout::IsShardedDimension(output_layout.dim(dim).sharding_spec()) && - Layout::IsShardedDimension(sharding_specs[dim].sharding_spec()) && - output_layout.dim(dim).sharding_spec() != - sharding_specs[dim].sharding_spec()) { - sharding_specs[dim].set_sharding_spec(Layout::kUnshardedDim); + if (Layout::IsShardedDimension(output_layout.sharding_spec(dim)) && + Layout::IsShardedDimension(sharding_specs[dim]) && + output_layout.sharding_spec(dim) != sharding_specs[dim]) { + sharding_specs[dim] = Layout::kUnshardedDim; } } } // Force the split_dim to be unsharded. - sharding_specs[split_dim].set_sharding_spec(Layout::kUnshardedDim); + sharding_specs[split_dim] = Layout::kUnshardedDim; return Layout::GetLayout(sharding_specs, first_layout.mesh()); } @@ -89,7 +88,7 @@ StatusOr SplitSPMDExpander::ExpandOp(mlir::Operation* op) { const int64_t split_dim, GetAdjustedSplitDim(split_op.getSplitDim(), split_op.getValue())); - if (Layout::IsShardedDimension(input_layout.dim(split_dim).sharding_spec())) { + if (Layout::IsShardedDimension(input_layout.sharding_spec(split_dim))) { return errors::InvalidArgument( "Spliting over sharded dimension is not supported."); } @@ -142,7 +141,7 @@ StatusOr SplitVSPMDExpander::ExpandOp(mlir::Operation* op) { const int64_t split_dim, GetAdjustedSplitDim(split_v_op.getSplitDim(), split_v_op.getValue())); - if (Layout::IsShardedDimension(input_layout.dim(split_dim).sharding_spec())) { + if (Layout::IsShardedDimension(input_layout.sharding_spec(split_dim))) { return errors::InvalidArgument( "Spliting over sharded dimension is not supported."); } diff --git a/tensorflow/dtensor/mlir/expansions/squeeze_spmd_expander.cc b/tensorflow/dtensor/mlir/expansions/squeeze_spmd_expander.cc index ebfd6d067196d4..c40e08814c36a1 100644 --- a/tensorflow/dtensor/mlir/expansions/squeeze_spmd_expander.cc +++ b/tensorflow/dtensor/mlir/expansions/squeeze_spmd_expander.cc @@ -15,7 +15,9 @@ limitations under the License. #include "tensorflow/dtensor/mlir/expansions/squeeze_spmd_expander.h" +#include #include +#include #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "tensorflow/core/platform/errors.h" @@ -55,16 +57,16 @@ StatusOr> SqueezeSPMDExpander::ComputeLayoutForward( TF_ASSIGN_OR_RETURN(auto shape, ExtractGlobalInputShape(op->getOpOperand(0))); std::set squeeze_dims = GetSqueezeDims(op, /*rank=*/shape.size()); - std::vector layout_specs; + std::vector layout_specs; layout_specs.reserve(input_layout.rank()); for (int64 i = 0; i < input_layout.rank(); ++i) { if (squeeze_dims.empty()) { if (shape[i] > 1) { - layout_specs.push_back(input_layout.dim(i)); + layout_specs.push_back(input_layout.sharding_spec(i)); } } else { if (squeeze_dims.find(i) == squeeze_dims.end()) { - layout_specs.push_back(input_layout.dim(i)); + layout_specs.push_back(input_layout.sharding_spec(i)); } } } @@ -85,24 +87,21 @@ SqueezeSPMDExpander::ComputeLayoutBackward( TF_ASSIGN_OR_RETURN(auto shape, ExtractGlobalInputShape(op->getOpOperand(0))); std::set squeeze_dims = GetSqueezeDims(op, /*rank=*/shape.size()); - ShardingSpec unsharded_spec; - unsharded_spec.set_sharding_spec(Layout::kUnshardedDim); - - std::vector layout_specs; + std::vector layout_specs; layout_specs.reserve(output_layout.rank()); size_t j = 0; for (size_t i = 0; i < shape.size(); ++i) { if (squeeze_dims.empty()) { if (shape[i] > 1) { - layout_specs.push_back(output_layout.dim(j++)); + layout_specs.push_back(output_layout.sharding_spec(j++)); } else { - layout_specs.push_back(unsharded_spec); + layout_specs.push_back(Layout::kUnshardedDim); } } else { if (squeeze_dims.find(i) == squeeze_dims.end()) { - layout_specs.push_back(output_layout.dim(j++)); + layout_specs.push_back(output_layout.sharding_spec(j++)); } else { - layout_specs.push_back(unsharded_spec); + layout_specs.push_back(Layout::kUnshardedDim); } } } diff --git a/tensorflow/dtensor/mlir/expansions/top_k_spmd_expander.cc b/tensorflow/dtensor/mlir/expansions/top_k_spmd_expander.cc index 6b8a5002489da0..2de0f5851b0575 100644 --- a/tensorflow/dtensor/mlir/expansions/top_k_spmd_expander.cc +++ b/tensorflow/dtensor/mlir/expansions/top_k_spmd_expander.cc @@ -15,6 +15,9 @@ limitations under the License. #include "tensorflow/dtensor/mlir/expansions/top_k_spmd_expander.h" +#include +#include + #include "mlir/IR/IRMapping.h" // from @llvm-project #include "tensorflow/core/platform/errors.h" #include "tensorflow/dtensor/cc/dstatus.h" @@ -29,14 +32,12 @@ namespace dtensor { // layout -> layout[:-1] + unsharded StatusOr GetSuggestedLayout(const Layout& input_layout) { - std::vector layout_specs(input_layout.rank()); + std::vector layout_specs(input_layout.rank()); for (int i = 0; i < input_layout.rank() - 1; ++i) { - layout_specs[i].set_sharding_spec(input_layout.sharding_spec(i)); + layout_specs[i] = input_layout.sharding_spec(i); } - layout_specs[input_layout.rank() - 1].set_sharding_spec( - Layout::kUnshardedDim); - + layout_specs[input_layout.rank() - 1] = Layout::kUnshardedDim; return Layout::GetLayout(layout_specs, input_layout.mesh()); } diff --git a/tensorflow/dtensor/mlir/expansions/where_spmd_expander.cc b/tensorflow/dtensor/mlir/expansions/where_spmd_expander.cc index 0ee755d440332c..c93ecf9e4ca679 100644 --- a/tensorflow/dtensor/mlir/expansions/where_spmd_expander.cc +++ b/tensorflow/dtensor/mlir/expansions/where_spmd_expander.cc @@ -119,7 +119,7 @@ StatusOr> WhereOpSPMDExpander::ComputeLayoutForward( // Append an unsharded sharding spec for the index dimension generated by the // Where op. std::vector layout_specs; - layout_specs.push_back(layout.dim(0).sharding_spec()); + layout_specs.push_back(layout.sharding_spec(0)); layout_specs.push_back(Layout::kUnshardedDim); TF_ASSIGN_OR_RETURN(Layout new_layout, Layout::GetLayout(layout_specs, layout.mesh())); @@ -138,7 +138,7 @@ WhereOpSPMDExpander::ComputeLayoutBackward( std::vector layout_specs; layout_specs.reserve(layout.rank() - 1); for (int i = 0; i < layout.rank() - 1; i++) { - layout_specs.push_back(layout.dim(i).sharding_spec()); + layout_specs.push_back(layout.sharding_spec(i)); } TF_ASSIGN_OR_RETURN(Layout new_layout, Layout::GetLayout(layout_specs, layout.mesh())); diff --git a/tensorflow/dtensor/mlir/ir/tf_dtensor.cc b/tensorflow/dtensor/mlir/ir/tf_dtensor.cc index c9b4fe235ec7db..14e66c98b6cb2c 100644 --- a/tensorflow/dtensor/mlir/ir/tf_dtensor.cc +++ b/tensorflow/dtensor/mlir/ir/tf_dtensor.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/dtensor/mlir/ir/tf_dtensor.h" #include +#include #include "llvm/Support/FormatVariadic.h" #include "mlir/IR/BuiltinTypes.h" // from @llvm-project @@ -210,8 +211,8 @@ mlir::LogicalResult DTensorAllToAllOp::verify() { int32_t num_split_dims = 0; int32_t num_concat_dims = 0; - tensorflow::dtensor::ShardingSpec split_spec; - tensorflow::dtensor::ShardingSpec concat_spec; + std::string split_spec; + std::string concat_spec; for (int32_t i = 0; i < input_layout.rank(); ++i) { if (input_layout.sharding_spec(i) == output_layout.sharding_spec(i)) continue; @@ -220,17 +221,17 @@ mlir::LogicalResult DTensorAllToAllOp::verify() { tensorflow::dtensor::Layout::IsShardedDimension( output_layout.sharding_spec(i))) { num_split_dims++; - split_spec = output_layout.dim(i); + split_spec = output_layout.sharding_spec(i); } else if (tensorflow::dtensor::Layout::IsShardedDimension( input_layout.sharding_spec(i)) && tensorflow::dtensor::Layout::IsUnshardedDimension( output_layout.sharding_spec(i))) { num_concat_dims++; - concat_spec = input_layout.dim(i); + concat_spec = input_layout.sharding_spec(i); } } if (num_split_dims != 1 || num_concat_dims != 1 || - split_spec.sharding_spec() != concat_spec.sharding_spec()) { + split_spec != concat_spec) { return op.emitOpError() << "must have one mesh dimension which is being " "unsharded in one axis and sharded in another"; } diff --git a/tensorflow/dtensor/mlir/utils/collective_lowering.cc b/tensorflow/dtensor/mlir/utils/collective_lowering.cc index 9f874b73d92e00..5c03db58e8e9e3 100644 --- a/tensorflow/dtensor/mlir/utils/collective_lowering.cc +++ b/tensorflow/dtensor/mlir/utils/collective_lowering.cc @@ -960,9 +960,8 @@ mlir::LogicalResult LowerAllGatherOpToCollective( } for (int i = 0; i < src_layout.rank(); i++) { - if (src_layout.num_shards_for_dim(src_layout.dim(i)) == - tgt_layout.num_shards_for_dim(tgt_layout.dim(i)) || - src_layout.num_shards_for_dim(src_layout.dim(i)) == 1) { + if (src_layout.num_shards_for_dim(i) == tgt_layout.num_shards_for_dim(i) || + src_layout.num_shards_for_dim(i) == 1) { continue; } @@ -970,8 +969,7 @@ mlir::LogicalResult LowerAllGatherOpToCollective( perm_for_transpose[0] = perm_for_transpose[i]; perm_for_transpose[i] = temp; - num_shards_per_dim.push_back( - src_layout.num_shards_for_dim(src_layout.dim(i))); + num_shards_per_dim.push_back(src_layout.num_shards_for_dim(i)); previous_sharded_dim[i] = last_sharded_dim; last_sharded_dim = i; @@ -1013,9 +1011,9 @@ mlir::LogicalResult LowerAllGatherOpToCollective( prev_op_result = reshape_op->getResult(0); for (int i = src_layout.rank() - 1; i >= 0; i--) { - if (src_layout.num_shards_for_dim(src_layout.dim(i)) == - tgt_layout.num_shards_for_dim(tgt_layout.dim(i)) || - src_layout.num_shards_for_dim(src_layout.dim(i)) == 1) { + if (src_layout.num_shards_for_dim(i) == + tgt_layout.num_shards_for_dim(i) || + src_layout.num_shards_for_dim(i) == 1) { continue; } @@ -1058,7 +1056,7 @@ mlir::LogicalResult LowerAllGatherOp(mlir::TF::DTensorAllGatherOp all_gather) { llvm::SmallVector concat_dims; for (int64 i = 0; i < src_layout.rank(); ++i) - if (src_layout.num_shards_for_dim(src_layout.dim(i)) > 1 && + if (src_layout.num_shards_for_dim(i) > 1 && Layout::IsUnshardedDimension(tgt_layout.sharding_spec(i))) concat_dims.push_back(i); @@ -1366,9 +1364,8 @@ mlir::LogicalResult LowerAllToAllHelper( absl::flat_hash_set dims_to_gather; for (int i = 0; i < src_layout.rank(); i++) { - if (src_layout.num_shards_for_dim(src_layout.dim(i)) == - tgt_layout.num_shards_for_dim(tgt_layout.dim(i)) || - src_layout.num_shards_for_dim(src_layout.dim(i)) == 1) { + if (src_layout.num_shards_for_dim(i) == tgt_layout.num_shards_for_dim(i) || + src_layout.num_shards_for_dim(i) == 1) { continue; } dims_to_gather.insert(src_layout.sharding_spec(i)); From e039d120001d10289556e82226a19a4146135ef4 Mon Sep 17 00:00:00 2001 From: Justin Szaday Date: Tue, 11 Jul 2023 10:00:54 -0700 Subject: [PATCH 133/376] Update multi-device expansion to handle resources. PiperOrigin-RevId: 547220844 --- tensorflow/dtensor/cc/constants.h | 2 + tensorflow/dtensor/cc/dtensor_device.cc | 2 +- tensorflow/dtensor/cc/dtensor_device_util.cc | 35 ++- tensorflow/dtensor/cc/dtensor_device_util.h | 3 + .../mlir/dtensor_multi_device_expansion.cc | 200 ++++++++++++++---- 5 files changed, 195 insertions(+), 47 deletions(-) diff --git a/tensorflow/dtensor/cc/constants.h b/tensorflow/dtensor/cc/constants.h index 3ad6d6c39a706b..b14ea2438af7f2 100644 --- a/tensorflow/dtensor/cc/constants.h +++ b/tensorflow/dtensor/cc/constants.h @@ -58,6 +58,8 @@ static constexpr char kNewResourceLayoutIndices[] = // Attribute carries layout for newly inferred layout of resource handle. static constexpr char kNewResourceArgLayouts[] = "_inferred_resource_layouts"; +static constexpr char kNumLocalOutputsAttr[] = "_num_local_outputs"; + // Attribute carries input layout information for shape op. static constexpr char kShapeOpInputLayout[] = "_shape_input_layout"; diff --git a/tensorflow/dtensor/cc/dtensor_device.cc b/tensorflow/dtensor/cc/dtensor_device.cc index d413b37141c3dc..deb26a332d7581 100644 --- a/tensorflow/dtensor/cc/dtensor_device.cc +++ b/tensorflow/dtensor/cc/dtensor_device.cc @@ -1839,8 +1839,8 @@ void DTensorDevice::ExecuteMultiDeviceOperation( int output_offset = 0; for (int i = 0; i < num_output_layouts; i++) { const Layout& output_layout = function.output_layouts[i]; + const int num_devices = function.num_local_outputs[i]; std::vector layout_outputs; - const int num_devices = output_layout.num_devices(); for (int j = 0; j < num_devices; j++) { const int output_idx = output_offset + j; layout_outputs.emplace_back(std::move(eager_outputs[output_idx])); diff --git a/tensorflow/dtensor/cc/dtensor_device_util.cc b/tensorflow/dtensor/cc/dtensor_device_util.cc index 319fc9bc521f6c..8cc8fa32c6277c 100644 --- a/tensorflow/dtensor/cc/dtensor_device_util.cc +++ b/tensorflow/dtensor/cc/dtensor_device_util.cc @@ -922,6 +922,21 @@ Status PrepareGraphForMlir( return OkStatus(); } +StatusOr> GetNumLocalOutputs(Node* node) { + const AttrValue* num_local_outputs = + (node->attrs()).Find(kNumLocalOutputsAttr); + if (num_local_outputs == nullptr) { + return absl::InvalidArgumentError("missing num_local_outputs attribute"); + } else { + const AttrValue_ListValue& list = num_local_outputs->list(); + std::vector res; + res.reserve(list.i_size()); + std::copy(list.i().begin(), list.i().end(), std::back_inserter(res)); + return res; + } +} + +namespace { Status SetMultiDeviceFunctionOutputs( TranslatedFunction& function, Node* node, const std::vector& global_output_shapes) { @@ -929,6 +944,8 @@ Status SetMultiDeviceFunctionOutputs( if (serialized_layouts == nullptr) { return absl::InvalidArgumentError("missing layout attribute"); } + TF_ASSIGN_OR_RETURN(std::vector num_local_outputs, + GetNumLocalOutputs(node)); const auto& serialized_layout_list = serialized_layouts->list(); for (int i = 0; i < serialized_layout_list.s_size(); i++) { const auto& serialized_layout = serialized_layout_list.s(i); @@ -936,17 +953,26 @@ Status SetMultiDeviceFunctionOutputs( Layout::FromString(serialized_layout)); function.output_layouts.emplace_back(std::move(layout)); } - for (int i = 0; i < function.output_layouts.size(); i++) { - const Layout& output_layout = function.output_layouts[i]; + int num_output_layouts = function.output_layouts.size(); + for (int i = 0; i < num_output_layouts; i++) { + const Layout* output_layout = &(function.output_layouts[i]); + if (output_layout->IsEmpty()) { + const auto search = function.resource_input_layouts.find(i); + if (search != function.resource_input_layouts.end()) { + output_layout = &(search->second); + } + } PartialTensorShape local_shape = - output_layout.LocalShapeFromGlobalShape(global_output_shapes[i]); - const int num_devices = output_layout.num_devices(); + output_layout->LocalShapeFromGlobalShape(global_output_shapes[i]); + const int64_t num_devices = num_local_outputs[i]; for (int j = 0; j < num_devices; j++) { function.local_output_shapes.emplace_back(local_shape); } } + function.num_local_outputs = std::move(num_local_outputs); return OkStatus(); } +} // namespace // Returns set of functions to run to execute DTensor computation. StatusOr IdentifyAllFunctionsToExecute( @@ -1033,6 +1059,7 @@ StatusOr IdentifyAllFunctionsToExecute( function.local_output_shapes.emplace_back( output_layout.LocalShapeFromGlobalShape( global_output_shapes[global_index])); + function.num_local_outputs.emplace_back(1); } } diff --git a/tensorflow/dtensor/cc/dtensor_device_util.h b/tensorflow/dtensor/cc/dtensor_device_util.h index 00ef825c92ca26..6bc5b2c4295aa9 100644 --- a/tensorflow/dtensor/cc/dtensor_device_util.h +++ b/tensorflow/dtensor/cc/dtensor_device_util.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_DTENSOR_CC_DTENSOR_DEVICE_UTIL_H_ #include +#include #include #include #include @@ -114,6 +115,8 @@ struct TranslatedFunction { std::vector local_output_shapes; // Output data types. std::vector output_dtypes; + // Number of local outputs for each layout. + std::vector num_local_outputs; }; struct ExecutionFunctions { diff --git a/tensorflow/dtensor/mlir/dtensor_multi_device_expansion.cc b/tensorflow/dtensor/mlir/dtensor_multi_device_expansion.cc index eb8d83034038dd..92bafef7f7e252 100644 --- a/tensorflow/dtensor/mlir/dtensor_multi_device_expansion.cc +++ b/tensorflow/dtensor/mlir/dtensor_multi_device_expansion.cc @@ -16,6 +16,7 @@ limitations under the License. #include #include #include +#include #include #include "absl/container/flat_hash_map.h" @@ -92,29 +93,28 @@ mlir::BlockArgument InsertArgumentForDevice(mlir::OpBuilder& builder, // Returns the user of all the ops in the span iff it is a single return op. // Otherwise, returns nullptr; for example, if there are multiple return ops. -template -mlir::func::ReturnOp GetReturnOpFromUsers(absl::Span ops) { - mlir::func::ReturnOp return_op; - - for (Operation op : ops) { +template +mlir::LogicalResult GetReturnOpFromUsers(Operations&& ops, + mlir::func::ReturnOp* return_op) { + for (mlir::Operation* op : ops) { for (mlir::Operation* user : op->getUsers()) { // TODO(twelve): Determine whether we should follow identity ops. if (mlir::func::ReturnOp op = llvm::dyn_cast_or_null(user)) { - if (return_op) { - if (return_op != op) { - return nullptr; + if (*return_op) { + if (*return_op != op) { + return mlir::failure(); } } else { - return_op = op; + *return_op = op; } } else { - return nullptr; + return mlir::failure(); } } } - return return_op; + return mlir::success(); } // Returns the devices for a given mesh. @@ -132,6 +132,62 @@ StatusOr> GetExpandedArguments( ExpandedArgumentMap& expanded_arguments, mlir::BlockArgument argument, const Mesh* target_mesh = nullptr); +StatusOr>> GetResourceLayouts( + mlir::Operation* op) { + if (op->hasAttr(kNewResourceArgLayouts)) { + auto attrs = op->getAttrOfType(kNewResourceArgLayouts); + std::vector layouts; + layouts.reserve(attrs.size()); + for (mlir::Attribute attr : attrs) { + auto string_attr = attr.cast(); + auto layout = Layout::FromString(string_attr.str()); + if (layout.ok()) { + layouts.emplace_back(std::move(layout.value())); + } else { + return layout.status(); + } + } + return layouts; + } else { + return std::nullopt; + } +} + +bool IsResource(mlir::Value value) { + return getElementTypeOrSelf(value.getType()).isa(); +} + +StatusOr> FindResourceLayout(mlir::BlockArgument arg) { + uint32_t arg_num = arg.getArgNumber(); + for (mlir::Operation* user : arg.getUsers()) { + auto resource_layouts = GetResourceLayouts(user); + if (resource_layouts.ok()) { + const auto& opt = resource_layouts.value(); + if (!opt || opt->empty()) { + continue; + } + } else { + return resource_layouts.status(); + } + + auto resource_indices = user->getAttrOfType( + kNewResourceLayoutIndices); + if (!resource_indices) { + return absl::InvalidArgumentError( + absl::StrCat("missing ", kNewResourceLayoutIndices)); + } + + for (auto [i, index] : llvm::enumerate(resource_indices)) { + uint64_t index_value = index.getZExtValue(); + if (index_value == arg_num) { + return (resource_layouts.value())->at(i); + } + } + } + + return std::nullopt; +} + mlir::tf_device::ClusterFuncOp ExtractDeviceClusterFromFunctionCall( mlir::TF::StatefulPartitionedCallOp op) { mlir::tf_device::ClusterFuncOp result; @@ -350,10 +406,24 @@ StatusOr> GetExpandedArguments( mesh = *target_mesh; } } else { - TF_ASSIGN_OR_RETURN(const std::optional layout, + TF_ASSIGN_OR_RETURN(std::optional layout, ExtractLayoutFromOperand(arg)); if (layout) { mesh = layout->mesh(); + + if (mesh->IsEmpty()) { + if (target_mesh) { + mesh = *target_mesh; + } else if (IsResource(arg)) { + TF_ASSIGN_OR_RETURN(layout, FindResourceLayout(arg)); + if (layout) { + mesh = layout->mesh(); + } else { + return absl::InvalidArgumentError( + "Could not find resource layout!"); + } + } + } } } if (mesh.has_value()) { @@ -401,20 +471,49 @@ mlir::FunctionType GetFunctionType(mlir::OpBuilder& builder, return builder.getFunctionType(input_types, result_types); } +struct InferredResourceAttributes { + mlir::Attribute layouts; + mlir::Attribute indices; + + InferredResourceAttributes(mlir::Attribute layouts_, mlir::Attribute indices_) + : layouts(layouts_), indices(indices_) {} +}; + // Build a new main function that calls the multi-device/translated function. -mlir::LogicalResult BuildOuterMainFunc( - mlir::ModuleOp module, mlir::func::FuncOp old_main_func, - mlir::func::FuncOp translated_func, mlir::func::ReturnOp return_op, - absl::Span call_ops) { +template +mlir::LogicalResult BuildOuterMainFunc(mlir::ModuleOp module, + mlir::func::FuncOp old_main_func, + mlir::func::FuncOp translated_func, + mlir::func::ReturnOp return_op, + mlir::ArrayAttr num_local_outputs_attr, + Operations&& call_ops) { + using CallOp = typename std::decay_t::value_type; llvm::SmallVector output_layouts; - for (mlir::TF::StatefulPartitionedCallOp call_op : call_ops) { + std::optional resource_attrs; + for (CallOp call_op : call_ops) { // Then extract all their output layouts. - mlir::ArrayAttr layouts = - call_op->getAttr(kLayoutAttr).dyn_cast_or_null(); + mlir::Attribute layout_attr = call_op->getAttr(kLayoutAttr); + mlir::ArrayAttr layouts = layout_attr.dyn_cast_or_null(); if (!layouts) { call_op.emitOpError() << "Could not find op's layouts."; return mlir::failure(); } + // Set the resource layouts. + mlir::Attribute resource_layouts_attr = + call_op->getAttr(kNewResourceArgLayouts); + mlir::Attribute resource_indices_attr = + call_op->getAttr(kNewResourceLayoutIndices); + if (resource_indices_attr && resource_layouts_attr) { + if (resource_attrs) { + // TODO(twelve): Determine how to merge inferred resource attrs if there + // are multiple sets of them. (when can that happen?) + call_op.emitOpError() + << "Multiple sets of inferred resource attributes!"; + return mlir::failure(); + } else { + resource_attrs.emplace(resource_layouts_attr, resource_indices_attr); + } + } // Here, we assume that the output layouts and the results are in the same // ordering--this property should be guaranteed as long as all the results // have been expanded (produced by ExpandOperation). @@ -442,25 +541,33 @@ mlir::LogicalResult BuildOuterMainFunc( // Get the type of the translated function. mlir::FunctionType func_type = translated_func.getFunctionType(); - // Then build a call op targeting it (reflecting its result types). - auto expanded_call_op = builder.create( - call_ops[0].getLoc(), func_type.getResults(), inputs, - translated_func.getSymName(), - /*config=*/builder.getStringAttr(""), - /*config_proto=*/builder.getStringAttr(""), - /*executor_type=*/builder.getStringAttr("")); + // Then build a call op targeting it (reflecting its result types) + auto expanded_call_op = + builder.create(call_ops[0].getLoc(), func_type.getResults(), + inputs, translated_func.getSymName(), + /*config=*/builder.getStringAttr(""), + /*config_proto=*/builder.getStringAttr(""), + /*executor_type=*/builder.getStringAttr("")); // Set the output layout attribute on the new call op. llvm::ArrayRef output_layouts_ref(output_layouts); mlir::ArrayAttr output_layouts_attr = builder.getArrayAttr(output_layouts_ref); expanded_call_op->setAttr(kLayoutAttr, output_layouts_attr); + expanded_call_op->setAttr(kNumLocalOutputsAttr, num_local_outputs_attr); + + if (resource_attrs) { + expanded_call_op->setAttr(kNewResourceArgLayouts, resource_attrs->layouts); + expanded_call_op->setAttr(kNewResourceLayoutIndices, + resource_attrs->indices); + } // Return all the values from the new call op. mlir::Operation::result_range outputs = expanded_call_op.getResults(); - if (return_op) { - builder.create(return_op.getLoc(), outputs); - } else if (!outputs.empty()) { + if (return_op || outputs.empty()) { + mlir::Location loc = return_op ? return_op.getLoc() : main_func.getLoc(); + builder.create(loc, outputs); + } else { call_ops[0]->emitOpError("Call had results, but they were not used."); return mlir::failure(); } @@ -542,9 +649,8 @@ struct DTensorMultiDeviceExpansion }); // Ensure that all the call ops return results via the same op. - mlir::func::ReturnOp return_op = GetReturnOpFromUsers( - absl::Span(stateful_call_ops)); - if (!return_op && !stateful_call_ops.empty()) { + mlir::func::ReturnOp return_op; + if (GetReturnOpFromUsers(stateful_call_ops, &return_op).failed()) { stateful_call_ops[0]->emitOpError( "Calls must be used by exactly one return op."); return; @@ -581,16 +687,26 @@ struct DTensorMultiDeviceExpansion } std::vector results; - for (unsigned i = 0; i < return_op->getNumOperands(); ++i) { - ExpandedResultsMap::iterator search = expanded_results.find(i); - if (search == expanded_results.end()) { - results.emplace_back(return_op->getOperand(i)); - } else { - std::vector& values = search->second; - results.insert(results.end(), values.begin(), values.end()); + llvm::SmallVector num_local_outputs; + if (return_op) { + for (unsigned i = 0; i < return_op->getNumOperands(); ++i) { + ExpandedResultsMap::iterator search = expanded_results.find(i); + int num_outputs; + if (search == expanded_results.end()) { + results.emplace_back(return_op->getOperand(i)); + num_outputs = 1; + } else { + std::vector& values = search->second; + results.insert(results.end(), values.begin(), values.end()); + num_outputs = values.size(); + } + num_local_outputs.emplace_back(builder.getI64IntegerAttr(num_outputs)); } } + mlir::ArrayAttr num_local_outputs_attr = + builder.getArrayAttr(num_local_outputs); + // update the operands of the translated return op translated_terminator_op->setOperands(results); // and, update the function's type accordingly @@ -598,9 +714,9 @@ struct DTensorMultiDeviceExpansion builder, translated_func, absl::Span(results))); UpdateEntryFuncAttr(builder, translated_func); - mlir::LogicalResult status = BuildOuterMainFunc( - module, main_func, translated_func, return_op, - absl::Span(stateful_call_ops)); + mlir::LogicalResult status = + BuildOuterMainFunc(module, main_func, translated_func, return_op, + num_local_outputs_attr, stateful_call_ops); if (mlir::failed(status)) { return; } From fd5078bbb052660ad90d2b1e2ad252998cf98f14 Mon Sep 17 00:00:00 2001 From: Fiona Lang Date: Tue, 11 Jul 2023 10:01:04 -0700 Subject: [PATCH 134/376] Update ops.Tensor references to //third_party/tensorflow/python/framework/tensor.py. PiperOrigin-RevId: 547220887 --- tensorflow/python/framework/BUILD | 33 ++-- tensorflow/python/framework/constant_op.py | 5 +- tensorflow/python/framework/extension_type.py | 17 +- .../python/framework/extension_type_field.py | 9 +- .../framework/extension_type_field_test.py | 39 ++-- .../python/framework/extension_type_test.py | 176 +++++++++--------- tensorflow/python/framework/function.py | 3 +- tensorflow/python/framework/importer.py | 3 +- tensorflow/python/framework/op_def_library.py | 5 +- .../python/framework/op_def_library_test.py | 6 +- tensorflow/python/framework/ops_test.py | 36 ++-- .../framework/python_api_dispatcher_test.py | 12 +- .../python_api_parameter_converter_test.py | 4 +- .../framework/python_tensor_converter_test.py | 22 +-- tensorflow/python/framework/smart_cond.py | 4 +- .../python/framework/sparse_tensor_test.py | 26 +-- tensorflow/python/framework/subscribe.py | 7 +- tensorflow/python/framework/tensor_util.py | 2 +- tensorflow/python/framework/test_util.py | 23 +-- tensorflow/python/framework/test_util_test.py | 5 +- tensorflow/python/framework/weak_tensor.py | 4 +- .../python/framework/weak_tensor_test.py | 8 +- 22 files changed, 234 insertions(+), 215 deletions(-) diff --git a/tensorflow/python/framework/BUILD b/tensorflow/python/framework/BUILD index 20f54f63c34c7f..856015a16a00cf 100644 --- a/tensorflow/python/framework/BUILD +++ b/tensorflow/python/framework/BUILD @@ -275,6 +275,7 @@ py_strict_library( srcs_version = "PY3", deps = [ ":ops", + ":tensor", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:variables", "//tensorflow/python/platform:tf_logging", @@ -400,6 +401,7 @@ py_strict_library( deps = [ ":dtypes", ":ops", + ":tensor", ":tensor_conversion_registry", ":tensor_shape", ":tensor_util", @@ -528,6 +530,7 @@ py_strict_library( ":dtypes", ":graph_to_function_def", ":ops", + ":tensor", "//tensorflow/core:protos_all_py", "//tensorflow/python/client:pywrap_tf_session", "//tensorflow/python/eager:context", @@ -665,6 +668,7 @@ py_strict_library( ":op_def_library_pybind", ":op_def_registry", ":ops", + ":tensor", ":tensor_shape", "//tensorflow/core:protos_all_py", "//tensorflow/core/config:flags_py", @@ -906,7 +910,7 @@ tf_py_strict_test( ":_pywrap_python_tensor_converter", ":constant_op", ":dtypes", - ":ops", + ":tensor", ":tensor_shape", ":test_lib", "//tensorflow/core:protos_all_py", @@ -1051,7 +1055,7 @@ tf_py_strict_test( deps = [ ":_pywrap_python_api_dispatcher", ":constant_op", - ":ops", + ":tensor", ":test_lib", "//tensorflow/python/ops/ragged:ragged_factory_ops", "//tensorflow/python/ops/ragged:ragged_tensor", @@ -1136,7 +1140,7 @@ tf_py_strict_test( ":constant_op", ":dtypes", ":indexed_slices", - ":ops", + ":tensor", ":test_lib", "//tensorflow/core:protos_all_py", "//tensorflow/python/eager:context", @@ -1532,7 +1536,7 @@ py_strict_library( srcs_version = "PY3", visibility = visibility + ["//tensorflow_model_optimization:__subpackages__"], deps = [ - ":ops", + ":tensor", ":tensor_util", "//tensorflow/python/ops:cond", "//tensorflow/python/ops:control_flow_case", @@ -1602,7 +1606,7 @@ py_strict_library( ":dtypes", ":errors", ":extension_type", - ":ops", + ":tensor", "//tensorflow/python/eager:context", "//third_party/py/numpy", ], @@ -1620,7 +1624,7 @@ tf_py_strict_test( ":dtypes", ":errors", ":ops", - ":tensor_spec", + ":tensor", ":test_lib", ":weak_tensor", "//tensorflow/python/eager:backprop", @@ -1687,9 +1691,8 @@ py_strict_library( ":dtypes", ":extension_type_field", ":immutable_dict", - ":ops", + ":tensor", ":tensor_shape", - ":tensor_spec", ":type_spec", ":type_spec_registry", "//tensorflow/core:protos_all_py", @@ -1715,6 +1718,7 @@ py_strict_library( ":dtypes", ":immutable_dict", ":ops", + ":tensor", ":tensor_shape", ":type_spec", "//tensorflow/python/util:type_annotations", @@ -1961,6 +1965,7 @@ py_strict_library( ":ops", ":random_seed", ":sparse_tensor", + ":tensor", ":tensor_shape", ":tensor_util", ":tfrt_utils", @@ -2473,9 +2478,9 @@ tf_py_strict_test( ":indexed_slices", ":ops", ":sparse_tensor", + ":tensor", ":tensor_conversion_registry", ":tensor_shape", - ":tensor_spec", ":tensor_util", ":test_lib", ":test_ops", @@ -2577,8 +2582,8 @@ tf_py_strict_test( ":extension_type_field", ":immutable_dict", ":ops", + ":tensor", ":tensor_shape", - ":tensor_spec", ":test_lib", ":type_spec", ":type_spec_registry", @@ -2619,9 +2624,8 @@ tf_py_strict_test( ":constant_op", ":dtypes", ":extension_type_field", - ":ops", + ":tensor", ":tensor_shape", - ":tensor_spec", ":test_lib", "//tensorflow/python/ops/ragged:ragged_factory_ops", "//tensorflow/python/ops/ragged:ragged_tensor", @@ -2724,8 +2728,8 @@ tf_py_strict_test( ":errors", ":ops", ":sparse_tensor", + ":tensor", ":tensor_shape", - ":tensor_spec", ":test_lib", ":type_utils", "//tensorflow/core:protos_all_py", @@ -2837,6 +2841,7 @@ cuda_py_strict_test( ":indexed_slices", ":ops", ":random_seed", + ":tensor", ":test_lib", ":test_ops", "//tensorflow/core:protos_all_py", @@ -2928,8 +2933,8 @@ tf_py_strict_test( ":op_def_library", ":op_def_library_pybind", ":ops", + ":tensor", ":tensor_shape", - ":tensor_spec", ":test_lib", "//tensorflow/core:protos_all_py", "//tensorflow/python/eager:def_function", diff --git a/tensorflow/python/framework/constant_op.py b/tensorflow/python/framework/constant_op.py index 20f314afc7b191..4004485469c33d 100644 --- a/tensorflow/python/framework/constant_op.py +++ b/tensorflow/python/framework/constant_op.py @@ -26,6 +26,7 @@ from tensorflow.python.eager import execute from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_conversion_registry from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util @@ -314,7 +315,7 @@ def _constant_eager_impl(ctx, value, dtype, shape, verify_shape): def is_constant(tensor_or_op): - if isinstance(tensor_or_op, ops.Tensor): + if isinstance(tensor_or_op, tensor_lib.Tensor): op = tensor_or_op.op else: op = tensor_or_op @@ -400,7 +401,7 @@ class _ConstantTensorCodec: """Codec for Tensor.""" def can_encode(self, pyobj): - return isinstance(pyobj, ops.Tensor) + return isinstance(pyobj, tensor_lib.Tensor) def do_encode(self, tensor_value, encode_fn): """Returns an encoded `TensorProto` for the given `tf.Tensor`.""" diff --git a/tensorflow/python/framework/extension_type.py b/tensorflow/python/framework/extension_type.py index 16ae82831121b4..d83e5b3b6d401e 100644 --- a/tensorflow/python/framework/extension_type.py +++ b/tensorflow/python/framework/extension_type.py @@ -25,9 +25,8 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import extension_type_field from tensorflow.python.framework import immutable_dict -from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import type_spec from tensorflow.python.framework import type_spec_registry from tensorflow.python.ops import array_ops @@ -146,7 +145,7 @@ class ExtensionType( >>> class Toy(ExtensionType): ... name: str - ... price: ops.Tensor + ... price: tensor.Tensor ... features: typing.Mapping[str, tf.Tensor] >>> class ToyStore(ExtensionType): @@ -307,7 +306,7 @@ def __eq__(self, other): def __ne__(self, other): eq = self.__eq__(other) - if isinstance(eq, ops.Tensor): + if isinstance(eq, tensor.Tensor): return math_ops.logical_not(eq) else: return not eq @@ -448,7 +447,7 @@ def _to_components(self, value): # TypeSpec API. if self._tf_extension_type_is_packed: return value._tf_extension_type_packed_variant # pylint: disable=protected-access - tensor_or_composite = (ops.Tensor, composite_tensor.CompositeTensor) + tensor_or_composite = (tensor.Tensor, composite_tensor.CompositeTensor) # Retireve fields by the order of spec dict to preserve field ordering. This # is needed as nest.flatten would sort dictionary entries by key. value_tuple = tuple(value.__dict__[key] for key in self.__dict__) @@ -490,7 +489,7 @@ def _from_components(self, components): # TypeSpec API. @property def _component_specs(self): # TypeSpec API. if self._tf_extension_type_is_packed: - return tensor_spec.TensorSpec((), dtypes.variant) + return tensor.TensorSpec((), dtypes.variant) components = [] @@ -864,9 +863,9 @@ def _deserialize_for_reduce(value_type, serialization): def _replace_tensor_with_spec(value): - if isinstance(value, ops.Tensor): + if isinstance(value, tensor.Tensor): # Note: we intentionally exclude `value.name` from the `TensorSpec`. - return tensor_spec.TensorSpec(value.shape, value.dtype) + return tensor.TensorSpec(value.shape, value.dtype) if hasattr(value, '_type_spec'): return value._type_spec # pylint: disable=protected-access return value @@ -1265,7 +1264,7 @@ def _convert_anonymous_fields(value, for_spec=False): ) if ( - isinstance(value, (ops.Tensor, composite_tensor.CompositeTensor)) + isinstance(value, (tensor.Tensor, composite_tensor.CompositeTensor)) and not for_spec ): return value diff --git a/tensorflow/python/framework/extension_type_field.py b/tensorflow/python/framework/extension_type_field.py index 80774535f39421..afd84fb7d9d1e7 100644 --- a/tensorflow/python/framework/extension_type_field.py +++ b/tensorflow/python/framework/extension_type_field.py @@ -23,6 +23,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import immutable_dict from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import type_spec from tensorflow.python.util import type_annotations @@ -154,7 +155,7 @@ def validate_field_value_type(value_type, if value_type in (int, float, str, bytes, bool, None, _NoneType, dtypes.DType): return - elif (value_type in (ops.Tensor, tensor_shape.TensorShape) or + elif (value_type in (tensor.Tensor, tensor_shape.TensorShape) or (isinstance(value_type, type) and _issubclass(value_type, composite_tensor.CompositeTensor))): if in_mapping_key: @@ -287,7 +288,7 @@ def _convert_value(value, expected_type, path, if expected_type is None: expected_type = _NoneType - if expected_type is ops.Tensor: + if expected_type is tensor.Tensor: return _convert_tensor(value, path, context) elif (isinstance(expected_type, type) and _issubclass(expected_type, composite_tensor.CompositeTensor)): @@ -324,13 +325,13 @@ def _convert_tensor(value, path, context): """Converts `value` to a `Tensor`.""" if context == _ConversionContext.SPEC: if not (isinstance(value, type_spec.TypeSpec) and - value.value_type is ops.Tensor): + value.value_type is tensor.Tensor): raise TypeError( f'{"".join(path)}: expected a TensorSpec, got ' f'{type(value).__name__!r}') return value - if not isinstance(value, ops.Tensor): + if not isinstance(value, tensor.Tensor): if context == _ConversionContext.DEFAULT: # TODO(edloper): Convert the value to a numpy array? (Note: we can't just # use `np.array(value)`, since the default dtypes for TF and numpy are diff --git a/tensorflow/python/framework/extension_type_field_test.py b/tensorflow/python/framework/extension_type_field_test.py index a892ce9097df9f..f352c899d0e042 100644 --- a/tensorflow/python/framework/extension_type_field_test.py +++ b/tensorflow/python/framework/extension_type_field_test.py @@ -22,9 +22,8 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import extension_type_field -from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import test_util from tensorflow.python.ops.ragged import ragged_factory_ops from tensorflow.python.ops.ragged import ragged_tensor @@ -46,12 +45,12 @@ class ExtensionTypeFieldTest(test_util.TensorFlowTestCase, # Without default values: ('x', int), ('f', float), - ('t', ops.Tensor), + ('t', tensor.Tensor), # With default values: ('x', int, 33), ('y', float, 33.8), - ('t', ops.Tensor, [[1, 2], [3, 4]]), - ('t', ops.Tensor, lambda: constant_op.constant([[1, 2], [3, 4]])), + ('t', tensor.Tensor, [[1, 2], [3, 4]]), + ('t', tensor.Tensor, lambda: constant_op.constant([[1, 2], [3, 4]])), ('r', ragged_tensor.RaggedTensor, lambda: ragged_factory_ops.constant([[1, 2], [3]])), ('seq', typing.Tuple[typing.Union[int, float], ...], (33, 12.8, 9, 0)), @@ -75,7 +74,7 @@ def testConstruction( default = converted_default self.assertEqual(field.name, name) self.assertEqual(field.value_type, value_type) - if isinstance(default, (ops.Tensor, ragged_tensor.RaggedTensor)): + if isinstance(default, (tensor.Tensor, ragged_tensor.RaggedTensor)): self.assertAllEqual(field.default, default) else: self.assertEqual(field.default, default) @@ -91,13 +90,13 @@ def testConstruction( ('seq', _TUPLE[typing.Union[int, float], ...], [33, 12.8, 'zero'], (r'default value for seq\[2\]: expected ' r"typing.Union\[int, float\], got 'str'")), - ('t', tensor_spec.TensorSpec(None, dtypes.int32), + ('t', tensor.TensorSpec(None, dtypes.int32), lambda: constant_op.constant(0.0), 'Unsupported type annotation TensorSpec.*'), ('x', dict, {}, "In field 'x': Unsupported type annotation 'dict'"), ('y', typing.Union[int, list], 3, "In field 'y': Unsupported type annotation 'list'"), - ('z', typing.Mapping[ops.Tensor, int], {}, + ('z', typing.Mapping[tensor.Tensor, int], {}, "In field 'z': Mapping had a key 'Tensor' with type 'type'"), ]) def testConstructionError(self, name, value_type, default, error): @@ -150,7 +149,7 @@ class ValidateFieldPyTypeTest(test_util.TensorFlowTestCase, dict(tp=type(None)), dict(tp=dtypes.DType), dict(tp=tensor_shape.TensorShape), - dict(tp=ops.Tensor), + dict(tp=tensor.Tensor), dict(tp='A', allow_forward_references=True), # Generic types dict(tp=typing.Union[int, float]), @@ -185,7 +184,7 @@ def testValidPytype(self, tp, allow_forward_references=False): error="Unsupported type annotation 'dict'"), dict(tp='A', error='Unresolved forward reference .*'), dict(tp=typing.Union[int, 'A'], error='Unresolved forward reference .*'), - dict(tp=typing.Mapping[ops.Tensor, int], + dict(tp=typing.Mapping[tensor.Tensor, int], error="Mapping had a key 'Tensor' with type 'type'"), dict( tp=typing.Mapping[tensor_shape.TensorShape, int], @@ -223,8 +222,8 @@ def testConvertFieldsMismatch(self, field_values, error): ('foo', str), (None, None), (True, bool), - ([1, 2, 3], ops.Tensor), - (lambda: constant_op.constant([1, 2, 3]), ops.Tensor), + ([1, 2, 3], tensor.Tensor), + (lambda: constant_op.constant([1, 2, 3]), tensor.Tensor), (lambda: ragged_factory_ops.constant([[1, 2], [3]]), ragged_tensor.RaggedTensor), ([1, 2, 3], typing.Tuple[int, ...], (1, 2, 3)), @@ -252,7 +251,7 @@ def testConvertValue(self, value, value_type, expected=None): if expected is None: expected = value converted = extension_type_field._convert_value(value, value_type, ('x',)) - if isinstance(converted, (ops.Tensor, ragged_tensor.RaggedTensor)): + if isinstance(converted, (tensor.Tensor, ragged_tensor.RaggedTensor)): self.assertAllEqual(converted, expected) else: self.assertEqual(converted, expected) @@ -263,7 +262,7 @@ def testConvertValue(self, value, value_type, expected=None): ('foo', str), (None, None), (True, bool), - (tensor_spec.TensorSpec([5]), ops.Tensor), + (tensor.TensorSpec([5]), tensor.Tensor), (ragged_tensor.RaggedTensorSpec([5, None]), ragged_tensor.RaggedTensor), ([1, 2, 3], typing.Tuple[int, ...], (1, 2, 3)), ((1, 2, 3), typing.Tuple[int, int, int], (1, 2, 3)), @@ -292,7 +291,7 @@ def testConvertValueForSpec(self, value, value_type, expected=None): converted = extension_type_field._convert_value( value, value_type, ('x',), extension_type_field._ConversionContext.SPEC) - if isinstance(converted, (ops.Tensor, ragged_tensor.RaggedTensor)): + if isinstance(converted, (tensor.Tensor, ragged_tensor.RaggedTensor)): self.assertAllEqual(converted, expected) else: self.assertEqual(converted, expected) @@ -321,14 +320,14 @@ def testConvertFields(self): 'y', typing.Tuple[typing.Union[int, bool], ...]), extension_type_field.ExtensionTypeField( 'y', _TUPLE[typing.Union[int, bool], ...]), - extension_type_field.ExtensionTypeField('z', ops.Tensor) + extension_type_field.ExtensionTypeField('z', tensor.Tensor) ] field_values = {'x': 1, 'y': [1, True, 3], 'z': [[1, 2], [3, 4], [5, 6]]} extension_type_field.convert_fields(fields, field_values) self.assertEqual(set(field_values), set(['x', 'y', 'z'])) self.assertEqual(field_values['x'], 1) self.assertEqual(field_values['y'], (1, True, 3)) - self.assertIsInstance(field_values['z'], ops.Tensor) + self.assertIsInstance(field_values['z'], tensor.Tensor) self.assertAllEqual(field_values['z'], [[1, 2], [3, 4], [5, 6]]) def testConvertFieldsForSpec(self): @@ -338,18 +337,18 @@ def testConvertFieldsForSpec(self): 'y', typing.Tuple[typing.Union[int, bool], ...]), extension_type_field.ExtensionTypeField( 'y', _TUPLE[typing.Union[int, bool], ...]), - extension_type_field.ExtensionTypeField('z', ops.Tensor) + extension_type_field.ExtensionTypeField('z', tensor.Tensor) ] field_values = { 'x': 1, 'y': [1, True, 3], - 'z': tensor_spec.TensorSpec([5, 3]) + 'z': tensor.TensorSpec([5, 3]) } extension_type_field.convert_fields_for_spec(fields, field_values) self.assertEqual(set(field_values), set(['x', 'y', 'z'])) self.assertEqual(field_values['x'], 1) self.assertEqual(field_values['y'], (1, True, 3)) - self.assertEqual(field_values['z'], tensor_spec.TensorSpec([5, 3])) + self.assertEqual(field_values['z'], tensor.TensorSpec([5, 3])) if __name__ == '__main__': diff --git a/tensorflow/python/framework/extension_type_test.py b/tensorflow/python/framework/extension_type_test.py index 0e6bccb396c7af..0169690eaf3c33 100644 --- a/tensorflow/python/framework/extension_type_test.py +++ b/tensorflow/python/framework/extension_type_test.py @@ -34,8 +34,8 @@ from tensorflow.python.framework import extension_type_field from tensorflow.python.framework import immutable_dict from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import test_util from tensorflow.python.framework import type_spec from tensorflow.python.framework import type_spec_registry @@ -64,8 +64,8 @@ class MaskedTensorV1(extension_type.ExtensionType): """Example subclass of ExtensionType, used for testing.""" - values: ops.Tensor - mask: ops.Tensor + values: tensor.Tensor + mask: tensor.Tensor class MaskedTensorV2(extension_type.ExtensionType): @@ -78,8 +78,8 @@ class MaskedTensorV2(extension_type.ExtensionType): __name__ = 'tf.test.MaskedTensorV2' - values: ops.Tensor - mask: ops.Tensor + values: tensor.Tensor + mask: tensor.Tensor def __repr__(self): if hasattr(self.values, 'numpy') and hasattr(self.mask, 'numpy'): @@ -117,7 +117,7 @@ def with_default(self, default): class SimpleExtensionType(extension_type.ExtensionType): - x: ops.Tensor + x: tensor.Tensor class Spec: @@ -145,8 +145,8 @@ class MaskedTensorV3(extension_type.BatchableExtensionType): __name__ = 'tf.test.MaskedTensorV3.Spec' - values: typing.Union[ops.Tensor, ragged_tensor.RaggedTensor] - mask: typing.Union[ops.Tensor, ragged_tensor.RaggedTensor] + values: typing.Union[tensor.Tensor, ragged_tensor.RaggedTensor] + mask: typing.Union[tensor.Tensor, ragged_tensor.RaggedTensor] def __init__(self, values, mask): if isinstance(values, ragged_tensor.RaggedTensor): @@ -182,12 +182,12 @@ class ForwardRefA(extension_type.ExtensionType): class ForwardRefB(extension_type.ExtensionType): z: 'ForwardRefB' - n: ops.Tensor + n: tensor.Tensor class ExtensionTypeWithTensorDefault(extension_type.ExtensionType): - x: ops.Tensor = 5 - y: ops.Tensor = ['a', 'b', 'c'] + x: tensor.Tensor = 5 + y: tensor.Tensor = ['a', 'b', 'c'] @test_util.run_all_in_graph_and_eager_modes @@ -198,9 +198,9 @@ def testAttributeAccessors(self): mt2 = extension_type.pack(mt1) for mt in [mt1, mt2]: - self.assertIsInstance(mt.values, ops.Tensor) + self.assertIsInstance(mt.values, tensor.Tensor) self.assertAllEqual(mt.values, [1, 2, 3, 4]) - self.assertIsInstance(mt.mask, ops.Tensor) + self.assertIsInstance(mt.mask, tensor.Tensor) self.assertAllEqual(mt.mask, [True, True, False, True]) def testAttributesAreImmutable(self): @@ -260,14 +260,16 @@ def testAsDict(self): def testConstructorSignature(self): class MyType(extension_type.ExtensionType): - x: ops.Tensor - y: ops.Tensor + x: tensor.Tensor + y: tensor.Tensor z: typing.Tuple[typing.Union[int, str], ...] = [1, 'two', 3] expected_parameters = [ tf_inspect.Parameter('self', POSITIONAL_OR_KEYWORD), - tf_inspect.Parameter('x', POSITIONAL_OR_KEYWORD, annotation=ops.Tensor), - tf_inspect.Parameter('y', POSITIONAL_OR_KEYWORD, annotation=ops.Tensor), + tf_inspect.Parameter( + 'x', POSITIONAL_OR_KEYWORD, annotation=tensor.Tensor), + tf_inspect.Parameter( + 'y', POSITIONAL_OR_KEYWORD, annotation=tensor.Tensor), tf_inspect.Parameter( 'z', POSITIONAL_OR_KEYWORD, @@ -284,7 +286,7 @@ def testConstructorSignatureWithKeywordOnlyArgs(self): class MyType(extension_type.ExtensionType): a: int b: str = 'Hello world' - c: ops.Tensor + c: tensor.Tensor expected_parameters = [ tf_inspect.Parameter('self', POSITIONAL_OR_KEYWORD), @@ -292,7 +294,7 @@ class MyType(extension_type.ExtensionType): tf_inspect.Parameter( 'b', POSITIONAL_OR_KEYWORD, annotation=str, default='Hello world' ), - tf_inspect.Parameter('c', KEYWORD_ONLY, annotation=ops.Tensor), + tf_inspect.Parameter('c', KEYWORD_ONLY, annotation=tensor.Tensor), ] expected_sig = tf_inspect.Signature( expected_parameters, return_annotation=MyType @@ -314,13 +316,14 @@ def testConstructorSignatureWithDefaultForTensorField(self): def testConstructorSignatureWithAnnotatedTensorField(self): class MyType(extension_type.ExtensionType): - a: typing_extensions.Annotated[ops.Tensor, 'metadata'] + a: typing_extensions.Annotated[tensor.Tensor, 'metadata'] b: typing_extensions.Annotated[str, 'metadata'] = 'Hello world' c: typing.Optional[typing_extensions.Annotated[int, 'metadata']] = None expected_parameters = [ tf_inspect.Parameter('self', POSITIONAL_OR_KEYWORD), - tf_inspect.Parameter('a', POSITIONAL_OR_KEYWORD, annotation=ops.Tensor), + tf_inspect.Parameter( + 'a', POSITIONAL_OR_KEYWORD, annotation=tensor.Tensor), tf_inspect.Parameter( 'b', POSITIONAL_OR_KEYWORD, annotation=str, default='Hello world' ), @@ -348,9 +351,9 @@ class EmptyType(extension_type.ExtensionType): def testCustomConstrutor(self): class SummarizedTensor(extension_type.ExtensionType): - values: ops.Tensor - mean: ops.Tensor - max: ops.Tensor + values: tensor.Tensor + mean: tensor.Tensor + max: tensor.Tensor def __init__(self, values): self.values = ops.convert_to_tensor(values) @@ -363,7 +366,7 @@ def __init__(self, values): self.assertAllEqual(x.max, 6) class Node(extension_type.ExtensionType): - x: ops.Tensor + x: tensor.Tensor y: typing.Optional[str] = None children: typing.Tuple['ExtensionTypeTest.Node', ...] = () @@ -402,8 +405,8 @@ def __init__(self, foo): def testCustomValidate(self): class AlignedTensors(extension_type.ExtensionType): - x: ops.Tensor - y: ops.Tensor + x: tensor.Tensor + y: tensor.Tensor def __validate__(self): self.x.shape.assert_is_compatible_with(self.y.shape) @@ -417,8 +420,8 @@ def __validate__(self): def testEquals(self): class MyType(extension_type.ExtensionType): - values: ops.Tensor - score: ops.Tensor + values: tensor.Tensor + score: tensor.Tensor flavor: str x1 = MyType([1, 2], 8, 'blue') @@ -509,8 +512,8 @@ def fn_with_side_effect(mts): def testNestPackUnpack(self): class CandyStore(extension_type.ExtensionType): - name: ops.Tensor - prices: typing.Mapping[str, ops.Tensor] + name: tensor.Tensor + prices: typing.Mapping[str, tensor.Tensor] store = CandyStore('Yum', {'gum': [0.42, 0.48], 'chocolate': [0.83, 1.02]}) components = nest.flatten(store, expand_composites=True) @@ -702,13 +705,14 @@ def body(i, x): self.assertAllEqual(y.mask, [True, False, True, False]) def testNestedFields(self): - PossiblyRaggedTensor = typing.Union[ops.Tensor, ragged_tensor.RaggedTensor] + PossiblyRaggedTensor = typing.Union[ + tensor.Tensor, ragged_tensor.RaggedTensor] ToyFeatures = typing.Mapping[str, PossiblyRaggedTensor] class ToyInfo(extension_type.ExtensionType): version: str - toys: typing.Tuple[typing.Tuple[str, ops.Tensor, ToyFeatures], ...] - boxes: typing.Mapping[str, ops.Tensor] + toys: typing.Tuple[typing.Tuple[str, tensor.Tensor, ToyFeatures], ...] + boxes: typing.Mapping[str, tensor.Tensor] authors = [[b'A', b'Aardvark'], [b'Z', b'Zhook']] toys = [ @@ -720,10 +724,10 @@ class ToyInfo(extension_type.ExtensionType): self.assertEqual(toy_info.version, '1.0 alpha') self.assertEqual(toy_info.toys[0][0], 'car') - self.assertIsInstance(toy_info.toys[0][1], ops.Tensor) + self.assertIsInstance(toy_info.toys[0][1], tensor.Tensor) self.assertAllEqual(toy_info.toys[0][1], 1.0) self.assertEqual(set(toy_info.toys[0][2].keys()), {'size', 'color'}) - self.assertIsInstance(toy_info.toys[0][2]['size'], ops.Tensor) + self.assertIsInstance(toy_info.toys[0][2]['size'], tensor.Tensor) self.assertAllEqual(toy_info.toys[0][2]['size'], [8, 3, 2]) self.assertIsInstance( toy_info.toys[1][2]['authors'], ragged_tensor.RaggedTensor @@ -745,15 +749,15 @@ class ToyInfo(extension_type.ExtensionType): self.assertRegex(repr(toy_info), expected_repr) def testNestedExtensionTypes(self): - PossiblyMaskedTensor = typing.Union[ops.Tensor, MaskedTensorV1] + PossiblyMaskedTensor = typing.Union[tensor.Tensor, MaskedTensorV1] class Toy(extension_type.ExtensionType): name: str - price: ops.Tensor + price: tensor.Tensor features: typing.Mapping[str, PossiblyMaskedTensor] class Box(extension_type.ExtensionType): - contents: ops.Tensor + contents: tensor.Tensor class ToyInfo(extension_type.ExtensionType): version: str @@ -784,7 +788,7 @@ def fn(info): def testNestedCustomConstructor(self): class Toy(extension_type.ExtensionType): name: str - price: ops.Tensor + price: tensor.Tensor def __init__(self, name, price, discount=0): if discount: @@ -834,10 +838,10 @@ def testGetExtensionTypeFields(self): for fields in [fields_1, fields_2]: self.assertLen(fields, 2) self.assertEqual(fields[0].name, 'values') - self.assertEqual(fields[0].value_type, ops.Tensor) + self.assertEqual(fields[0].value_type, tensor.Tensor) self.assertEqual(fields[0].default, fields[0].NO_DEFAULT) self.assertEqual(fields[1].name, 'mask') - self.assertEqual(fields[1].value_type, ops.Tensor) + self.assertEqual(fields[1].value_type, tensor.Tensor) self.assertEqual(fields[1].default, fields[0].NO_DEFAULT) def testHasExtensionTypeField(self): @@ -866,7 +870,7 @@ def testForwardReferences(self): B._tf_extension_type_fields(), ( extension_type_field.ExtensionTypeField('z', B), - extension_type_field.ExtensionTypeField('n', ops.Tensor), + extension_type_field.ExtensionTypeField('n', tensor.Tensor), ), ) @@ -905,7 +909,7 @@ def testUnsupportedAnnotations(self): ): class MyType1(extension_type.ExtensionType): # pylint: disable=unused-variable - values: typing.List[ops.Tensor] + values: typing.List[tensor.Tensor] with self.assertRaisesRegex( TypeError, "In field 'xyz': Unsupported type annotation" @@ -955,8 +959,8 @@ def testExtensionTypeBaseConstructorRaisesException(self): class ExtensionTypeWithName(extension_type.ExtensionType): __name__ = 'tf.__test__.ExtensionTypeWithName' # For SavedModel - x: typing.Tuple[ops.Tensor, int] - y: ops.Tensor + x: typing.Tuple[tensor.Tensor, int] + y: tensor.Tensor def testSavedModelSupport(self): class TestModule(module.Module): @@ -985,16 +989,16 @@ def testPackedEncoding(self): mt2 = extension_type.pack(mt1) self.assertLen(nest.flatten(mt2, expand_composites=True), 1) - self.assertIsInstance(mt2.values, ops.Tensor) + self.assertIsInstance(mt2.values, tensor.Tensor) self.assertAllEqual(mt2.values, [1, 2, 3, 4]) - self.assertIsInstance(mt2.mask, ops.Tensor) + self.assertIsInstance(mt2.mask, tensor.Tensor) self.assertAllEqual(mt2.mask, [True, True, False, True]) mt3 = extension_type.unpack(mt2) self.assertLen(nest.flatten(mt3, expand_composites=True), 2) - self.assertIsInstance(mt3.values, ops.Tensor) + self.assertIsInstance(mt3.values, tensor.Tensor) self.assertAllEqual(mt3.values, [1, 2, 3, 4]) - self.assertIsInstance(mt3.mask, ops.Tensor) + self.assertIsInstance(mt3.mask, tensor.Tensor) self.assertAllEqual(mt3.mask, [True, True, False, True]) nest.assert_same_structure(mt1, mt3, expand_composites=True) @@ -1010,8 +1014,8 @@ def testPackedEncoding(self): def testSubclassing(self): class Instrument(extension_type.ExtensionType): - name: ops.Tensor - weight: ops.Tensor + name: tensor.Tensor + weight: tensor.Tensor needs_case: bool class StringInstrument(Instrument): @@ -1019,7 +1023,7 @@ class StringInstrument(Instrument): needs_case: bool = True # Override default value. class Violin(StringInstrument): - maker: ops.Tensor + maker: tensor.Tensor num_strings: int = 4 # Override default value. name: str = 'violin' # Override field type and default value. @@ -1030,10 +1034,10 @@ class Violin(StringInstrument): [ tf_inspect.Parameter('self', POSITIONAL_OR_KEYWORD), tf_inspect.Parameter( - 'name', POSITIONAL_OR_KEYWORD, annotation=ops.Tensor + 'name', POSITIONAL_OR_KEYWORD, annotation=tensor.Tensor ), tf_inspect.Parameter( - 'weight', POSITIONAL_OR_KEYWORD, annotation=ops.Tensor + 'weight', POSITIONAL_OR_KEYWORD, annotation=tensor.Tensor ), tf_inspect.Parameter( 'needs_case', @@ -1051,14 +1055,16 @@ class Violin(StringInstrument): tf_inspect.Parameter( 'name', POSITIONAL_OR_KEYWORD, annotation=str, default='violin' ), - tf_inspect.Parameter('weight', KEYWORD_ONLY, annotation=ops.Tensor), + tf_inspect.Parameter( + 'weight', KEYWORD_ONLY, annotation=tensor.Tensor), tf_inspect.Parameter( 'needs_case', KEYWORD_ONLY, annotation=bool, default=True ), tf_inspect.Parameter( 'num_strings', KEYWORD_ONLY, annotation=int, default=4 ), - tf_inspect.Parameter('maker', KEYWORD_ONLY, annotation=ops.Tensor), + tf_inspect.Parameter( + 'maker', KEYWORD_ONLY, annotation=tensor.Tensor), ], ) @@ -1131,8 +1137,8 @@ class ExtensionTypeSpecTest( ): def testSpecConstructor(self): - values_spec = tensor_spec.TensorSpec([4], dtypes.float32) - mask_spec = tensor_spec.TensorSpec([4], dtypes.bool) + values_spec = tensor.TensorSpec([4], dtypes.float32) + mask_spec = tensor.TensorSpec([4], dtypes.bool) mt_spec = MaskedTensorV1.Spec(values_spec, mask_spec) self.assertEqual(mt_spec.values, values_spec) self.assertEqual(mt_spec.mask, mask_spec) @@ -1142,8 +1148,8 @@ def testSpecConstructor(self): def testSpecConstructorSignature(self): class MyType(extension_type.ExtensionType): - x: ops.Tensor - y: ops.Tensor + x: tensor.Tensor + y: tensor.Tensor z: typing.Tuple[typing.Union[int, str], ...] = [1, 'two', 3] expected_parameters = [ @@ -1183,19 +1189,19 @@ def testSpecFromValue(self): mt = MaskedTensorV1([1.0, 2.0, 3.0, 4.0], [True, True, False, True]) mt_spec = MaskedTensorV1.Spec.from_value(mt) - expected_values_spec = tensor_spec.TensorSpec([4], dtypes.float32) - expected_mask_spec = tensor_spec.TensorSpec([4], dtypes.bool) + expected_values_spec = tensor.TensorSpec([4], dtypes.float32) + expected_mask_spec = tensor.TensorSpec([4], dtypes.bool) self.assertEqual(mt_spec.values, expected_values_spec) self.assertEqual(mt_spec.mask, expected_mask_spec) def testSpecSerialize(self): class Zoo(extension_type.ExtensionType): zookeepers: typing.Tuple[str, ...] - animals: typing.Mapping[str, typing.Mapping[str, ops.Tensor]] + animals: typing.Mapping[str, typing.Mapping[str, tensor.Tensor]] featurespec = { - 'size': tensor_spec.TensorSpec([3]), - 'weight': tensor_spec.TensorSpec([]), + 'size': tensor.TensorSpec([3]), + 'weight': tensor.TensorSpec([]), } zoo_spec = Zoo.Spec( zookeepers=['Zoey', 'Zack'], @@ -1222,7 +1228,7 @@ class Zoo(extension_type.ExtensionType): def testSpecComponents(self): class Zoo(extension_type.ExtensionType): zookeepers: typing.Tuple[str, ...] - animals: typing.Mapping[str, typing.Mapping[str, ops.Tensor]] + animals: typing.Mapping[str, typing.Mapping[str, tensor.Tensor]] zoo = Zoo( ['Zoey', 'Zack'], @@ -1247,17 +1253,17 @@ class Zoo(extension_type.ExtensionType): self.assertEqual( zoo_spec._component_specs, ( - tensor_spec.TensorSpec([3], dtypes.int32), - tensor_spec.TensorSpec([], dtypes.float32), - tensor_spec.TensorSpec([], dtypes.float32), - tensor_spec.TensorSpec([3], dtypes.int32), - tensor_spec.TensorSpec([], dtypes.float32), + tensor.TensorSpec([3], dtypes.int32), + tensor.TensorSpec([], dtypes.float32), + tensor.TensorSpec([], dtypes.float32), + tensor.TensorSpec([3], dtypes.int32), + tensor.TensorSpec([], dtypes.float32), ), ) def testCopyAndPickle(self): - values_spec = tensor_spec.TensorSpec([4], dtypes.float32) - mask_spec = tensor_spec.TensorSpec([4], dtypes.bool) + values_spec = tensor.TensorSpec([4], dtypes.float32) + mask_spec = tensor.TensorSpec([4], dtypes.bool) mt_spec = MaskedTensorV1.Spec(values_spec, mask_spec) self.assertEqual(copy.copy(mt_spec), mt_spec) self.assertEqual(copy.deepcopy(mt_spec), mt_spec) @@ -1273,8 +1279,8 @@ class WeightedTensor(extension_type.ExtensionType): * Add method (with_shape). """ - values: ops.Tensor - weight: ops.Tensor # scalar + values: tensor.Tensor + weight: tensor.Tensor # scalar shape = property(lambda self: self.shape) dtype = property(lambda self: self.dtype) @@ -1286,8 +1292,8 @@ def __validate__(self): class Spec: def __init__(self, shape, dtype, weight_dtype=dtypes.float32): - self.values = tensor_spec.TensorSpec(shape, dtype) - self.weight = tensor_spec.TensorSpec([], weight_dtype) + self.values = tensor.TensorSpec(shape, dtype) + self.weight = tensor.TensorSpec([], weight_dtype) def __validate__(self): self.weight.shape.assert_has_rank(0) @@ -1376,7 +1382,7 @@ def testAttributeAccessors(self, fields): s = extension_type.AnonymousExtensionType(**fields) for name, value in fields.items(): actual = getattr(s, name) - if isinstance(actual, (ops.Tensor, ragged_tensor.RaggedTensor)): + if isinstance(actual, (tensor.Tensor, ragged_tensor.RaggedTensor)): self.assertAllEqual(actual, value) else: self.assertEqual(actual, value) @@ -1434,7 +1440,7 @@ def testReinterpret(self): lambda: extension_type.AnonymousExtensionType( values=constant_op.constant([1, 2, 3]) ), - ops.Tensor, + tensor.Tensor, ( 'reinterpret expects `new_type` to be a subclass of ' 'tf.ExtensionType; ' @@ -1464,8 +1470,8 @@ def f(x, y): y_mask = y.mask if isinstance(y, MaskedTensorV1) else True return MaskedTensorV1(x_values + y_values, x_mask & y_mask) - t_spec = tensor_spec.TensorSpec(None, dtypes.int32) - b_spec = tensor_spec.TensorSpec(None, dtypes.bool) + t_spec = tensor.TensorSpec(None, dtypes.int32) + b_spec = tensor.TensorSpec(None, dtypes.bool) mt_spec = MaskedTensorV1.Spec(values=t_spec, mask=b_spec) model = module.Module() model.f = def_function.function(f) @@ -1515,8 +1521,8 @@ def testFlatTensorSpecs(self): self.assertEqual( flat_specs, [ - tensor_spec.TensorSpec(shape=(2,), dtype=dtypes.int32, name=None), - tensor_spec.TensorSpec(shape=(2,), dtype=dtypes.bool, name=None), + tensor.TensorSpec(shape=(2,), dtype=dtypes.int32, name=None), + tensor.TensorSpec(shape=(2,), dtype=dtypes.bool, name=None), ], ) @@ -1546,7 +1552,7 @@ def testToLegacyOutputShapeMissing(self): def replace_tensors_with_placeholders(value): def repl(x): - if isinstance(x, ops.Tensor): + if isinstance(x, tensor.Tensor): return array_ops.placeholder_with_default(x, shape=None) else: return x diff --git a/tensorflow/python/framework/function.py b/tensorflow/python/framework/function.py index b09c176e80eb78..632a0022ea2a12 100644 --- a/tensorflow/python/framework/function.py +++ b/tensorflow/python/framework/function.py @@ -29,6 +29,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import graph_to_function_def from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.ops import array_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variable_scope as vs @@ -706,7 +707,7 @@ def __call__(self, *args, **kwargs): args = list(args) for (i, x) in enumerate(args): x = ops.convert_to_tensor(x) - if not isinstance(x, ops.Tensor): + if not isinstance(x, tensor_lib.Tensor): raise ValueError(f"Expected a Tensor but got {x} with type {type(x)}.") input_types.append(x.dtype) args[i] = x diff --git a/tensorflow/python/framework/importer.py b/tensorflow/python/framework/importer.py index acef1db0607759..359e9f4f99af08 100644 --- a/tensorflow/python/framework/importer.py +++ b/tensorflow/python/framework/importer.py @@ -24,6 +24,7 @@ from tensorflow.python.framework import function from tensorflow.python.framework import op_def_registry from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.ops import control_flow_util from tensorflow.python.util import compat from tensorflow.python.util.deprecation import deprecated_args @@ -194,7 +195,7 @@ def _ConvertInputMapValues(name, input_map): Raises: ValueError: if input map values cannot be converted due to empty name scope. """ - if not all(isinstance(v, ops.Tensor) for v in input_map.values()): + if not all(isinstance(v, tensor.Tensor) for v in input_map.values()): if name == '': # pylint: disable=g-explicit-bool-comparison raise ValueError( 'tf.import_graph_def() requires a non-empty `name` if `input_map` ' diff --git a/tensorflow/python/framework/op_def_library.py b/tensorflow/python/framework/op_def_library.py index b2ad0cf659dedb..1ac4f9b73460da 100644 --- a/tensorflow/python/framework/op_def_library.py +++ b/tensorflow/python/framework/op_def_library.py @@ -26,6 +26,7 @@ from tensorflow.python.framework import op_def_library_pybind from tensorflow.python.framework import op_def_registry from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import _pywrap_utils @@ -368,7 +369,7 @@ def _CanExtractAttrsFastPath(op_def, keywords): # Check if all inputs are already tf.Tensor for input_arg in op_def.input_arg: value = keywords.get(input_arg.name, None) - if not isinstance(value, ops.Tensor): + if not isinstance(value, tensor.Tensor): return False # Check that attrs are not `func` or `list(func)` type. @@ -452,7 +453,7 @@ def _ExtractInputsAndAttrs(op_type_name, op_def, allowed_list_attr_map, dtype = attrs[input_arg.type_attr] else: for t in values: - if isinstance(t, ops.Tensor): + if isinstance(t, tensor.Tensor): dtype = t.dtype break diff --git a/tensorflow/python/framework/op_def_library_test.py b/tensorflow/python/framework/op_def_library_test.py index 2f835ae47f9cf9..f77c85643524c7 100644 --- a/tensorflow/python/framework/op_def_library_test.py +++ b/tensorflow/python/framework/op_def_library_test.py @@ -25,8 +25,8 @@ from tensorflow.python.framework import op_def_library from tensorflow.python.framework import op_def_library_pybind from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import test_util from tensorflow.python.platform import googletest from tensorflow.python.util import compat @@ -417,7 +417,7 @@ def fn(x): def testAttrFuncWithFuncWithAttrs(self): with ops.Graph().as_default(): @def_function.function( - input_signature=(tensor_spec.TensorSpec(None, dtypes.float32),), + input_signature=(tensor.TensorSpec(None, dtypes.float32),), autograph=False, experimental_attributes={"_implements": 15}) def fn(x): @@ -1334,7 +1334,7 @@ def testStructuredOutputListAndSingle(self): self.assertIsInstance(a, list) self.assertEqual(n_a, len(a)) self.assertTrue(all(x.dtype == dtypes.int32 for x in a)) - self.assertIsInstance(b, ops.Tensor) + self.assertIsInstance(b, tensor.Tensor) self.assertEqual(dtypes.float32, b.dtype) def testStructuredOutputMultipleLists(self): diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py index aec2433043ddf6..70b61f699d768e 100644 --- a/tensorflow/python/framework/ops_test.py +++ b/tensorflow/python/framework/ops_test.py @@ -44,9 +44,9 @@ from tensorflow.python.framework import indexed_slices from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_conversion_registry from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import tensor_util from tensorflow.python.framework import test_ops from tensorflow.python.framework import test_util @@ -589,27 +589,27 @@ def testSerialize(self, spec, expected): @parameterized.parameters([ (indexed_slices.IndexedSlicesSpec(dtype=dtypes.string), ( - tensor_spec.TensorSpec(None, dtypes.string), - tensor_spec.TensorSpec([None], dtypes.int64), + tensor_lib.TensorSpec(None, dtypes.string), + tensor_lib.TensorSpec([None], dtypes.int64), )), (indexed_slices.IndexedSlicesSpec( dtype=dtypes.string, dense_shape_dtype=dtypes.int32), ( - tensor_spec.TensorSpec(None, dtypes.string), - tensor_spec.TensorSpec([None], dtypes.int64), - tensor_spec.TensorSpec([None], dtypes.int32), + tensor_lib.TensorSpec(None, dtypes.string), + tensor_lib.TensorSpec([None], dtypes.int64), + tensor_lib.TensorSpec([None], dtypes.int32), )), (indexed_slices.IndexedSlicesSpec( shape=[5, 10, 15], dense_shape_dtype=dtypes.int32), ( - tensor_spec.TensorSpec([None, 10, 15], dtypes.float32), - tensor_spec.TensorSpec([None], dtypes.int64), - tensor_spec.TensorSpec([3], dtypes.int32), + tensor_lib.TensorSpec([None, 10, 15], dtypes.float32), + tensor_lib.TensorSpec([None], dtypes.int64), + tensor_lib.TensorSpec([3], dtypes.int32), )), (indexed_slices.IndexedSlicesSpec( shape=[5, 10, 15], dense_shape_dtype=dtypes.int32, indices_shape=[20]), ( - tensor_spec.TensorSpec([20, 10, 15], dtypes.float32), - tensor_spec.TensorSpec([20], dtypes.int64), - tensor_spec.TensorSpec([3], dtypes.int32), + tensor_lib.TensorSpec([20, 10, 15], dtypes.float32), + tensor_lib.TensorSpec([20], dtypes.int64), + tensor_lib.TensorSpec([3], dtypes.int32), )), ]) def testComponentSpecs(self, spec, expected): @@ -1447,10 +1447,10 @@ def testNodeDefArgs(self): g, "Foo1", [t1, t2[1], t2[0]], [dtypes.float32, dtypes.int32], name="myop3") - self.assertTrue(isinstance(t1, ops.Tensor)) + self.assertTrue(isinstance(t1, tensor_lib.Tensor)) self.assertTrue(isinstance(t2, list)) self.assertTrue(isinstance(t3, list)) - self.assertTrue(isinstance(t3[0], ops.Tensor)) + self.assertTrue(isinstance(t3[0], tensor_lib.Tensor)) self.assertEqual("myop1", t1._as_node_def_input()) self.assertEqual("myop2", t2[0]._as_node_def_input()) self.assertEqual("myop2:1", t2[1]._as_node_def_input()) @@ -2333,8 +2333,8 @@ def testMembershipAllowed(self): g = ops.Graph() t1 = _apply_op(g, "FloatOutput", [], [dtypes.float32], name="myop1") t2 = _apply_op(g, "FloatOutput", [], [dtypes.float32], name="myop2") - self.assertTrue(isinstance(t1, ops.Tensor)) - self.assertTrue(isinstance(t2, ops.Tensor)) + self.assertTrue(isinstance(t1, tensor_lib.Tensor)) + self.assertTrue(isinstance(t2, tensor_lib.Tensor)) self.assertTrue(t1 in [t1]) self.assertTrue(t1 not in [t2]) @@ -3623,7 +3623,7 @@ def testCompositeTensorConversion(self): self.assertIsInstance(y, _TupleTensor) self.assertLen(y, len(x)) for x_, y_ in zip(x, y): - self.assertIsInstance(y_, ops.Tensor) + self.assertIsInstance(y_, tensor_lib.Tensor) self.assertTrue(tensor_util.is_tf_type(y_)) self.assertAllEqual(x_, tensor_util.constant_value(y_)) @@ -3681,7 +3681,7 @@ def setUpInputShapes(self, pre_add_input_shapes): test_tensor_shape = [None, 1, 1, 1] @def_function.function(input_signature=[ - tensor_spec.TensorSpec(shape=test_tensor_shape, dtype=dtypes.float32) + tensor_lib.TensorSpec(shape=test_tensor_shape, dtype=dtypes.float32) ]) def f(x): return array_ops.identity(x, name="output") diff --git a/tensorflow/python/framework/python_api_dispatcher_test.py b/tensorflow/python/framework/python_api_dispatcher_test.py index f179cbc9eceb31..a4ddb620f09aeb 100644 --- a/tensorflow/python/framework/python_api_dispatcher_test.py +++ b/tensorflow/python/framework/python_api_dispatcher_test.py @@ -19,7 +19,7 @@ from tensorflow.python.framework import _pywrap_python_api_dispatcher as dispatch from tensorflow.python.framework import constant_op -from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.framework import test_util from tensorflow.python.ops.ragged import ragged_factory_ops from tensorflow.python.ops.ragged import ragged_tensor @@ -47,7 +47,7 @@ def testInstanceChecker(self): self.assertEqual(repr(int_checker), '') with self.subTest('tensor checker'): - tensor_checker = dispatch.MakeInstanceChecker(ops.Tensor) + tensor_checker = dispatch.MakeInstanceChecker(tensor.Tensor) self.assertEqual(tensor_checker.Check(t), MATCH) self.assertEqual(tensor_checker.Check(3), NO_MATCH) self.assertEqual(tensor_checker.Check(3.0), NO_MATCH) @@ -119,7 +119,7 @@ def testUnionChecker(self): float_checker = dispatch.MakeInstanceChecker(float) str_checker = dispatch.MakeInstanceChecker(str) none_checker = dispatch.MakeInstanceChecker(type(None)) - tensor_checker = dispatch.MakeInstanceChecker(ops.Tensor) + tensor_checker = dispatch.MakeInstanceChecker(tensor.Tensor) ragged_checker = dispatch.MakeInstanceChecker(ragged_tensor.RaggedTensor) t = constant_op.constant([1, 2, 3]) @@ -159,7 +159,7 @@ def testUnionChecker(self): def testListChecker(self): int_checker = dispatch.MakeInstanceChecker(int) - tensor_checker = dispatch.MakeInstanceChecker(ops.Tensor) + tensor_checker = dispatch.MakeInstanceChecker(tensor.Tensor) ragged_checker = dispatch.MakeInstanceChecker(ragged_tensor.RaggedTensor) np_int_checker = dispatch.MakeInstanceChecker(np.integer) @@ -269,7 +269,7 @@ def testSimpleSignature(self): def testUnion(self): rt_checker = dispatch.MakeInstanceChecker(ragged_tensor.RaggedTensor) - tensor_checker = dispatch.MakeInstanceChecker(ops.Tensor) + tensor_checker = dispatch.MakeInstanceChecker(tensor.Tensor) rt_or_tensor = dispatch.MakeUnionChecker([rt_checker, tensor_checker]) checker = dispatch.PySignatureChecker([(0, rt_or_tensor), (1, rt_or_tensor)]) @@ -383,7 +383,7 @@ def testListAndUnionDispatch(self): (None,)) rt_checker = dispatch.MakeInstanceChecker(ragged_tensor.RaggedTensor) - tensor_checker = dispatch.MakeInstanceChecker(ops.Tensor) + tensor_checker = dispatch.MakeInstanceChecker(tensor.Tensor) rt_or_t = dispatch.MakeUnionChecker([rt_checker, tensor_checker]) list_of_rt_or_t = dispatch.MakeListChecker(rt_or_t) f1 = lambda x, ys, name=None: 'f1' diff --git a/tensorflow/python/framework/python_api_parameter_converter_test.py b/tensorflow/python/framework/python_api_parameter_converter_test.py index 9787b6c0c53478..e6a4c705195aa4 100644 --- a/tensorflow/python/framework/python_api_parameter_converter_test.py +++ b/tensorflow/python/framework/python_api_parameter_converter_test.py @@ -23,7 +23,7 @@ from tensorflow.python.framework import _pywrap_python_api_info from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import test_util from tensorflow.python.framework._pywrap_python_api_parameter_converter import Convert @@ -82,7 +82,7 @@ def assertParamsEqual(self, actual_params, expected_params): self.assertParamEqual(actual, expected) def assertParamEqual(self, actual, expected): - if isinstance(actual, ops.Tensor): + if isinstance(actual, tensor.Tensor): self.assertAllEqual(actual, expected) else: self.assertEqual(actual, expected) diff --git a/tensorflow/python/framework/python_tensor_converter_test.py b/tensorflow/python/framework/python_tensor_converter_test.py index 413770b973b7d5..3257b8e7de91f4 100644 --- a/tensorflow/python/framework/python_tensor_converter_test.py +++ b/tensorflow/python/framework/python_tensor_converter_test.py @@ -24,7 +24,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import indexed_slices -from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.framework import test_util from tensorflow.python.platform import googletest @@ -47,7 +47,7 @@ def makePythonTensorConverter(self): def testConvertIntWithInferredDType(self): converter = self.makePythonTensorConverter() result, dtype, used_fallback = converter.Convert(12, types_pb2.DT_INVALID) - self.assertIsInstance(result, ops.Tensor) + self.assertIsInstance(result, tensor.Tensor) self.assertAllEqual(result, 12) self.assertEqual(dtype, types_pb2.DT_INT32) self.assertEqual(used_fallback, not context.executing_eagerly()) @@ -55,7 +55,7 @@ def testConvertIntWithInferredDType(self): def testConvertIntWithExplicitDtype(self): converter = self.makePythonTensorConverter() result, dtype, used_fallback = converter.Convert(12, types_pb2.DT_INT64) - self.assertIsInstance(result, ops.Tensor) + self.assertIsInstance(result, tensor.Tensor) self.assertAllEqual(result, 12) self.assertEqual(dtype, types_pb2.DT_INT64) self.assertEqual(used_fallback, not context.executing_eagerly()) @@ -74,7 +74,7 @@ def testConvertTensorWithInferredDType(self): converter = self.makePythonTensorConverter() result, dtype, used_fallback = converter.Convert( constant_op.constant([1, 2, 3]), types_pb2.DT_INVALID) - self.assertIsInstance(result, ops.Tensor) + self.assertIsInstance(result, tensor.Tensor) self.assertAllEqual(result, [1, 2, 3]) self.assertEqual(dtype, types_pb2.DT_INT32) self.assertFalse(used_fallback) @@ -83,7 +83,7 @@ def testConvertTensorWithExplicitDtype(self): converter = self.makePythonTensorConverter() result, dtype, used_fallback = converter.Convert( constant_op.constant([1, 2, 3], dtypes.int64), types_pb2.DT_INT64) - self.assertIsInstance(result, ops.Tensor) + self.assertIsInstance(result, tensor.Tensor) self.assertAllEqual(result, [1, 2, 3]) self.assertEqual(dtype, types_pb2.DT_INT64) self.assertFalse(used_fallback) @@ -101,7 +101,7 @@ def testConvertListWithInferredDType(self): converter = self.makePythonTensorConverter() result, dtype, used_fallback = converter.Convert([[1, 2, 3], [4, 5, 6]], types_pb2.DT_INVALID) - self.assertIsInstance(result, ops.Tensor) + self.assertIsInstance(result, tensor.Tensor) self.assertAllEqual(result, [[1, 2, 3], [4, 5, 6]]) self.assertEqual(dtype, types_pb2.DT_INT32) self.assertEqual(used_fallback, not context.executing_eagerly()) @@ -110,7 +110,7 @@ def testConvertListWithExplicitDtype(self): converter = self.makePythonTensorConverter() result, dtype, used_fallback = converter.Convert([[1, 2, 3], [4, 5, 6]], types_pb2.DT_INT64) - self.assertIsInstance(result, ops.Tensor) + self.assertIsInstance(result, tensor.Tensor) self.assertAllEqual(result, [[1, 2, 3], [4, 5, 6]]) self.assertEqual(dtype, types_pb2.DT_INT64) self.assertEqual(used_fallback, not context.executing_eagerly()) @@ -137,7 +137,7 @@ def testConvertNumpyArrayWithInferredDType(self): converter = self.makePythonTensorConverter() x = np.array([[1, 2, 3], [4, 5, 6]], np.int32) result, dtype, used_fallback = converter.Convert(x, types_pb2.DT_INVALID) - self.assertIsInstance(result, ops.Tensor) + self.assertIsInstance(result, tensor.Tensor) self.assertAllEqual(result, [[1, 2, 3], [4, 5, 6]]) self.assertEqual(dtype, types_pb2.DT_INT32) self.assertEqual(used_fallback, not context.executing_eagerly()) @@ -146,7 +146,7 @@ def testConvertNumpyArrayWithExplicitDtype(self): converter = self.makePythonTensorConverter() x = np.array([[1, 2, 3], [4, 5, 6]], np.int32) result, dtype, used_fallback = converter.Convert(x, types_pb2.DT_INT64) - self.assertIsInstance(result, ops.Tensor) + self.assertIsInstance(result, tensor.Tensor) self.assertAllEqual(result, [[1, 2, 3], [4, 5, 6]]) self.assertEqual(dtype, types_pb2.DT_INT64) self.assertEqual(used_fallback, not context.executing_eagerly()) @@ -173,7 +173,7 @@ def testConvertIndexedSlicesWithInferredDType(self): constant_op.constant([1], dtypes.int64, name="x_indices"), constant_op.constant([3, 3], dtypes.int64, name="x_shape")) result, dtype, used_fallback = converter.Convert(x, types_pb2.DT_INVALID) - self.assertIsInstance(result, ops.Tensor) + self.assertIsInstance(result, tensor.Tensor) self.assertAllEqual(result, [[0, 0, 0], [1, 2, 3], [0, 0, 0]]) self.assertEqual(dtype, types_pb2.DT_INT32) self.assertTrue(used_fallback) @@ -185,7 +185,7 @@ def testConvertIndexedSlicesWithExplicitDtype(self): constant_op.constant([1], dtypes.int64, name="x_indices"), constant_op.constant([3, 3], dtypes.int64, name="x_shape")) result, dtype, used_fallback = converter.Convert(x, types_pb2.DT_INT32) - self.assertIsInstance(result, ops.Tensor) + self.assertIsInstance(result, tensor.Tensor) self.assertAllEqual(result, [[0, 0, 0], [1, 2, 3], [0, 0, 0]]) self.assertEqual(dtype, types_pb2.DT_INT32) self.assertTrue(used_fallback) diff --git a/tensorflow/python/framework/smart_cond.py b/tensorflow/python/framework/smart_cond.py index 67708b3aece98a..efaee2c1549111 100644 --- a/tensorflow/python/framework/smart_cond.py +++ b/tensorflow/python/framework/smart_cond.py @@ -14,7 +14,7 @@ # ============================================================================== """smart_cond and related utilities.""" -from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.framework import tensor_util from tensorflow.python.ops import cond from tensorflow.python.ops import control_flow_case @@ -70,7 +70,7 @@ def smart_constant_value(pred): Raises: TypeError: If `pred` is not a Tensor or bool. """ - if isinstance(pred, ops.Tensor): + if isinstance(pred, tensor.Tensor): pred_value = tensor_util.constant_value(pred) # TODO(skyewm): consider folding this into tensor_util.constant_value. # pylint: disable=protected-access diff --git a/tensorflow/python/framework/sparse_tensor_test.py b/tensorflow/python/framework/sparse_tensor_test.py index deded692a80f58..fbd3fdc881aa29 100644 --- a/tensorflow/python/framework/sparse_tensor_test.py +++ b/tensorflow/python/framework/sparse_tensor_test.py @@ -24,8 +24,8 @@ from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import test_util from tensorflow.python.framework.type_utils import fulltypes_for_flat_tensors from tensorflow.python.ops import array_ops @@ -158,7 +158,7 @@ def test_simple(self): sp = sparse_tensor.SparseTensor(indices, values, dense_shape) self.assertIsInstance(sp.shape, tensor_shape.TensorShape) - self.assertIsInstance(sp.dense_shape, ops.Tensor) + self.assertIsInstance(sp.dense_shape, tensor_lib.Tensor) self.assertEqual(sp.shape.as_list(), [5, 5]) def test_unknown_shape(self): @@ -172,7 +172,7 @@ def my_func(dense_shape): return sp my_func.get_concrete_function( - dense_shape=tensor_spec.TensorSpec( + dense_shape=tensor_lib.TensorSpec( dtype=dtypes.int64, shape=[2,])) def test_partial_shape(self): @@ -188,7 +188,7 @@ def my_func(x): return sp my_func.get_concrete_function( - x=tensor_spec.TensorSpec(dtype=dtypes.int64, shape=[])) + x=tensor_lib.TensorSpec(dtype=dtypes.int64, shape=[])) def test_neg_shape(self): indices = [[0, 2]] @@ -211,7 +211,7 @@ def my_func(x): return sp my_func.get_concrete_function( - x=tensor_spec.TensorSpec(dtype=dtypes.int64, shape=[None, None])) + x=tensor_lib.TensorSpec(dtype=dtypes.int64, shape=[None, None])) def test_unknown_rank(self): @@ -224,7 +224,7 @@ def my_func(dense_shape): return sp my_func.get_concrete_function( - dense_shape=tensor_spec.TensorSpec(dtype=dtypes.int64, shape=[None])) + dense_shape=tensor_lib.TensorSpec(dtype=dtypes.int64, shape=[None])) @test_util.run_all_in_graph_and_eager_modes @@ -266,14 +266,14 @@ def testSerialize(self, st_spec, expected): @parameterized.parameters([ (sparse_tensor.SparseTensorSpec(dtype=dtypes.string), [ - tensor_spec.TensorSpec([None, None], dtypes.int64), - tensor_spec.TensorSpec([None], dtypes.string), - tensor_spec.TensorSpec([None], dtypes.int64) + tensor_lib.TensorSpec([None, None], dtypes.int64), + tensor_lib.TensorSpec([None], dtypes.string), + tensor_lib.TensorSpec([None], dtypes.int64) ]), (sparse_tensor.SparseTensorSpec(shape=[5, None, None]), [ - tensor_spec.TensorSpec([None, 3], dtypes.int64), - tensor_spec.TensorSpec([None], dtypes.float32), - tensor_spec.TensorSpec([3], dtypes.int64) + tensor_lib.TensorSpec([None, 3], dtypes.int64), + tensor_lib.TensorSpec([None], dtypes.float32), + tensor_lib.TensorSpec([3], dtypes.int64) ]), ]) def testComponentSpecs(self, st_spec, expected): @@ -331,7 +331,7 @@ def testFromNumpyComponents(self): ]) def testFlatTensorSpecs(self, st_spec): self.assertEqual(st_spec._flat_tensor_specs, - [tensor_spec.TensorSpec(None, dtypes.variant)]) + [tensor_lib.TensorSpec(None, dtypes.variant)]) @parameterized.parameters([ dtypes.float32, diff --git a/tensorflow/python/framework/subscribe.py b/tensorflow/python/framework/subscribe.py index e68412b4982df3..3e48388542930b 100644 --- a/tensorflow/python/framework/subscribe.py +++ b/tensorflow/python/framework/subscribe.py @@ -18,6 +18,7 @@ import re from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.ops import array_ops from tensorflow.python.ops import variables from tensorflow.python.platform import tf_logging as logging @@ -42,7 +43,7 @@ def _recursive_apply(tensors, apply_fn): `TypeError` if undefined type in the tensors structure. """ tensors_type = type(tensors) - if isinstance(tensors, ops.Tensor): + if isinstance(tensors, tensor_lib.Tensor): return apply_fn(tensors) elif isinstance(tensors, variables.Variable): return apply_fn(tensors.value()) @@ -171,7 +172,9 @@ def _subscribe_extend(tensor, side_effects): for s in side_effects: outs += s(source_tensor) - out_ops = [out.op if isinstance(out, ops.Tensor) else out for out in outs] + out_ops = [ + out.op if isinstance(out, tensor_lib.Tensor) else out for out in outs + ] tensor.op._add_control_inputs(out_ops) # pylint: disable=protected-access return tensor diff --git a/tensorflow/python/framework/tensor_util.py b/tensorflow/python/framework/tensor_util.py index f4dc3b4e43a2e3..836c116506f8d2 100644 --- a/tensorflow/python/framework/tensor_util.py +++ b/tensorflow/python/framework/tensor_util.py @@ -933,7 +933,7 @@ def bar(self, partial): or None if it cannot be calculated. Raises: - TypeError: if tensor is not an ops.Tensor. + TypeError: if tensor is not an tensor.Tensor. """ if isinstance(tensor, core.Value): try: diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py index 1727db9016c890..e9096a925eaf25 100644 --- a/tensorflow/python/framework/test_util.py +++ b/tensorflow/python/framework/test_util.py @@ -60,6 +60,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util from tensorflow.python.framework import tfrt_utils @@ -376,7 +377,7 @@ def NHWCToNCHW(input_tensor): """ # tensor dim -> new axis order new_axes = {3: [0, 2, 1], 4: [0, 3, 1, 2], 5: [0, 4, 1, 2, 3]} - if isinstance(input_tensor, ops.Tensor): + if isinstance(input_tensor, tensor_lib.Tensor): ndims = input_tensor.shape.ndims return array_ops.transpose(input_tensor, new_axes[ndims]) else: @@ -401,7 +402,7 @@ def NHWCToNCHW_VECT_C(input_shape_or_tensor): divisible by 4. """ permutations = {5: [0, 3, 1, 2, 4], 6: [0, 4, 1, 2, 3, 5]} - is_tensor = isinstance(input_shape_or_tensor, ops.Tensor) + is_tensor = isinstance(input_shape_or_tensor, tensor_lib.Tensor) temp_shape = ( input_shape_or_tensor.shape.as_list() if is_tensor else input_shape_or_tensor) @@ -435,7 +436,7 @@ def NCHW_VECT_CToNHWC(input_shape_or_tensor): ValueError: if last dimension of `input_shape_or_tensor` is not 4. """ permutations = {5: [0, 2, 3, 1, 4], 6: [0, 2, 3, 4, 1, 5]} - is_tensor = isinstance(input_shape_or_tensor, ops.Tensor) + is_tensor = isinstance(input_shape_or_tensor, tensor_lib.Tensor) input_shape = ( input_shape_or_tensor.shape.as_list() if is_tensor else input_shape_or_tensor) @@ -462,7 +463,7 @@ def NCHWToNHWC(input_tensor): """ # tensor dim -> new axis order new_axes = {4: [0, 2, 3, 1], 5: [0, 2, 3, 4, 1]} - if isinstance(input_tensor, ops.Tensor): + if isinstance(input_tensor, tensor_lib.Tensor): ndims = input_tensor.shape.ndims return array_ops.transpose(input_tensor, new_axes[ndims]) else: @@ -806,7 +807,7 @@ def decorator(self, **kwargs): def _is_tensorflow_object(obj): try: return isinstance(obj, - (ops.Tensor, variables.Variable, + (tensor_lib.Tensor, variables.Variable, tensor_shape.Dimension, tensor_shape.TensorShape)) except (ReferenceError, AttributeError): # If the object no longer exists, we don't care about it. @@ -1545,7 +1546,7 @@ def decorated(*args, **kwds): tensor_args = [] tensor_indices = [] for i, arg in enumerate(args): - if isinstance(arg, (ops.Tensor, variables.Variable)): + if isinstance(arg, (tensor_lib.Tensor, variables.Variable)): tensor_args.append(arg) tensor_indices.append(i) @@ -3583,18 +3584,18 @@ def assertShapeEqual(self, input_a, input_b, msg=None): Raises: TypeError: If the arguments have the wrong type. """ - if not isinstance(input_a, (np.ndarray, np.generic, ops.Tensor)): + if not isinstance(input_a, (np.ndarray, np.generic, tensor_lib.Tensor)): raise TypeError( "input_a must be a Numpy ndarray, Numpy scalar, or a Tensor." f"Instead received {type(input_a)}") - if not isinstance(input_b, (np.ndarray, np.generic, ops.Tensor)): + if not isinstance(input_b, (np.ndarray, np.generic, tensor_lib.Tensor)): raise TypeError( "input_b must be a Numpy ndarray, Numpy scalar, or a Tensor." f"Instead received {type(input_b)}") shape_a = input_a.get_shape().as_list() if isinstance( - input_a, ops.Tensor) else input_a.shape + input_a, tensor_lib.Tensor) else input_a.shape shape_b = input_b.get_shape().as_list() if isinstance( - input_b, ops.Tensor) else input_b.shape + input_b, tensor_lib.Tensor) else input_b.shape self.assertAllEqual(shape_a, shape_b, msg=msg) def assertDeviceEqual(self, device1, device2, msg=None): @@ -3641,7 +3642,7 @@ def _GetPyList(self, a): """Converts `a` to a nested python list.""" if isinstance(a, ragged_tensor.RaggedTensor): return self.evaluate(a).to_list() - elif isinstance(a, ops.Tensor): + elif isinstance(a, tensor_lib.Tensor): a = self.evaluate(a) return a.tolist() if isinstance(a, np.ndarray) else a elif isinstance(a, np.ndarray): diff --git a/tensorflow/python/framework/test_util_test.py b/tensorflow/python/framework/test_util_test.py index b6d28f4a24d7de..5ffd054c19e9c2 100644 --- a/tensorflow/python/framework/test_util_test.py +++ b/tensorflow/python/framework/test_util_test.py @@ -41,6 +41,7 @@ from tensorflow.python.framework import indexed_slices from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed +from tensorflow.python.framework import tensor from tensorflow.python.framework import test_ops from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops @@ -1191,9 +1192,9 @@ def add_two(x): if run_eagerly: self.assertTrue(isinstance(t, ops.EagerTensor) for t in results) else: - self.assertTrue(isinstance(t, ops.Tensor) for t in results) + self.assertTrue(isinstance(t, tensor.Tensor) for t in results) else: - self.assertTrue(isinstance(t, ops.Tensor) for t in results) + self.assertTrue(isinstance(t, tensor.Tensor) for t in results) class SyncDevicesTest(test_util.TensorFlowTestCase): diff --git a/tensorflow/python/framework/weak_tensor.py b/tensorflow/python/framework/weak_tensor.py index a36d57c58f01f3..98ecce25026539 100644 --- a/tensorflow/python/framework/weak_tensor.py +++ b/tensorflow/python/framework/weak_tensor.py @@ -24,7 +24,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import extension_type -from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor as tensor_lib _ALLOWED_WEAK_DTYPES = ( dtypes.int32, @@ -62,7 +62,7 @@ class WeakTensor(extension_type.ExtensionType): # __name__ is required for serialization in SavedModel. __name__ = "tf.WeakTensor" - tensor: ops.Tensor + tensor: tensor_lib.Tensor def __validate__(self): if self.tensor.dtype not in _ALLOWED_WEAK_DTYPES: diff --git a/tensorflow/python/framework/weak_tensor_test.py b/tensorflow/python/framework/weak_tensor_test.py index 9e8229e0541899..8fec36d82e3fe4 100644 --- a/tensorflow/python/framework/weak_tensor_test.py +++ b/tensorflow/python/framework/weak_tensor_test.py @@ -22,7 +22,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_spec +from tensorflow.python.framework import tensor from tensorflow.python.framework import test_util from tensorflow.python.framework import weak_tensor from tensorflow.python.module import module @@ -157,7 +157,7 @@ def f(c, a, b): def test_weak_tensor_in_tf_func_with_spec(self): # Test weak tensor spec with matching input. - weak_tensor_spec = weak_tensor.WeakTensor.Spec(tensor_spec.TensorSpec([2])) + weak_tensor_spec = weak_tensor.WeakTensor.Spec(tensor.TensorSpec([2])) wt = weak_tensor.WeakTensor(constant_op.constant([1.0, 2.0])) @def_function.function(input_signature=[weak_tensor_spec]) @@ -185,8 +185,8 @@ class CustomModule(module.Module): @def_function.function def __call__(self, x): - if isinstance(x, ops.Tensor): - raise TypeError('Weak tensor should not be ops.Tensor type.') + if isinstance(x, tensor.Tensor): + raise TypeError('Weak tensor should not be tensor.Tensor type.') return x m = CustomModule() From 68b6874ac715d880d6bf01f006f5c5bbad66e172 Mon Sep 17 00:00:00 2001 From: Fiona Lang Date: Tue, 11 Jul 2023 10:31:11 -0700 Subject: [PATCH 135/376] Update ops.Tensor references to //third_party/tensorflow/python/framework/tensor.py. PiperOrigin-RevId: 547230345 --- tensorflow/python/ops/BUILD | 54 +++++++++--- tensorflow/python/ops/array_grad.py | 7 +- tensorflow/python/ops/array_ops.py | 32 ++++--- tensorflow/python/ops/batch_ops.py | 6 +- tensorflow/python/ops/bincount_ops.py | 3 +- tensorflow/python/ops/check_ops.py | 9 +- tensorflow/python/ops/composite_tensor_ops.py | 3 +- tensorflow/python/ops/cond.py | 5 +- tensorflow/python/ops/cond_v2.py | 3 +- tensorflow/python/ops/control_flow_case.py | 3 +- tensorflow/python/ops/control_flow_grad.py | 5 +- tensorflow/python/ops/control_flow_ops.py | 34 +++---- .../python/ops/control_flow_switch_case.py | 3 +- tensorflow/python/ops/functional_ops.py | 3 +- tensorflow/python/ops/gradient_checker.py | 3 +- tensorflow/python/ops/gradient_checker_v2.py | 3 +- tensorflow/python/ops/gradients_test.py | 16 ++-- tensorflow/python/ops/gradients_util.py | 88 ++++++++++++------- tensorflow/python/ops/image_ops_impl.py | 3 +- tensorflow/python/ops/io_ops.py | 5 +- tensorflow/python/ops/linalg_ops.py | 3 +- tensorflow/python/ops/linalg_ops_impl.py | 12 ++- tensorflow/python/ops/list_ops.py | 5 +- tensorflow/python/ops/lookup_ops.py | 5 +- tensorflow/python/ops/math_grad.py | 15 ++-- tensorflow/python/ops/math_ops.py | 62 +++++++------ tensorflow/python/ops/math_ops_test.py | 7 +- tensorflow/python/ops/nn_ops.py | 10 ++- tensorflow/python/ops/op_selector.py | 15 ++-- tensorflow/python/ops/parsing_config.py | 3 +- tensorflow/python/ops/random_ops_util.py | 4 +- .../python/ops/resource_variable_ops.py | 16 ++-- tensorflow/python/ops/rnn.py | 3 +- tensorflow/python/ops/rnn_cell_impl.py | 6 +- tensorflow/python/ops/session_ops.py | 3 +- tensorflow/python/ops/sparse_ops.py | 16 ++-- tensorflow/python/ops/special_math_ops.py | 3 +- tensorflow/python/ops/stateful_random_ops.py | 3 +- tensorflow/python/ops/summary_ops_v2.py | 15 ++-- tensorflow/python/ops/tensor_array_ops.py | 31 ++++--- tensorflow/python/ops/variable_scope.py | 5 +- tensorflow/python/ops/variables.py | 30 +++++-- tensorflow/python/ops/weak_tensor_ops.py | 6 +- tensorflow/python/ops/weak_tensor_ops_test.py | 18 ++-- tensorflow/python/ops/while_v2.py | 3 +- .../ops/while_v2_indexed_slices_rewriter.py | 4 +- 46 files changed, 364 insertions(+), 227 deletions(-) diff --git a/tensorflow/python/ops/BUILD b/tensorflow/python/ops/BUILD index 1da5270622505d..3c6556c5506b96 100644 --- a/tensorflow/python/ops/BUILD +++ b/tensorflow/python/ops/BUILD @@ -95,6 +95,7 @@ py_strict_library( "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:function", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/util:deprecation", "//tensorflow/python/util:dispatch", @@ -384,7 +385,7 @@ py_strict_library( ":batch_ops_gen", "//tensorflow/python/eager:def_function", "//tensorflow/python/framework:ops", - "//tensorflow/python/framework:tensor_spec", + "//tensorflow/python/framework:tensor", "//tensorflow/python/util:nest", "//tensorflow/python/util:tf_export", ], @@ -747,6 +748,7 @@ py_strict_library( "//tensorflow/python/framework:indexed_slices", "//tensorflow/python/framework:ops", "//tensorflow/python/framework:sparse_tensor", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/framework:tensor_util", ], @@ -789,6 +791,7 @@ py_strict_library( "//tensorflow/python/framework:indexed_slices", "//tensorflow/python/framework:ops", "//tensorflow/python/framework:sparse_tensor", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_conversion_registry", "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/framework:tensor_util", @@ -902,6 +905,7 @@ py_strict_library( "//tensorflow/python/framework:errors", "//tensorflow/python/framework:ops", "//tensorflow/python/framework:sparse_tensor", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/framework:tensor_util", "//tensorflow/python/util:compat", @@ -1130,6 +1134,7 @@ py_strict_library( "//tensorflow/python/framework:indexed_slices", "//tensorflow/python/framework:ops", "//tensorflow/python/framework:sparse_tensor", + "//tensorflow/python/framework:tensor", ], ) @@ -1151,8 +1156,8 @@ py_strict_library( "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:indexed_slices", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_shape", - "//tensorflow/python/framework:tensor_spec", "//tensorflow/python/framework:tensor_util", "//tensorflow/python/framework:type_spec", "//tensorflow/python/util:compat", @@ -1176,6 +1181,7 @@ py_strict_library( "//tensorflow/python/framework:constant_op", "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/platform:tf_logging", "//tensorflow/python/util:dispatch", "//tensorflow/python/util:tf_export", @@ -1193,6 +1199,7 @@ py_strict_library( ":math_ops", "//tensorflow/python/eager:context", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/util:tf_export", ], ) @@ -1211,6 +1218,7 @@ py_strict_library( "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:indexed_slices", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_util", "//tensorflow/python/platform:tf_logging", "//tensorflow/python/types:core", @@ -1372,6 +1380,7 @@ py_strict_library( "//tensorflow/python/framework:func_graph", "//tensorflow/python/framework:indexed_slices", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/framework:tensor_util", "//tensorflow/python/framework:type_spec", @@ -1408,6 +1417,7 @@ py_strict_library( "//tensorflow/python/framework:func_graph", "//tensorflow/python/framework:indexed_slices", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/framework:tensor_spec", "//tensorflow/python/framework:tensor_util", @@ -1429,6 +1439,7 @@ py_strict_library( "//tensorflow/python/framework:constant_op", "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_conversion", "//tensorflow/python/util:deprecation", "//tensorflow/python/util:dispatch", @@ -1651,6 +1662,7 @@ py_strict_library( "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:indexed_slices", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/platform:tf_logging", "//tensorflow/python/util:compat", @@ -1824,6 +1836,7 @@ py_strict_library( "//tensorflow/python/eager:context", "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/lib/io:lib", "//tensorflow/python/util:deprecation", "//tensorflow/python/util:tf_export", @@ -1861,6 +1874,7 @@ py_strict_library( ":math_ops", "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/util:deprecation", "//tensorflow/python/util:dispatch", "//tensorflow/python/util:tf_export", @@ -1877,6 +1891,7 @@ py_strict_library( ":math_ops", "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/util:compat", "//third_party/py/numpy", ], @@ -1941,6 +1956,7 @@ py_strict_library( "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", "//tensorflow/python/framework:sparse_tensor", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/framework:tensor_util", "//tensorflow/python/ops/ragged:ragged_tensor", @@ -1970,6 +1986,7 @@ py_strict_library( "//tensorflow/python/framework:constant_op", "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_util", "//third_party/py/numpy", ], @@ -1981,6 +1998,7 @@ py_strict_library( srcs_version = "PY3", deps = [ "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/util:object_identity", ], ) @@ -2004,6 +2022,7 @@ py_strict_library( "//tensorflow/python/framework:indexed_slices", "//tensorflow/python/framework:ops", "//tensorflow/python/framework:sparse_tensor", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_conversion_registry", "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/framework:tensor_util", @@ -2068,7 +2087,6 @@ py_strict_library( "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_conversion_registry", "//tensorflow/python/framework:tensor_shape", - "//tensorflow/python/framework:tensor_spec", "//tensorflow/python/saved_model:nested_structure_coder", "//tensorflow/python/trackable:base", "//tensorflow/python/types:core", @@ -2110,6 +2128,7 @@ py_strict_library( "//tensorflow/python/framework:cpp_shape_inference_proto_py", "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/framework:tensor_util", "//third_party/py/numpy", @@ -2183,6 +2202,7 @@ py_strict_library( "//tensorflow/python/framework:graph_util", "//tensorflow/python/framework:ops", "//tensorflow/python/framework:random_seed", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/framework:tensor_util", "//tensorflow/python/platform:device_context", @@ -2222,6 +2242,7 @@ py_strict_library( "//tensorflow/python/framework:constant_op", "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/ops/ragged:ragged_math_ops", "//tensorflow/python/ops/ragged:ragged_tensor", @@ -2340,6 +2361,7 @@ py_strict_library( "//tensorflow/python/framework:config", "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/trackable:autotrackable", "//tensorflow/python/util:nest", "//tensorflow/python/util:tf_export", @@ -2404,6 +2426,7 @@ py_strict_library( "//tensorflow/python/framework:constant_op", "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/framework:tensor_util", "//tensorflow/python/util:deprecation", @@ -2485,6 +2508,7 @@ py_strict_library( "//tensorflow/python/framework:device", "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/util:compat", "//tensorflow/python/util:tf_export", "//third_party/py/numpy", @@ -2514,10 +2538,9 @@ py_strict_library( ":bitwise_ops", ":math_ops", ":stateless_random_ops_v2_gen", - "//tensorflow/python/eager:context", "//tensorflow/python/framework:constant_op", "//tensorflow/python/framework:dtypes", - "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/util:tf_export", ], ) @@ -2556,6 +2579,7 @@ py_strict_library( "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", "//tensorflow/python/framework:sparse_tensor", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_conversion", "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/framework:tensor_util", @@ -2704,6 +2728,7 @@ py_strict_library( ":special_math_ops_gen", "//tensorflow/compiler/tf2xla/ops:gen_xla_ops", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/platform:tf_logging", "//tensorflow/python/util:deprecation", @@ -2905,6 +2930,7 @@ py_strict_library( "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", "//tensorflow/python/framework:smart_cond", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_util", "//tensorflow/python/platform:tf_logging", "//tensorflow/python/trackable:resource", @@ -2963,8 +2989,8 @@ py_strict_library( "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:errors", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_shape", - "//tensorflow/python/framework:tensor_spec", "//tensorflow/python/framework:tensor_util", "//tensorflow/python/framework:type_spec", "//tensorflow/python/framework:type_spec_registry", @@ -2987,6 +3013,7 @@ py_strict_library( "//tensorflow/python/framework:composite_tensor", "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/saved_model:nested_structure_coder", "//tensorflow/python/util:nest", ], @@ -3007,6 +3034,7 @@ py_strict_library( "//tensorflow/python/eager:monitoring", "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_conversion_registry", "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/platform:tf_logging", @@ -3035,6 +3063,7 @@ py_strict_library( "//tensorflow/python/eager:context", "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_conversion_registry", "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/trackable:base", @@ -3101,6 +3130,7 @@ py_strict_library( "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:indexed_slices", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/platform:tf_logging", "//tensorflow/python/util:deprecation", "//tensorflow/python/util:tf_export", @@ -3119,6 +3149,7 @@ py_strict_library( "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:indexed_slices", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/platform:tf_logging", "//tensorflow/python/util:tf_export", "//third_party/py/numpy", @@ -3308,7 +3339,7 @@ cuda_py_strict_test( "//tensorflow/python/framework:function", "//tensorflow/python/framework:indexed_slices", "//tensorflow/python/framework:ops", - "//tensorflow/python/framework:tensor_spec", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:test_lib", "//tensorflow/python/framework:test_ops", "//tensorflow/python/ops/ragged:ragged_factory_ops", @@ -3523,6 +3554,7 @@ cuda_py_strict_test( "//tensorflow/python/framework:errors", "//tensorflow/python/framework:indexed_slices", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:test_lib", "//tensorflow/python/ops/ragged:ragged_factory_ops", "//tensorflow/python/platform:test", @@ -4279,7 +4311,7 @@ py_strict_library( "//tensorflow/python/eager:context", "//tensorflow/python/framework:constant_op", "//tensorflow/python/framework:dtypes", - "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/framework:tensor_util", "//tensorflow/python/keras/layers/legacy_rnn:rnn_cell_impl", @@ -4315,6 +4347,7 @@ py_strict_library( "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", "//tensorflow/python/framework:random_seed", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/framework:tensor_util", "//tensorflow/python/util:deprecation", @@ -4360,7 +4393,7 @@ py_strict_library( "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:func_graph", "//tensorflow/python/framework:indexed_slices", - "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_conversion", "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/util:nest", @@ -4414,7 +4447,7 @@ py_strict_library( srcs = ["weak_tensor_ops.py"], deps = [ ":weak_tensor_ops_list", - "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:weak_tensor", "//tensorflow/python/util:dispatch", ], @@ -4435,6 +4468,7 @@ py_strict_test( "//tensorflow/python/framework:constant_op", "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:test_lib", "//tensorflow/python/framework:weak_tensor", "//tensorflow/python/ops/numpy_ops:np_array_ops", diff --git a/tensorflow/python/ops/array_grad.py b/tensorflow/python/ops/array_grad.py index 4ae6802cc042df..ccbcfe11efa3de 100644 --- a/tensorflow/python/ops/array_grad.py +++ b/tensorflow/python/ops/array_grad.py @@ -22,6 +22,7 @@ from tensorflow.python.framework import indexed_slices as indexed_slices_lib from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops @@ -91,7 +92,7 @@ def _ExtractInputShapes(inputs): for x in inputs: input_shape = array_ops.shape(x) if not isinstance(input_shape, - ops.Tensor) or input_shape.op.type != "Const": + tensor.Tensor) or input_shape.op.type != "Const": fully_known = False break sizes.append(input_shape) @@ -109,7 +110,7 @@ def _ExtractInputShapes(inputs): input_values = op.inputs[start_value_index:end_value_index] out_grads = [] - if isinstance(grad, ops.Tensor): + if isinstance(grad, tensor.Tensor): if context.executing_eagerly() or isinstance(concat_dim, ops.EagerTensor): # Using mod here for convenience since concat_dim is already verified # in concat implementation to be within the allowed [-rank, rank) range. @@ -1206,7 +1207,7 @@ def _BroadcastToGrad(op, grad): input_value = op.inputs[0] broadcast_shape = op.inputs[1] shape_dtype = dtypes.int32 - if isinstance(broadcast_shape, ops.Tensor): + if isinstance(broadcast_shape, tensor.Tensor): shape_dtype = broadcast_shape.dtype input_value_shape = array_ops.shape(input_value, out_type=shape_dtype) diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index 2d49e40b5e06fe..753f1b03d6789a 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -29,6 +29,7 @@ from tensorflow.python.framework import indexed_slices from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_conversion_registry from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util @@ -1030,7 +1031,7 @@ def _slice_helper(tensor, slice_spec, var=None): appear in TensorFlow's generated documentation. Args: - tensor: An ops.Tensor object. + tensor: An tensor.Tensor object. slice_spec: The arguments to Tensor.__getitem__. var: In the case of variable slice assignment, the Variable object to slice (i.e. tensor is the read-only view of this variable). @@ -1048,9 +1049,11 @@ def _slice_helper(tensor, slice_spec, var=None): if var is None and ops._numpy_style_slicing: # pylint: disable=protected-access return tensor._numpy_style_getitem(slice_spec) # pylint: disable=protected-access - if isinstance(slice_spec, bool) or \ - (isinstance(slice_spec, ops.Tensor) and slice_spec.dtype == dtypes.bool) or \ - (isinstance(slice_spec, np.ndarray) and slice_spec.dtype == bool): + if (isinstance(slice_spec, bool) + or (isinstance(slice_spec, tensor_lib.Tensor) + and slice_spec.dtype == dtypes.bool) + or (isinstance(slice_spec, np.ndarray) + and slice_spec.dtype == bool)): return boolean_mask(tensor=tensor, mask=slice_spec) if not isinstance(slice_spec, (list, tuple)): @@ -1067,7 +1070,7 @@ def _slice_helper(tensor, slice_spec, var=None): # Finds the best dtype for begin, end, and strides. dtype = None for t in [s.start, s.stop, s.step]: - if t is None or not isinstance(t, ops.Tensor): + if t is None or not isinstance(t, tensor_lib.Tensor): continue if t.dtype == dtypes.int64: dtype = dtypes.int64 @@ -1117,8 +1120,9 @@ def _slice_helper(tensor, slice_spec, var=None): begin.append(s) end.append(s + 1) # TODO(mdan): Investigate why we can't set int32 here. - if isinstance(s, ops.Tensor) and (s.dtype == dtypes.int16 or - s.dtype == dtypes.int64): + if ( + isinstance(s, tensor_lib.Tensor) + and (s.dtype == dtypes.int16 or s.dtype == dtypes.int64)): strides.append(constant_op.constant(1, dtype=s.dtype)) else: strides.append(1) @@ -1413,7 +1417,7 @@ def _SliceHelperVar(var, slice_spec): return _slice_helper(var.value(), slice_spec, var) -ops.Tensor._override_operator("__getitem__", _slice_helper) +tensor_lib.Tensor._override_operator("__getitem__", _slice_helper) @tf_export("parallel_stack") @@ -2887,7 +2891,7 @@ def zeros(shape, dtype=dtypes.float32, name=None, layout=None): else: zero = 0 - if not isinstance(shape, ops.Tensor): + if not isinstance(shape, tensor_lib.Tensor): try: if not context.executing_eagerly(): # Create a constant if it won't be very big. Otherwise, create a fill @@ -3202,7 +3206,7 @@ def ones(shape, dtype=dtypes.float32, name=None, layout=None): one = np.ones([]).astype(dtype.as_numpy_dtype) else: one = 1 - if not isinstance(shape, ops.Tensor): + if not isinstance(shape, tensor_lib.Tensor): try: if not context.executing_eagerly(): # Create a constant if it won't be very big. Otherwise, create a fill @@ -3403,7 +3407,7 @@ def sparse_placeholder(dtype, shape=None, name=None): dense_shape = placeholder(dtypes.int64, shape=[rank], name=shape_name) dense_shape_default = tensor_util.constant_value_as_shape(dense_shape) else: - if isinstance(shape, ops.Tensor): + if isinstance(shape, tensor_lib.Tensor): rank = shape.get_shape()[0] dense_shape_default = tensor_util.constant_value_as_shape(shape) else: @@ -3590,7 +3594,7 @@ def pad(tensor, paddings, mode="CONSTANT", name=None, constant_values=0): # pyl paddings_constant = _get_paddings_constant(paddings) input_shape = ( tensor_shape.TensorShape(tensor.shape) - if isinstance(tensor, ops.Tensor) else result.op.inputs[0].shape) + if isinstance(tensor, tensor_lib.Tensor) else result.op.inputs[0].shape) if (input_shape.ndims is not None and not result.shape.is_fully_defined() and paddings_constant is not None): new_shape = [] @@ -3618,7 +3622,7 @@ def _get_paddings_constant(paddings): A nested list or numbers or `None`, in which `None` indicates unknown padding size. """ - if isinstance(paddings, ops.Tensor): + if isinstance(paddings, tensor_lib.Tensor): return tensor_util.constant_value(paddings, partial=True) elif isinstance(paddings, (list, tuple)): return [_get_paddings_constant(x) for x in paddings] @@ -4402,7 +4406,7 @@ def one_hot(indices, def _all_dimensions(x): """Returns a 1D-tensor listing all dimensions in x.""" # Fast path: avoid creating Rank and Range ops if ndims is known. - if isinstance(x, ops.Tensor) and x.get_shape().ndims is not None: + if isinstance(x, tensor_lib.Tensor) and x.get_shape().ndims is not None: return constant_op.constant( np.arange(x.get_shape().ndims), dtype=dtypes.int32) if (isinstance(x, sparse_tensor.SparseTensor) and diff --git a/tensorflow/python/ops/batch_ops.py b/tensorflow/python/ops/batch_ops.py index dbe17201146d59..0361ea242946b3 100644 --- a/tensorflow/python/ops/batch_ops.py +++ b/tensorflow/python/ops/batch_ops.py @@ -16,7 +16,7 @@ """Operations for automatic batching and unbatching.""" from tensorflow.python.eager import def_function from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_spec +from tensorflow.python.framework import tensor from tensorflow.python.ops import gen_batch_ops # pylint: disable=wildcard-import from tensorflow.python.ops.gen_batch_ops import * @@ -92,14 +92,14 @@ def computation(*computation_args): return fn(*computation_args) computation = computation.get_concrete_function(*[ - tensor_spec.TensorSpec( + tensor.TensorSpec( dtype=x.dtype, shape=x.shape, name="batch_" + str(i)) for i, x in enumerate(args) ]) with ops.name_scope("batch") as name: for a in args: - if not isinstance(a, ops.Tensor): + if not isinstance(a, tensor.Tensor): raise ValueError("All arguments to functions decorated with " "`batch_function` are supposed to be Tensors; " f"found {a!r}.") diff --git a/tensorflow/python/ops/bincount_ops.py b/tensorflow/python/ops/bincount_ops.py index ce63aac1b0c5ba..92290a608844e9 100644 --- a/tensorflow/python/ops/bincount_ops.py +++ b/tensorflow/python/ops/bincount_ops.py @@ -17,6 +17,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.framework import tensor_conversion from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_math_ops @@ -221,7 +222,7 @@ def validate_dense_weights(values, weights, dtype=None): return array_ops.constant([], dtype=dtype) return array_ops.constant([], dtype=values.dtype) - if not isinstance(weights, ops.Tensor): + if not isinstance(weights, tensor.Tensor): raise ValueError( "Argument `weights` must be a tf.Tensor if `values` is a tf.Tensor. " f"Received weights={weights} of type: {type(weights).__name__}") diff --git a/tensorflow/python/ops/check_ops.py b/tensorflow/python/ops/check_ops.py index 102e66afb8a827..bc3d6266ad3428 100644 --- a/tensorflow/python/ops/check_ops.py +++ b/tensorflow/python/ops/check_ops.py @@ -24,6 +24,7 @@ from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops @@ -71,7 +72,7 @@ def _maybe_constant_value_string(t): - if not isinstance(t, ops.Tensor): + if not isinstance(t, tensor_lib.Tensor): return str(t) const_t = tensor_util.constant_value(t) if const_t is not None: @@ -417,7 +418,7 @@ def _pretty_print(data_item, summarize): Returns: An appropriate string representation of data_item """ - if isinstance(data_item, ops.Tensor): + if isinstance(data_item, tensor_lib.Tensor): arr = data_item.numpy() if np.isscalar(arr): # Tensor.numpy() returns a scalar for zero-dimensional tensors @@ -526,7 +527,7 @@ def assert_proper_iterable(values): `Tensor`, `SparseTensor`, `np.array`, `tf.compat.bytes_or_text_types`. """ unintentional_iterables = ( - (ops.Tensor, sparse_tensor.SparseTensor, np.ndarray) + (tensor_lib.Tensor, sparse_tensor.SparseTensor, np.ndarray) + compat.bytes_or_text_types ) if isinstance(values, unintentional_iterables): @@ -1979,7 +1980,7 @@ def is_numeric_tensor(tensor): Returns `False` if `tensor` is of a non-numeric type or if `tensor` is not a `tf.Tensor` object. """ - return isinstance(tensor, ops.Tensor) and tensor.dtype in NUMERIC_TYPES + return isinstance(tensor, tensor_lib.Tensor) and tensor.dtype in NUMERIC_TYPES @tf_export( diff --git a/tensorflow/python/ops/composite_tensor_ops.py b/tensorflow/python/ops/composite_tensor_ops.py index 5067aa7f6be823..51a44613f6ddb1 100644 --- a/tensorflow/python/ops/composite_tensor_ops.py +++ b/tensorflow/python/ops/composite_tensor_ops.py @@ -18,6 +18,7 @@ from tensorflow.python.framework import composite_tensor from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.ops import gen_composite_tensor_ops from tensorflow.python.saved_model import nested_structure_coder from tensorflow.python.util import nest @@ -74,7 +75,7 @@ def composite_tensor_from_variant(encoded, type_spec, name=None): TypeError: If `encoded` is not a Tensor with dtype=variant. InvalidArgumentError: If `encoded` is not compatible with `type_spec`. """ - if not isinstance(encoded, ops.Tensor): + if not isinstance(encoded, tensor.Tensor): raise TypeError(f"Expected `encoded` to be a Tensor, got {encoded!r}.") if encoded.dtype != dtypes.variant: raise TypeError("Expected `encoded` to have dtype=variant, got " diff --git a/tensorflow/python/ops/cond.py b/tensorflow/python/ops/cond.py index 02cbdbf182a30f..9fae845aaeb469 100644 --- a/tensorflow/python/ops/cond.py +++ b/tensorflow/python/ops/cond.py @@ -19,6 +19,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import indexed_slices from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import cond_v2 @@ -207,7 +208,9 @@ def f2(): return tf.add(y, 23) res_f_flat = nest.flatten(res_f, expand_composites=True) for (x, y) in zip(res_t_flat, res_f_flat): - assert isinstance(x, ops.Tensor) and isinstance(y, ops.Tensor) + assert ( + isinstance(x, tensor_lib.Tensor) + and isinstance(y, tensor_lib.Tensor)) if x.dtype.base_dtype != y.dtype.base_dtype: raise ValueError( "Outputs of 'true_fn' and 'false_fn' must have the same type(s). " diff --git a/tensorflow/python/ops/cond_v2.py b/tensorflow/python/ops/cond_v2.py index f09a280a9438a3..66e44131875dca 100644 --- a/tensorflow/python/ops/cond_v2.py +++ b/tensorflow/python/ops/cond_v2.py @@ -31,6 +31,7 @@ from tensorflow.python.framework import func_graph as func_graph_module from tensorflow.python.framework import indexed_slices from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util from tensorflow.python.framework import type_spec @@ -650,7 +651,7 @@ def _make_output_composite_tensors_match(op_type, branch_graphs): for branch_idx, branch_out in enumerate(branch_outs): if isinstance(branch_out, indexed_slices.IndexedSlices): continue - elif isinstance(branch_out, ops.Tensor): + elif isinstance(branch_out, tensor_lib.Tensor): with branch_graphs[branch_idx].as_default(): branch_outputs[branch_idx][output_idx] = math_ops._as_indexed_slices( branch_out) diff --git a/tensorflow/python/ops/control_flow_case.py b/tensorflow/python/ops/control_flow_case.py index a8d508f358db75..be7beca29fe10a 100644 --- a/tensorflow/python/ops/control_flow_case.py +++ b/tensorflow/python/ops/control_flow_case.py @@ -20,6 +20,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.ops import array_ops_stack from tensorflow.python.ops import cond from tensorflow.python.ops import control_flow_assert @@ -401,7 +402,7 @@ def _case_verify_and_canonicalize_args(pred_fn_pairs, exclusive, name, f"Received {pred_fn_pair}.") pred, fn = pred_fn_pair - if isinstance(pred, ops.Tensor): + if isinstance(pred, tensor.Tensor): if pred.dtype != dtypes.bool: raise TypeError("pred must be Tensor of type bool: %s" % pred.name) elif not allow_python_preds: diff --git a/tensorflow/python/ops/control_flow_grad.py b/tensorflow/python/ops/control_flow_grad.py index 7b8e13c8351673..6fe6d207f9805d 100644 --- a/tensorflow/python/ops/control_flow_grad.py +++ b/tensorflow/python/ops/control_flow_grad.py @@ -19,6 +19,7 @@ from tensorflow.python.framework import indexed_slices from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_util from tensorflow.python.ops import math_ops @@ -158,7 +159,7 @@ def _ExitGrad(op, grad): if op_ctxt.grad_state: raise TypeError("Second-order gradient for while loops not supported.") - if isinstance(grad, ops.Tensor): + if isinstance(grad, tensor.Tensor): grad_ctxt.AddName(grad.name) else: if not isinstance( @@ -220,7 +221,7 @@ def _EnterGrad(op, grad): return grad if op.get_attr("is_constant"): # Add a gradient accumulator for each loop invariant. - if isinstance(grad, ops.Tensor): + if isinstance(grad, tensor.Tensor): result = grad_ctxt.AddBackpropAccumulator(op, grad) elif isinstance(grad, indexed_slices.IndexedSlices): result = grad_ctxt.AddBackpropIndexedSlicesAccumulator(op, grad) diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py index 80fcc7e51910ca..7ddbfe5e8cc356 100644 --- a/tensorflow/python/ops/control_flow_ops.py +++ b/tensorflow/python/ops/control_flow_ops.py @@ -27,8 +27,8 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import indexed_slices from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import tensor_util from tensorflow.python.framework import type_spec from tensorflow.python.ops import array_ops @@ -70,7 +70,7 @@ def _Identity(tensor, name=None): # TODO(b/246438937): Remove this when we expand ResourceVariables into # dt_resource tensors. tensor = variable_utils.convert_variables_to_tensors(tensor) - if isinstance(tensor, ops.Tensor): + if isinstance(tensor, tensor_lib.Tensor): if tensor.dtype._is_ref_dtype: # pylint: disable=protected-access return gen_array_ops.ref_identity(tensor, name=name) else: @@ -84,7 +84,7 @@ def _Identity(tensor, name=None): def _NextIteration(tensor, name=None): tensor = ops.internal_convert_to_tensor_or_composite(tensor, as_ref=True) - if isinstance(tensor, ops.Tensor): + if isinstance(tensor, tensor_lib.Tensor): if tensor.dtype._is_ref_dtype: # pylint: disable=protected-access return ref_next_iteration(tensor, name=name) else: @@ -127,7 +127,7 @@ def _Enter(tensor, than its corresponding shape in `shape_invariant`. """ tensor = ops.internal_convert_to_tensor_or_composite(tensor, as_ref=True) - if isinstance(tensor, ops.Tensor): + if isinstance(tensor, tensor_lib.Tensor): if tensor.dtype._is_ref_dtype and use_ref: # pylint: disable=protected-access result = gen_control_flow_ops.ref_enter( tensor, frame_name, is_constant, parallel_iterations, name=name) @@ -162,7 +162,7 @@ def exit(tensor, name=None): # pylint: disable=redefined-builtin The same tensor as `tensor`. """ tensor = ops.internal_convert_to_tensor_or_composite(tensor, as_ref=True) - if isinstance(tensor, ops.Tensor): + if isinstance(tensor, tensor_lib.Tensor): if tensor.dtype._is_ref_dtype: # pylint: disable=protected-access return gen_control_flow_ops.ref_exit(tensor, name) else: @@ -197,7 +197,7 @@ def switch(data, pred, dtype=None, name=None): data = ops.internal_convert_to_tensor_or_composite( data, dtype=dtype, name="data", as_ref=True) pred = ops.convert_to_tensor(pred, name="pred") - if isinstance(data, ops.Tensor): + if isinstance(data, tensor_lib.Tensor): return gen_control_flow_ops.switch(data, pred, name=name) else: if not isinstance(data, composite_tensor.CompositeTensor): @@ -249,7 +249,7 @@ def _SwitchRefOrTensor(data, pred, name="Switch"): # var and data may be pinned to different devices, so we want to ops # created within ops.colocate_with(data) to ignore the existing stack. with ops.colocate_with(data, ignore_existing=True): - if isinstance(data, ops.Tensor): + if isinstance(data, tensor_lib.Tensor): if data.dtype._is_ref_dtype: # pylint: disable=protected-access return ref_switch(data, pred, name=name) return switch(data, pred, name=name) @@ -287,7 +287,7 @@ def merge(inputs, name=None): ops.internal_convert_to_tensor_or_composite(inp, as_ref=True) for inp in inputs ] - if all(isinstance(v, ops.Tensor) for v in inputs): + if all(isinstance(v, tensor_lib.Tensor) for v in inputs): if all(v.dtype._is_ref_dtype for v in inputs): # pylint: disable=protected-access return gen_control_flow_ops.ref_merge(inputs, name) else: @@ -296,7 +296,7 @@ def merge(inputs, name=None): # If there is a mix of tensors and indexed slices, then convert the # tensors to indexed slices. if all( - isinstance(v, (indexed_slices.IndexedSlices, ops.Tensor)) + isinstance(v, (indexed_slices.IndexedSlices, tensor_lib.Tensor)) for v in inputs): inputs = math_ops._as_indexed_slices_list(inputs, optimize=False) @@ -384,8 +384,8 @@ def _shape_invariant_to_type_spec(var, shape=None): "'shape' must be one of TypeSpec, TensorShape or None. " f"Received: {type(shape)}") - if isinstance(var, ops.Tensor): - return tensor_spec.TensorSpec(shape, var.dtype) + if isinstance(var, tensor_lib.Tensor): + return tensor_lib.TensorSpec(shape, var.dtype) else: try: return var._shape_invariant_to_type_spec(shape) # pylint: disable=protected-access @@ -408,7 +408,7 @@ def _EnforceShapeInvariant(merge_var, next_var): ValueError: If any tensor in `merge_var` has a more specific shape than its corresponding tensor in `next_var`. """ - if isinstance(merge_var, ops.Tensor): + if isinstance(merge_var, tensor_lib.Tensor): m_shape = merge_var.get_shape() n_shape = next_var.get_shape() if not _ShapeLessThanOrEqual(n_shape, m_shape): @@ -427,7 +427,7 @@ def _EnforceShapeInvariant(merge_var, next_var): def _AddNextAndBackEdge(m, v, enforce_shape_invariant=True): """Add NextIteration and back edge from v to m.""" - if isinstance(m, ops.Tensor): + if isinstance(m, tensor_lib.Tensor): v = ops.convert_to_tensor(v) v = _NextIteration(v) if enforce_shape_invariant: @@ -1632,7 +1632,7 @@ def _InitializeValues(self, values): """Makes the values known to this context.""" self._values = set() for x in values: - if isinstance(x, ops.Tensor): + if isinstance(x, tensor_lib.Tensor): self._values.add(x.name) else: raise TypeError("'values' must be a list of Tensors. " @@ -1831,7 +1831,7 @@ def _FixControlInputsAndContext(self, enters): graph = ops.get_default_graph() # pylint: disable=protected-access for e in enters: - if isinstance(e, ops.Tensor): + if isinstance(e, tensor_lib.Tensor): xs = [e] else: raise TypeError("'enters' must be a list of Tensors. " @@ -1888,7 +1888,7 @@ def _AsTensorList(x, p): if isinstance(v, ops.Operation): v = with_dependencies([v], p) v = ops.convert_to_tensor_or_composite(v) - if isinstance(v, ops.Tensor): + if isinstance(v, tensor_lib.Tensor): l.append(array_ops.identity(v)) else: l.append( @@ -2150,7 +2150,7 @@ def tuple(tensors, name=None, control_inputs=None): # pylint: disable=redefined ] if control_inputs: for c in control_inputs: - if isinstance(c, ops.Tensor): + if isinstance(c, tensor_lib.Tensor): c = c.op elif not isinstance(c, ops.Operation): raise TypeError( diff --git a/tensorflow/python/ops/control_flow_switch_case.py b/tensorflow/python/ops/control_flow_switch_case.py index 8cb4fe685ef104..843a088017df36 100644 --- a/tensorflow/python/ops/control_flow_switch_case.py +++ b/tensorflow/python/ops/control_flow_switch_case.py @@ -16,6 +16,7 @@ from tensorflow.python.eager import context from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.ops import array_ops from tensorflow.python.ops import cond_v2 from tensorflow.python.ops import control_flow_util as util @@ -45,7 +46,7 @@ def _indexed_case_verify_and_canonicalize_args(branch_fns, default, Returns: branch_fns: validated list of callables for each branch (default last). """ - if not isinstance(branch_index, ops.Tensor): + if not isinstance(branch_index, tensor.Tensor): raise TypeError("'branch_index' must be a Tensor, got {}".format( type(branch_index))) if not branch_index.dtype.is_integer: diff --git a/tensorflow/python/ops/functional_ops.py b/tensorflow/python/ops/functional_ops.py index f394ce2cc77ad3..a50b445aea6394 100644 --- a/tensorflow/python/ops/functional_ops.py +++ b/tensorflow/python/ops/functional_ops.py @@ -20,6 +20,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import function from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_functional_ops @@ -1044,7 +1045,7 @@ def WhileBody(i, n, start, delta, *args): if isinstance(for_result, ops.Operation): for_result = () # Unary functions return a single Tensor value. - elif isinstance(for_result, ops.Tensor): + elif isinstance(for_result, tensor.Tensor): for_result = (for_result,) return (i + 1, n, start, delta) + tuple(for_result) diff --git a/tensorflow/python/ops/gradient_checker.py b/tensorflow/python/ops/gradient_checker.py index 9f00934c8263a6..8ed4aee3d47c6b 100644 --- a/tensorflow/python/ops/gradient_checker.py +++ b/tensorflow/python/ops/gradient_checker.py @@ -24,6 +24,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import indexed_slices from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.ops import array_ops from tensorflow.python.ops import gradients from tensorflow.python.ops import math_ops @@ -108,7 +109,7 @@ def _compute_theoretical_jacobian(x, x_shape, x_data, dy, dy_shape, dx, r_end = r_begin + x_val_size jacobian[r_begin:r_end, col] += v.flat else: - assert isinstance(dx, ops.Tensor), "dx = " + str(dx) + assert isinstance(dx, tensor.Tensor), "dx = " + str(dx) backprop = sess.run( dx, feed_dict=_extra_feeds(extra_feed_dict, {x: x_data, dy: dy_data})) jacobian[:, col] = backprop.ravel().view(jacobian.dtype) diff --git a/tensorflow/python/ops/gradient_checker_v2.py b/tensorflow/python/ops/gradient_checker_v2.py index 3a201b103fafb3..390bdbd20e92cd 100644 --- a/tensorflow/python/ops/gradient_checker_v2.py +++ b/tensorflow/python/ops/gradient_checker_v2.py @@ -24,6 +24,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import indexed_slices from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.ops import array_ops from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util.tf_export import tf_export @@ -74,7 +75,7 @@ def _to_numpy(a): """ if isinstance(a, ops.EagerTensor): return a.numpy() - if isinstance(a, ops.Tensor): + if isinstance(a, tensor.Tensor): sess = ops.get_default_session() return sess.run(a) if isinstance(a, indexed_slices.IndexedSlicesValue): diff --git a/tensorflow/python/ops/gradients_test.py b/tensorflow/python/ops/gradients_test.py index 62d5bde1eae3e8..b643278e3f9eb2 100644 --- a/tensorflow/python/ops/gradients_test.py +++ b/tensorflow/python/ops/gradients_test.py @@ -28,7 +28,7 @@ from tensorflow.python.framework import function as framework_function from tensorflow.python.framework import indexed_slices from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_spec +from tensorflow.python.framework import tensor from tensorflow.python.framework import test_ops from tensorflow.python.framework import test_util from tensorflow.python.framework.constant_op import constant @@ -234,9 +234,9 @@ def _TestOpGrad(_, float_grad, string_grad): z = x * 2.0 w = z * 3.0 grads = gradients.gradients(z, [c]) - self.assertIsInstance(grads[0], ops.Tensor) + self.assertIsInstance(grads[0], tensor.Tensor) grads = gradients.gradients(w, [c]) - self.assertIsInstance(grads[0], ops.Tensor) + self.assertIsInstance(grads[0], tensor.Tensor) def testNoGradientForStringOutputsWithOpNamespace(self): with ops.Graph().as_default(): @@ -254,9 +254,9 @@ def _TestOpGrad(_, float_grad, string_grad): z = x * 2.0 w = z * 3.0 grads = gradients.gradients(z, [c]) - self.assertIsInstance(grads[0], ops.Tensor) + self.assertIsInstance(grads[0], tensor.Tensor) grads = gradients.gradients(w, [c]) - self.assertIsInstance(grads[0], ops.Tensor) + self.assertIsInstance(grads[0], tensor.Tensor) def testSingletonIndexedSlices(self): with ops.Graph().as_default(): @@ -1614,7 +1614,7 @@ def F(x): self.assertAllClose(grads_re, grads) f_graph = def_function.function( - F, input_signature=[tensor_spec.TensorSpec(None)]) + F, input_signature=[tensor.TensorSpec(None)]) grads_re = self._grad(custom_gradient.recompute_grad(f_graph))(x) grads = self._grad(f_graph)(x) self.assertAllClose(grads_re, grads) @@ -1633,8 +1633,8 @@ def F(x1, x2): f_graph = def_function.function( F, input_signature=[ - tensor_spec.TensorSpec(None, dtype=dtypes.int32), - tensor_spec.TensorSpec(None, dtype=dtypes.float32), + tensor.TensorSpec(None, dtype=dtypes.int32), + tensor.TensorSpec(None, dtype=dtypes.float32), ]) grads_re = self._grad(custom_gradient.recompute_grad(f_graph))(x1, x2) grads = self._grad(f_graph)(x1, x2) diff --git a/tensorflow/python/ops/gradients_util.py b/tensorflow/python/ops/gradients_util.py index 31c06e091eb7e8..3579e4f937479d 100644 --- a/tensorflow/python/ops/gradients_util.py +++ b/tensorflow/python/ops/gradients_util.py @@ -26,6 +26,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import indexed_slices from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops @@ -166,56 +167,77 @@ def _DefaultGradYs(grad_ys, if y.dtype.is_complex: raise TypeError( f"Gradients of complex tensors ({y}) must set grad_ys (y.dtype = " - f"{dtypes.as_dtype(y.dtype).name})") + f"{dtypes.as_dtype(y.dtype).name})" + ) new_grad_ys.append( array_ops.ones( - array_ops.shape(y), dtype=y.dtype, name="grad_ys_%d" % i)) + array_ops.shape(y), dtype=y.dtype, name="grad_ys_%d" % i + ) + ) continue if y.dtype.is_floating or y.dtype.is_integer: if not grad_y.dtype.is_floating and not grad_y.dtype.is_integer: raise TypeError( f"Gradient type {dtypes.as_dtype(grad_y.dtype).name} generated " f"for real or integer-valued tensor {y} with type " - f"{dtypes.as_dtype(y.dtype).name} must be real or integer") + f"{dtypes.as_dtype(y.dtype).name} must be real or integer" + ) elif y.dtype.is_complex: if not grad_y.dtype.is_complex: raise TypeError( f"Gradient type {dtypes.as_dtype(grad_y.dtype).name} generated " f"for complex-valued tensor {y} with type " - f"{dtypes.as_dtype(y.dtype).name} must be real") + f"{dtypes.as_dtype(y.dtype).name} must be real" + ) elif y.dtype == dtypes.variant: if grad_y.dtype != dtypes.variant: raise TypeError( f"Gradient type {dtypes.as_dtype(grad_y.dtype).name} generated " f"for variant tensor {y} with type " - f"{dtypes.as_dtype(y.dtype).name} must be variant") + f"{dtypes.as_dtype(y.dtype).name} must be variant" + ) elif y.dtype == dtypes.resource: # We assume y is the handle of a ResourceVariable. The gradient of a # ResourceVariable should be a numeric value, not another resource. if grad_y.dtype == dtypes.resource: - raise TypeError(f"Input gradient {grad_y} for resource tensor {y} " - "should not be a resource") + raise TypeError( + f"Input gradient {grad_y} for resource tensor {y} " + "should not be a resource" + ) else: raise TypeError( f"Tensor {y} with type {dtypes.as_dtype(y.dtype).name} must be " - "numeric to obtain a default gradient") + "numeric to obtain a default gradient" + ) # Create a grad_y tensor in the name scope of the gradient. # Required for TensorArrays to identify which gradient call a # grad_y value is coming from. if isinstance(grad_y, indexed_slices.IndexedSlices): new_grad_ys.append( indexed_slices.IndexedSlices( - indices=(array_ops.identity( - grad_y.indices, name="grad_ys_%d_indices" % i) - if isinstance(grad_y.indices, ops.Tensor) else - grad_y.indices), - values=(array_ops.identity( - grad_y.values, name="grad_ys_%d_values" % i) if isinstance( - grad_y.values, ops.Tensor) else grad_y.values), - dense_shape=(array_ops.identity( - grad_y.dense_shape, name="grad_ys_%d_shape" % i) - if isinstance(grad_y.dense_shape, ops.Tensor) else - grad_y.dense_shape))) + indices=( + array_ops.identity( + grad_y.indices, name="grad_ys_%d_indices" % i + ) + if isinstance(grad_y.indices, tensor_lib.Tensor) + else grad_y.indices + ), + values=( + array_ops.identity( + grad_y.values, name="grad_ys_%d_values" % i + ) + if isinstance(grad_y.values, tensor_lib.Tensor) + else grad_y.values + ), + dense_shape=( + array_ops.identity( + grad_y.dense_shape, name="grad_ys_%d_shape" % i + ) + if isinstance(grad_y.dense_shape, tensor_lib.Tensor) + else grad_y.dense_shape + ), + ) + ) else: new_grad_ys.append(array_ops.identity(grad_y, name="grad_ys_%d" % i)) @@ -594,10 +616,11 @@ def _GradientsHelper(ys, func_call = None is_partitioned_call = _IsPartitionedCall(op) # pylint: disable=protected-access - is_func_call = ( - src_graph._is_function(op.type) or is_partitioned_call) + is_func_call = src_graph._is_function(op.type) or is_partitioned_call # pylint: enable=protected-access - has_out_grads = any(isinstance(g, ops.Tensor) or g for g in out_grads) + has_out_grads = any( + isinstance(g, tensor_lib.Tensor) or g for g in out_grads + ) if has_out_grads and (op not in stop_ops): try: grad_fn = ops.get_gradient_function(op) @@ -662,9 +685,12 @@ def _GradientsHelper(ys, # output, it means that the cost does not depend on output[i], # therefore dC/doutput[i] is 0. for i, out_grad in enumerate(out_grads): - if (not isinstance(out_grad, ops.Tensor) and not out_grad) and ( + if ( + not isinstance(out_grad, tensor_lib.Tensor) and not out_grad + ) and ( (not grad_fn and is_func_call) - or backprop_util.IsTrainable(op.outputs[i])): + or backprop_util.IsTrainable(op.outputs[i]) + ): # Only trainable outputs or outputs for a function call that # will use SymbolicGradient get a zero gradient. Gradient # functions should ignore the gradient for other outputs. @@ -710,7 +736,7 @@ def _GradientsHelper(ys, # line up with in_grads. for i, (t_in, in_grad) in enumerate(zip(_Inputs(op, xs_set), in_grads)): if in_grad is not None: - if (isinstance(in_grad, ops.Tensor) and + if (isinstance(in_grad, tensor_lib.Tensor) and t_in.dtype != dtypes.resource): try: in_grad.set_shape(t_in.get_shape()) @@ -738,7 +764,7 @@ def _HasAnyNotNoneGrads(grads, op): """Return true iff op has real gradient.""" out_grads = _GetGrads(grads, op) for out_grad in out_grads: - if isinstance(out_grad, (ops.Tensor, indexed_slices.IndexedSlices)): + if isinstance(out_grad, (tensor_lib.Tensor, indexed_slices.IndexedSlices)): return True if out_grad and isinstance(out_grad, collections_abc.Sequence): if any(g is not None for g in out_grad): @@ -842,7 +868,7 @@ def _GetGrads(grads, op): def _AccumulatorShape(inputs): shape = tensor_shape.unknown_shape() for i in inputs: - if isinstance(i, ops.Tensor): + if isinstance(i, tensor_lib.Tensor): shape = shape.merge_with(i.get_shape()) return shape @@ -981,12 +1007,13 @@ def _AggregatedGrads(grads, out_grads = _GetGrads(grads, op) for i, out_grad in enumerate(out_grads): if loop_state: - if isinstance(out_grad, (ops.Tensor, indexed_slices.IndexedSlices)): + if isinstance( + out_grad, (tensor_lib.Tensor, indexed_slices.IndexedSlices)): assert control_flow_util.IsLoopSwitch(op) continue # Grads have to be Tensors or IndexedSlices if (isinstance(out_grad, collections_abc.Sequence) and not all( - isinstance(g, (ops.Tensor, indexed_slices.IndexedSlices)) + isinstance(g, (tensor_lib.Tensor, indexed_slices.IndexedSlices)) for g in out_grad if g is not None)): raise TypeError(f"Invalid gradient {out_grad} [index = {i}]. Gradients " @@ -996,7 +1023,8 @@ def _AggregatedGrads(grads, if len(out_grad) < 2: used = "nop" out_grads[i] = out_grad[0] - elif all(isinstance(g, ops.Tensor) for g in out_grad if g is not None): + elif all( + isinstance(g, tensor_lib.Tensor) for g in out_grad if g is not None): tensor_shape = _AccumulatorShape(out_grad) if aggregation_method in [ AggregationMethod.EXPERIMENTAL_TREE, diff --git a/tensorflow/python/ops/image_ops_impl.py b/tensorflow/python/ops/image_ops_impl.py index aeebb3cded09db..80c2d996de8954 100644 --- a/tensorflow/python/ops/image_ops_impl.py +++ b/tensorflow/python/ops/image_ops_impl.py @@ -24,6 +24,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops @@ -98,7 +99,7 @@ def _is_tensor(x): Returns: `True` if `x` is a `tf.Tensor` or `tf.Variable`, otherwise `False`. """ - return isinstance(x, (ops.Tensor, variables.Variable)) + return isinstance(x, (tensor_lib.Tensor, variables.Variable)) def _ImageDimensions(image, rank): diff --git a/tensorflow/python/ops/io_ops.py b/tensorflow/python/ops/io_ops.py index 45ce92131c915b..64e2d0e5b7feb8 100644 --- a/tensorflow/python/ops/io_ops.py +++ b/tensorflow/python/ops/io_ops.py @@ -23,6 +23,7 @@ from tensorflow.python.eager import context from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.lib.io import python_io from tensorflow.python.ops import gen_data_flow_ops from tensorflow.python.ops import gen_io_ops @@ -275,7 +276,7 @@ def read(self, queue, name=None): key: A string scalar Tensor. value: A string scalar Tensor. """ - if isinstance(queue, ops.Tensor): + if isinstance(queue, tensor_lib.Tensor): queue_ref = queue else: queue_ref = queue.queue_ref @@ -307,7 +308,7 @@ def read_up_to(self, queue, num_records, # pylint: disable=invalid-name keys: A 1-D string Tensor. values: A 1-D string Tensor. """ - if isinstance(queue, ops.Tensor): + if isinstance(queue, tensor_lib.Tensor): queue_ref = queue else: queue_ref = queue.queue_ref diff --git a/tensorflow/python/ops/linalg_ops.py b/tensorflow/python/ops/linalg_ops.py index f007c2021254c1..237cbd5a5d4650 100644 --- a/tensorflow/python/ops/linalg_ops.py +++ b/tensorflow/python/ops/linalg_ops.py @@ -18,6 +18,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.ops import array_ops from tensorflow.python.ops import cond from tensorflow.python.ops import gen_array_ops @@ -62,7 +63,7 @@ def _RegularizedGramianCholesky(matrix, l2_regularizer, first_kind): gramian = math_ops.matmul( matrix, matrix, adjoint_a=first_kind, adjoint_b=not first_kind) - if isinstance(l2_regularizer, ops.Tensor) or l2_regularizer != 0: + if isinstance(l2_regularizer, tensor_lib.Tensor) or l2_regularizer != 0: matrix_shape = array_ops.shape(matrix) batch_shape = matrix_shape[:-2] if first_kind: diff --git a/tensorflow/python/ops/linalg_ops_impl.py b/tensorflow/python/ops/linalg_ops_impl.py index 4481bb4e350dcb..45393ccba5349b 100644 --- a/tensorflow/python/ops/linalg_ops_impl.py +++ b/tensorflow/python/ops/linalg_ops_impl.py @@ -18,6 +18,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.util import compat @@ -42,8 +43,8 @@ def eye(num_rows, num_columns = num_rows if num_columns is None else num_columns # We cannot statically infer what the diagonal size should be: - if (isinstance(num_rows, ops.Tensor) or - isinstance(num_columns, ops.Tensor)): + if (isinstance(num_rows, tensor.Tensor) or + isinstance(num_columns, tensor.Tensor)): diag_size = math_ops.minimum(num_rows, num_columns) else: # We can statically infer the diagonal size, and whether it is square. @@ -56,9 +57,12 @@ def eye(num_rows, diag_size = np.minimum(num_rows, num_columns) # We can not statically infer the shape of the tensor. - if isinstance(batch_shape, ops.Tensor) or isinstance(diag_size, ops.Tensor): + if isinstance(batch_shape, tensor.Tensor) or isinstance( + diag_size, tensor.Tensor + ): batch_shape = ops.convert_to_tensor( - batch_shape, name='shape', dtype=dtypes.int32) + batch_shape, name='shape', dtype=dtypes.int32 + ) diag_shape = array_ops.concat((batch_shape, [diag_size]), axis=0) if not is_square: shape = array_ops.concat((batch_shape, [num_rows, num_columns]), axis=0) diff --git a/tensorflow/python/ops/list_ops.py b/tensorflow/python/ops/list_ops.py index 5577b04e5502e6..bc2aeca89b26b4 100644 --- a/tensorflow/python/ops/list_ops.py +++ b/tensorflow/python/ops/list_ops.py @@ -21,6 +21,7 @@ from tensorflow.python.framework import cpp_shape_inference_pb2 from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops @@ -383,7 +384,7 @@ def _build_element_shape(shape): Returns: A None-free shape that can be converted to a tensor. """ - if isinstance(shape, ops.Tensor): + if isinstance(shape, tensor_lib.Tensor): return shape if isinstance(shape, tensor_shape.TensorShape): # `TensorShape.as_list` requires rank to be known. @@ -398,7 +399,7 @@ def _build_element_shape(shape): def convert(val): if val is None: return -1 - if isinstance(val, ops.Tensor): + if isinstance(val, tensor_lib.Tensor): return val if isinstance(val, tensor_shape.Dimension): return val.value if val.value is not None else -1 diff --git a/tensorflow/python/ops/lookup_ops.py b/tensorflow/python/ops/lookup_ops.py index fe06e18ee57232..b4eeedabe55407 100644 --- a/tensorflow/python/ops/lookup_ops.py +++ b/tensorflow/python/ops/lookup_ops.py @@ -24,6 +24,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops @@ -706,7 +707,7 @@ def __init__(self, ValueError: when the filename is empty, or when the table key and value data types do not match the expected data types. """ - if not isinstance(filename, ops.Tensor) and not filename: + if not isinstance(filename, tensor_lib.Tensor) and not filename: raise ValueError("`filename` argument required for tf.lookup.TextFileInitializer") self._filename_arg = filename @@ -1499,7 +1500,7 @@ def index_table_from_file(vocabulary_file=None, num_oov_buckets) if vocab_size is not None and vocab_size < 1: vocab_file_value = vocabulary_file - if isinstance(vocabulary_file, ops.Tensor): + if isinstance(vocabulary_file, tensor_lib.Tensor): vocab_file_value = tensor_util.constant_value(vocabulary_file) or "?" raise ValueError("`vocab_size` must be greater than 0, got %d for " "vocabulary_file: %s." % (vocab_size, vocab_file_value)) diff --git a/tensorflow/python/ops/math_grad.py b/tensorflow/python/ops/math_grad.py index 7ada2203bfeede..22378e6deae1b2 100644 --- a/tensorflow/python/ops/math_grad.py +++ b/tensorflow/python/ops/math_grad.py @@ -20,6 +20,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_array_ops @@ -85,8 +86,8 @@ def SmartBroadcastGradientArgs(x, y, grad): # NOTE: It may be productive to apply these optimizations in the eager case # as well. if context.executing_eagerly() or not ( - isinstance(x, ops.Tensor) and isinstance(y, ops.Tensor) - and isinstance(grad, ops.Tensor)): + isinstance(x, tensor.Tensor) and isinstance(y, tensor.Tensor) + and isinstance(grad, tensor.Tensor)): sx = array_ops.shape(x) sy = array_ops.shape(y) rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy) @@ -1303,7 +1304,7 @@ def _AddGrad(op, grad): # No gradient skipping, so do the full gradient computation pass x = op.inputs[0] - if (isinstance(grad, ops.Tensor) and + if (isinstance(grad, tensor.Tensor) and _ShapesFullySpecifiedAndEqual(x, y, grad)): return grad, grad (sx, rx, must_reduce_x), (sy, ry, must_reduce_y) = ( @@ -1337,7 +1338,7 @@ def _SubGrad(op, grad): # No gradient skipping, so do the full gradient computation pass x = op.inputs[0] - if (isinstance(grad, ops.Tensor) and + if (isinstance(grad, tensor.Tensor) and _ShapesFullySpecifiedAndEqual(x, y, grad)): return grad, -grad (sx, rx, must_reduce_x), (sy, ry, must_reduce_y) = ( @@ -1371,7 +1372,7 @@ def _MulGrad(op, grad): # No gradient skipping, so do the full gradient computation pass x = op.inputs[0] - if (isinstance(grad, ops.Tensor) and + if (isinstance(grad, tensor.Tensor) and _ShapesFullySpecifiedAndEqual(x, y, grad) and grad.dtype in (dtypes.int32, dtypes.float32)): return gen_math_ops.mul(grad, y), gen_math_ops.mul(grad, x) @@ -1403,7 +1404,7 @@ def _MulNoNanGrad(op, grad): """The gradient of scalar multiplication with NaN-suppression.""" x = op.inputs[0] y = op.inputs[1] - if (isinstance(grad, ops.Tensor) and + if (isinstance(grad, tensor.Tensor) and _ShapesFullySpecifiedAndEqual(x, y, grad)): return gen_math_ops.mul_no_nan(grad, y), gen_math_ops.mul_no_nan(x, grad) assert x.dtype.base_dtype == y.dtype.base_dtype, (x.dtype, " vs. ", y.dtype) @@ -1625,7 +1626,7 @@ def _SquaredDifferenceGrad(op, grad): # Tensor (not a number like 2.0) which causes it to convert to Tensor. x_grad = math_ops.scalar_mul(2.0, grad) * (x - y) - if (isinstance(grad, ops.Tensor) and + if (isinstance(grad, tensor.Tensor) and _ShapesFullySpecifiedAndEqual(x, y, grad)): return x_grad, -x_grad diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py index 39f8e77a520cd6..c90bde289564df 100644 --- a/tensorflow/python/ops/math_ops.py +++ b/tensorflow/python/ops/math_ops.py @@ -76,6 +76,7 @@ from tensorflow.python.framework import indexed_slices from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_conversion_registry from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util @@ -995,8 +996,8 @@ def cast(x, dtype, name=None): """ base_type = dtypes.as_dtype(dtype).base_dtype - if isinstance(x, - (ops.Tensor, _resource_variable_type)) and base_type == x.dtype: + if isinstance( + x, (tensor_lib.Tensor, _resource_variable_type)) and base_type == x.dtype: return x with ops.name_scope(name, "Cast", [x]) as name: if isinstance(x, sparse_tensor.SparseTensor): @@ -1386,8 +1387,8 @@ def to_complex128(x, name="ToComplex128"): return cast(x, dtypes.complex128, name=name) -ops.Tensor._override_operator("__neg__", gen_math_ops.neg) -ops.Tensor._override_operator("__abs__", abs) +tensor_lib.Tensor._override_operator("__neg__", gen_math_ops.neg) +tensor_lib.Tensor._override_operator("__abs__", abs) def _maybe_get_dtype(x): @@ -1396,7 +1397,7 @@ def _maybe_get_dtype(x): # value (not just dtype) of np.ndarray to decide the result type. if isinstance(x, numbers.Real): return x - if isinstance(x, ops.Tensor): + if isinstance(x, tensor_lib.Tensor): return x.dtype.as_numpy_dtype if isinstance(x, dtypes.DType): return x.as_numpy_dtype @@ -1442,7 +1443,7 @@ def maybe_promote_tensors(*tensors, force_same_dtype=False): result_type = np_dtypes._result_type( *[_maybe_get_dtype(x) for x in nest.flatten(tensors)]) def _promote_or_cast(x): - if isinstance(x, ops.Tensor): + if isinstance(x, tensor_lib.Tensor): x = cast(x, result_type) else: x = ops.convert_to_tensor(x, result_type) @@ -1450,7 +1451,8 @@ def _promote_or_cast(x): return [_promote_or_cast(x) for x in tensors] -def _OverrideBinaryOperatorHelper(func, op_name, clazz_object=ops.Tensor): +def _OverrideBinaryOperatorHelper( + func, op_name, clazz_object=tensor_lib.Tensor): """Register operators with different tensor and scalar versions. If `clazz_object` is `SparseTensor`, assumes `func` takes `(sp_indices, @@ -1516,7 +1518,7 @@ def r_binary_op_wrapper(y, x): r_binary_op_wrapper.__doc__ = doc binary_op_wrapper_sparse.__doc__ = doc - if clazz_object is ops.Tensor: + if clazz_object is tensor_lib.Tensor: clazz_object._override_operator("__%s__" % op_name, binary_op_wrapper) del binary_op_wrapper clazz_object._override_operator("__r%s__" % op_name, r_binary_op_wrapper) @@ -1835,7 +1837,7 @@ def _add_dispatch(x, y, name=None): Returns: The result of the elementwise `+` operation. """ - if not isinstance(y, ops.Tensor) and not isinstance( + if not isinstance(y, tensor_lib.Tensor) and not isinstance( y, sparse_tensor.SparseTensor): y = ops.convert_to_tensor(y, dtype_hint=x.dtype.base_dtype, name="y") if x.dtype == dtypes.string: @@ -1953,7 +1955,7 @@ def invert_(x, name=None): _OverrideBinaryOperatorHelper(and_, "and") _OverrideBinaryOperatorHelper(or_, "or") _OverrideBinaryOperatorHelper(xor_, "xor") -ops.Tensor._override_operator("__invert__", invert_) +tensor_lib.Tensor._override_operator("__invert__", invert_) def _promote_dtypes_decorator(fn): @@ -1963,13 +1965,13 @@ def wrapper(x, y, *args, **kwargs): return tf_decorator.make_decorator(fn, wrapper) -ops.Tensor._override_operator("__lt__", _promote_dtypes_decorator( +tensor_lib.Tensor._override_operator("__lt__", _promote_dtypes_decorator( gen_math_ops.less)) -ops.Tensor._override_operator("__le__", _promote_dtypes_decorator( +tensor_lib.Tensor._override_operator("__le__", _promote_dtypes_decorator( gen_math_ops.less_equal)) -ops.Tensor._override_operator("__gt__", _promote_dtypes_decorator( +tensor_lib.Tensor._override_operator("__gt__", _promote_dtypes_decorator( gen_math_ops.greater)) -ops.Tensor._override_operator("__ge__", _promote_dtypes_decorator( +tensor_lib.Tensor._override_operator("__ge__", _promote_dtypes_decorator( gen_math_ops.greater_equal)) @@ -2077,8 +2079,11 @@ def tensor_equals(self, other): if other is None: return False g = getattr(self, "graph", None) - if (ops.Tensor._USE_EQUALITY and ops.executing_eagerly_outside_functions() and - (g is None or g.building_function)): + if ( + tensor_lib.Tensor._USE_EQUALITY + and ops.executing_eagerly_outside_functions() + and (g is None or g.building_function) + ): self, other = maybe_promote_tensors(self, other) return gen_math_ops.equal(self, other, incompatible_shape_error=False) else: @@ -2115,7 +2120,10 @@ def tensor_not_equals(self, other): """ if other is None: return True - if ops.Tensor._USE_EQUALITY and ops.executing_eagerly_outside_functions(): + if ( + tensor_lib.Tensor._USE_EQUALITY + and ops.executing_eagerly_outside_functions() + ): self, other = maybe_promote_tensors(self, other) return gen_math_ops.not_equal(self, other, incompatible_shape_error=False) else: @@ -2123,8 +2131,8 @@ def tensor_not_equals(self, other): return self is not other -ops.Tensor._override_operator("__eq__", tensor_equals) -ops.Tensor._override_operator("__ne__", tensor_not_equals) +tensor_lib.Tensor._override_operator("__eq__", tensor_equals) +tensor_lib.Tensor._override_operator("__ne__", tensor_not_equals) @tf_export("range") @@ -2184,11 +2192,11 @@ def range(start, limit=None, delta=1, dtype=None, name="range"): # pylint: disa start, limit = 0, start with ops.name_scope(name, "Range", [start, limit, delta]) as name: - if not isinstance(start, ops.Tensor): + if not isinstance(start, tensor_lib.Tensor): start = ops.convert_to_tensor(start, dtype=dtype, name="start") - if not isinstance(limit, ops.Tensor): + if not isinstance(limit, tensor_lib.Tensor): limit = ops.convert_to_tensor(limit, dtype=dtype, name="limit") - if not isinstance(delta, ops.Tensor): + if not isinstance(delta, tensor_lib.Tensor): delta = ops.convert_to_tensor(delta, dtype=dtype, name="delta") # infer dtype if not explicitly provided @@ -3941,7 +3949,7 @@ def _as_indexed_slices(x, optimize=True): TypeError: If 'x' is not a Tensor or an IndexedSlices object. """ # TODO(touts): op_scope - if not isinstance(x, (ops.Tensor, indexed_slices.IndexedSlices)): + if not isinstance(x, (tensor_lib.Tensor, indexed_slices.IndexedSlices)): raise TypeError(f"Not a Tensor or IndexedSlices: {type(x)}.") if isinstance(x, indexed_slices.IndexedSlices): return x @@ -4109,7 +4117,7 @@ def add_n(inputs, name=None): "Tensor/IndexedSlices with the same dtype and shape.") inputs = indexed_slices.convert_n_to_tensor_or_indexed_slices(inputs) if not all( - isinstance(x, (ops.Tensor, indexed_slices.IndexedSlices)) + isinstance(x, (tensor_lib.Tensor, indexed_slices.IndexedSlices)) for x in inputs): raise ValueError("Inputs must be an iterable of at least one " "Tensor/IndexedSlices with the same dtype and shape.") @@ -4185,7 +4193,7 @@ def _input_error(): if not inputs or not isinstance(inputs, (list, tuple)): raise _input_error() inputs = indexed_slices.convert_n_to_tensor_or_indexed_slices(inputs) - if not all(isinstance(x, ops.Tensor) for x in inputs): + if not all(isinstance(x, tensor_lib.Tensor) for x in inputs): raise _input_error() if not all(x.dtype == inputs[0].dtype for x in inputs): raise _input_error() @@ -4194,7 +4202,7 @@ def _input_error(): else: shape = tensor_shape.unknown_shape() for input_tensor in inputs: - if isinstance(input_tensor, ops.Tensor): + if isinstance(input_tensor, tensor_lib.Tensor): shape = shape.merge_with(input_tensor.get_shape()) # tensor_dtype is for safety only; operator's output type computed in C++ @@ -4542,7 +4550,7 @@ def conj(x, name=None): Equivalent to numpy.conj. @end_compatibility """ - if isinstance(x, ops.Tensor): + if isinstance(x, tensor_lib.Tensor): dt = x.dtype if dt.is_floating or dt.is_integer: return x diff --git a/tensorflow/python/ops/math_ops_test.py b/tensorflow/python/ops/math_ops_test.py index 350524ef4aa689..01eefe80f74ba2 100644 --- a/tensorflow/python/ops/math_ops_test.py +++ b/tensorflow/python/ops/math_ops_test.py @@ -27,6 +27,7 @@ from tensorflow.python.framework import errors_impl from tensorflow.python.framework import indexed_slices from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import gradients @@ -784,9 +785,9 @@ def testConsistent(self): def testWithPythonValue(self): # Test case for https://github.com/tensorflow/tensorflow/issues/39475 x = math_ops.divide(5, 2) - self.assertIsInstance(x, ops.Tensor) + self.assertIsInstance(x, tensor_lib.Tensor) x = math_ops.divide(5, array_ops.constant(2.0)) - self.assertIsInstance(x, ops.Tensor) + self.assertIsInstance(x, tensor_lib.Tensor) def intEdgeTestData(self, dtype): """Edge-case test data for integer types.""" @@ -1206,7 +1207,7 @@ def testEqualityNoDowncast(self, is_equals, float_literal): x = constant_op.constant(4) try: result = op(x, float_literal) - if isinstance(result, ops.Tensor): + if isinstance(result, tensor_lib.Tensor): result = self.evaluate(result) except TypeError: # Throwing a TypeError is OK diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py index b258c9c4f61130..8770c7cab75e86 100644 --- a/tensorflow/python/ops/nn_ops.py +++ b/tensorflow/python/ops/nn_ops.py @@ -177,6 +177,7 @@ from tensorflow.python.framework import graph_util from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops @@ -1239,7 +1240,8 @@ def convolution_internal( not tensor_util.is_tf_type(filters)): with ops.name_scope("convolution_internal", None, [filters, input]): filters = ops.convert_to_tensor(filters, name='filters') - if (not isinstance(input, ops.Tensor) and not tensor_util.is_tf_type(input)): + if (not isinstance(input, tensor_lib.Tensor) and not tensor_util.is_tf_type( + input)): with ops.name_scope("convolution_internal", None, [filters, input]): input = ops.convert_to_tensor(input, name="input") @@ -2236,7 +2238,7 @@ def conv1d_transpose( input = array_ops.expand_dims(input, spatial_start_dim) filters = array_ops.expand_dims(filters, 0) output_shape = list(output_shape) if not isinstance( - output_shape, ops.Tensor) else output_shape + output_shape, tensor_lib.Tensor) else output_shape output_shape = array_ops.concat([output_shape[: spatial_start_dim], [1], output_shape[spatial_start_dim:]], 0) @@ -3820,7 +3822,7 @@ def _swap_axis(input_tensor, dim_index, last_index, name=None): return compute_op(inputs, name=name) dim_val = dim - if isinstance(dim, ops.Tensor): + if isinstance(dim, tensor_lib.Tensor): dim_val = tensor_util.constant_value(dim) if dim_val is not None and not -shape.ndims <= dim_val < shape.ndims: raise errors_impl.InvalidArgumentError( @@ -3834,7 +3836,7 @@ def _swap_axis(input_tensor, dim_index, last_index, name=None): # In case dim is negative (and is not last dimension -1), add shape.ndims ndims = array_ops.rank(inputs) - if not isinstance(dim, ops.Tensor): + if not isinstance(dim, tensor_lib.Tensor): if dim < 0: dim += ndims else: diff --git a/tensorflow/python/ops/op_selector.py b/tensorflow/python/ops/op_selector.py index 77258b8f117726..3ddbe8c8f433ec 100644 --- a/tensorflow/python/ops/op_selector.py +++ b/tensorflow/python/ops/op_selector.py @@ -15,6 +15,7 @@ """Tools for selecting ops in a graph.""" from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.util import object_identity @@ -27,7 +28,7 @@ def is_differentiable(op): def is_iterable(obj): """Return true if the object is iterable.""" - if isinstance(obj, ops.Tensor): + if isinstance(obj, tensor_lib.Tensor): return False try: _ = iter(obj) @@ -94,7 +95,7 @@ def get_unique_graph(tops, check_types=None, none_if_empty=False): if not is_iterable(tops): raise TypeError("{} is not iterable".format(type(tops))) if check_types is None: - check_types = (ops.Operation, ops.Tensor) + check_types = (ops.Operation, tensor_lib.Tensor) elif not is_iterable(check_types): check_types = (check_types,) g = None @@ -153,9 +154,9 @@ def make_list_of_t(ts, check_graph=True, allow_graph=True, ignore_ops=False): if not ts: return [] if check_graph: - check_types = None if ignore_ops else ops.Tensor + check_types = None if ignore_ops else tensor_lib.Tensor get_unique_graph(ts, check_types=check_types) - return [t for t in ts if isinstance(t, ops.Tensor)] + return [t for t in ts if isinstance(t, tensor_lib.Tensor)] def get_generating_ops(ts): @@ -272,7 +273,7 @@ def get_backward_walk_ops(seed_ops, # Empty iterable. return [] - if isinstance(first_seed_op, ops.Tensor): + if isinstance(first_seed_op, tensor_lib.Tensor): ts = make_list_of_t(seed_ops, allow_graph=False) seed_ops = get_generating_ops(ts) else: @@ -318,7 +319,7 @@ class UnliftableError(Exception): def _as_operation(op_or_tensor): - if isinstance(op_or_tensor, ops.Tensor): + if isinstance(op_or_tensor, tensor_lib.Tensor): return op_or_tensor.op return op_or_tensor @@ -338,7 +339,7 @@ def show_path(from_op, tensors, sources): Returns: A python string containing the path, or "??" if none is found. """ - if isinstance(from_op, ops.Tensor): + if isinstance(from_op, tensor_lib.Tensor): from_op = from_op.op if not isinstance(tensors, list): diff --git a/tensorflow/python/ops/parsing_config.py b/tensorflow/python/ops/parsing_config.py index 32ff8190c083f2..8be4c9c79fc988 100644 --- a/tensorflow/python/ops/parsing_config.py +++ b/tensorflow/python/ops/parsing_config.py @@ -20,6 +20,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops @@ -498,7 +499,7 @@ def _make_dense_default(self, key, shape, dtype): else: if default_value is None: default_value = constant_op.constant([], dtype=dtype) - elif not isinstance(default_value, ops.Tensor): + elif not isinstance(default_value, tensor.Tensor): key_name = "key_" + re.sub("[^A-Za-z0-9_.\\-/]", "_", key) default_value = ops.convert_to_tensor( default_value, dtype=dtype, name=key_name) diff --git a/tensorflow/python/ops/random_ops_util.py b/tensorflow/python/ops/random_ops_util.py index 1e81f4691b88df..4f9eefcc920e31 100644 --- a/tensorflow/python/ops/random_ops_util.py +++ b/tensorflow/python/ops/random_ops_util.py @@ -18,7 +18,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops_stack from tensorflow.python.ops import bitwise_ops @@ -69,7 +69,7 @@ def convert_alg_to_int(alg): return alg if isinstance(alg, Algorithm): return alg.value - if isinstance(alg, ops.Tensor): + if isinstance(alg, tensor.Tensor): return alg if isinstance(alg, str): # canonicalized alg diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py index 153ae97b89f1b7..6ad2dd866e371c 100644 --- a/tensorflow/python/ops/resource_variable_ops.py +++ b/tensorflow/python/ops/resource_variable_ops.py @@ -43,7 +43,6 @@ from tensorflow.python.framework import tensor as tensor_module from tensorflow.python.framework import tensor_conversion_registry from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import tensor_spec from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_array_ops from tensorflow.python.ops import gen_resource_variable_ops @@ -76,7 +75,7 @@ def get_eager_safe_handle_data(handle): """Get the data handle from the Tensor `handle`.""" - assert isinstance(handle, ops.Tensor) + assert isinstance(handle, tensor_module.Tensor) if isinstance(handle, ops.EagerTensor): return handle._handle_data # pylint: disable=protected-access @@ -268,7 +267,7 @@ class EagerResourceDeleter: __slots__ = ["_handle", "_handle_device", "_context"] def __init__(self, handle, handle_device): - if not isinstance(handle, ops.Tensor): + if not isinstance(handle, tensor_module.Tensor): raise ValueError( (f"Passed handle={handle} to EagerResourceDeleter. Was expecting " f"the handle to be a `tf.Tensor`.")) @@ -1933,7 +1932,7 @@ def _init_from_args( "`variable_def`. You provided neither.") init_from_fn = callable(initial_value) - if isinstance(initial_value, ops.Tensor) and hasattr( + if isinstance(initial_value, tensor_module.Tensor) and hasattr( initial_value, "graph") and initial_value.graph.building_function: raise ValueError(f"Argument `initial_value` ({initial_value}) could not " "be lifted out of a `tf.function`. " @@ -2540,7 +2539,7 @@ def __eq__(self, other): return isinstance(other, PList) and self.components == other.components -class VariableSpec(tensor_spec.DenseSpec): +class VariableSpec(tensor_module.DenseSpec): """Describes a tf.Variable. A `VariableSpec` provides metadata describing the `tf.Variable` objects @@ -2626,7 +2625,8 @@ def _from_components(self, components): raise ValueError(f"Components of a ResourceVariable must only contain " f"its resource handle, got f{components} instead.") handle = components[0] - if not isinstance(handle, ops.Tensor) or handle.dtype != dtypes.resource: + if not isinstance( + handle, tensor_module.Tensor) or handle.dtype != dtypes.resource: raise ValueError(f"The handle of a ResourceVariable must be a resource " f"tensor, got {handle} instead.") return ResourceVariable(trainable=self.trainable, @@ -2637,7 +2637,7 @@ def _from_components(self, components): @property def _component_specs(self): return [ - tensor_spec.TensorSpec( + tensor_module.TensorSpec( [], dtypes.DType( dtypes.resource._type_enum, # pylint: disable=protected-access @@ -2696,7 +2696,7 @@ def placeholder_value(self, placeholder_context): # exists in the PlaceholderContext variable = placeholder_context.get_placeholder(self.alias_id) else: - spec = tensor_spec.TensorSpec([], dtypes.resource) + spec = tensor_module.TensorSpec([], dtypes.resource) spec_context = trace_type.InternalPlaceholderContext( context_graph.outer_graph) spec_context.update_naming_scope(name) diff --git a/tensorflow/python/ops/rnn.py b/tensorflow/python/ops/rnn.py index a0c3bee073d497..e3c26ff539387f 100644 --- a/tensorflow/python/ops/rnn.py +++ b/tensorflow/python/ops/rnn.py @@ -17,6 +17,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops @@ -125,7 +126,7 @@ def _infer_state_dtype(explicit_dtype, state): def _maybe_tensor_shape_from_tensor(shape): - if isinstance(shape, ops.Tensor): + if isinstance(shape, tensor.Tensor): return tensor_shape.as_shape(tensor_util.constant_value(shape)) else: return shape diff --git a/tensorflow/python/ops/rnn_cell_impl.py b/tensorflow/python/ops/rnn_cell_impl.py index e190f4a35cd81c..9cff803335c6a3 100644 --- a/tensorflow/python/ops/rnn_cell_impl.py +++ b/tensorflow/python/ops/rnn_cell_impl.py @@ -23,7 +23,7 @@ from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util from tensorflow.python.keras.layers.legacy_rnn import rnn_cell_impl @@ -84,7 +84,7 @@ def _concat(prefix, suffix, static=False): ValueError: if prefix or suffix was `None` and asked for dynamic Tensors out. """ - if isinstance(prefix, ops.Tensor): + if isinstance(prefix, tensor.Tensor): p = prefix p_static = tensor_util.constant_value(prefix) if p.shape.ndims == 0: @@ -102,7 +102,7 @@ def _concat(prefix, suffix, static=False): if p.is_fully_defined() else None ) - if isinstance(suffix, ops.Tensor): + if isinstance(suffix, tensor.Tensor): s = suffix s_static = tensor_util.constant_value(suffix) if s.shape.ndims == 0: diff --git a/tensorflow/python/ops/session_ops.py b/tensorflow/python/ops/session_ops.py index e6aff616a32ae9..9ef5794fb1a67c 100644 --- a/tensorflow/python/ops/session_ops.py +++ b/tensorflow/python/ops/session_ops.py @@ -23,6 +23,7 @@ from tensorflow.python.framework import device as pydev from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_data_flow_ops from tensorflow.python.util import compat @@ -166,7 +167,7 @@ def get_session_handle(data, name=None): ``` """ - if not isinstance(data, ops.Tensor): + if not isinstance(data, tensor_lib.Tensor): raise TypeError("`data` must be of type Tensor.") # Colocate this operation with data. diff --git a/tensorflow/python/ops/sparse_ops.py b/tensorflow/python/ops/sparse_ops.py index 42d26e27d79ed9..b52e299149a874 100644 --- a/tensorflow/python/ops/sparse_ops.py +++ b/tensorflow/python/ops/sparse_ops.py @@ -27,6 +27,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_conversion from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util @@ -98,7 +99,7 @@ def _convert_to_sparse_tensors(sp_inputs): def _make_int64_tensor(value, name): if isinstance(value, compat.integral_types): return ops.convert_to_tensor(value, name=name, dtype=dtypes.int64) - if not isinstance(value, ops.Tensor): + if not isinstance(value, tensor_lib.Tensor): raise TypeError("{} must be an integer value".format(name)) if value.dtype == dtypes.int64: return value @@ -215,7 +216,7 @@ def sparse_expand_dims(sp_input, axis=None, name=None): with ops.name_scope(name, default_name="expand_dims", values=[sp_input]): if isinstance(axis, compat.integral_types): axis = ops.convert_to_tensor(axis, name="axis", dtype=dtypes.int32) - elif not isinstance(axis, ops.Tensor): + elif not isinstance(axis, tensor_lib.Tensor): raise TypeError("axis must be an integer value in range [-rank(sp_input)" " - 1, rank(sp_input)]") @@ -717,7 +718,8 @@ def _sparse_cross_internal_v2(inputs): if not isinstance(inputs, (tuple, list)): raise TypeError("Inputs must be a list") if not all( - isinstance(i, sparse_tensor.SparseTensor) or isinstance(i, ops.Tensor) + isinstance( + i, sparse_tensor.SparseTensor) or isinstance(i, tensor_lib.Tensor) for i in inputs): raise TypeError("All inputs must be Tensor or SparseTensor.") sparse_inputs = [ @@ -747,7 +749,8 @@ def _sparse_cross_internal(inputs, if not isinstance(inputs, (tuple, list)): raise TypeError("Inputs must be a list") if not all( - isinstance(i, sparse_tensor.SparseTensor) or isinstance(i, ops.Tensor) + isinstance( + i, sparse_tensor.SparseTensor) or isinstance(i, tensor_lib.Tensor) for i in inputs): raise TypeError("All inputs must be SparseTensors") @@ -1901,7 +1904,7 @@ def sparse_merge_impl(sp_ids, if isinstance(sp_ids, sparse_tensor.SparseTensorValue) or isinstance( sp_ids, sparse_tensor.SparseTensor): sp_ids = [sp_ids] - if not (isinstance(vocab_size, ops.Tensor) or + if not (isinstance(vocab_size, tensor_lib.Tensor) or isinstance(vocab_size, numbers.Integral)): raise TypeError("vocab_size has to be a Tensor or Python int. Found %s" % type(vocab_size)) @@ -1914,7 +1917,8 @@ def sparse_merge_impl(sp_ids, raise TypeError("vocab_size has to be a list of Tensors or Python ints. " "Found %s" % type(vocab_size)) for dim in vocab_size: - if not (isinstance(dim, ops.Tensor) or isinstance(dim, numbers.Integral)): + if not (isinstance( + dim, tensor_lib.Tensor) or isinstance(dim, numbers.Integral)): raise TypeError( "vocab_size has to be a list of Tensors or Python ints. Found %s" % type(dim)) diff --git a/tensorflow/python/ops/special_math_ops.py b/tensorflow/python/ops/special_math_ops.py index 77a7fc0631f3d7..6ee429a6783e27 100644 --- a/tensorflow/python/ops/special_math_ops.py +++ b/tensorflow/python/ops/special_math_ops.py @@ -28,6 +28,7 @@ from tensorflow.compiler.tf2xla.ops import gen_xla_ops from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops @@ -1060,7 +1061,7 @@ def _reshape_if_necessary(tensor, new_shape): new_shape = tuple(-1 if x is None else x for x in new_shape) cur_shape = tuple(x.value for x in tensor.shape.dims) if (len(new_shape) == len(cur_shape) and - all(not isinstance(d1, ops.Tensor) and (d0 == d1 or d1 == -1) + all(not isinstance(d1, tensor_lib.Tensor) and (d0 == d1 or d1 == -1) for d0, d1 in zip(cur_shape, new_shape))): return tensor else: diff --git a/tensorflow/python/ops/stateful_random_ops.py b/tensorflow/python/ops/stateful_random_ops.py index ae396e66dceb61..a33ea0c9f21a2e 100644 --- a/tensorflow/python/ops/stateful_random_ops.py +++ b/tensorflow/python/ops/stateful_random_ops.py @@ -21,6 +21,7 @@ from tensorflow.python.framework import config from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops_stack from tensorflow.python.ops import gen_stateful_random_ops @@ -161,7 +162,7 @@ def _get_state_size(alg): def _check_state_shape(shape, alg): - if isinstance(alg, ops.Tensor) and not context.executing_eagerly(): + if isinstance(alg, tensor.Tensor) and not context.executing_eagerly(): return shape.assert_is_compatible_with([_get_state_size(int(alg))]) diff --git a/tensorflow/python/ops/summary_ops_v2.py b/tensorflow/python/ops/summary_ops_v2.py index 95b6f67879890c..9ef7cced15cf3a 100644 --- a/tensorflow/python/ops/summary_ops_v2.py +++ b/tensorflow/python/ops/summary_ops_v2.py @@ -31,6 +31,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import smart_cond +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops @@ -992,10 +993,14 @@ def graph_v1(param, step=None, name=None): Raises: TypeError: If `param` isn't already a `tf.Tensor` in graph mode. """ - if not context.executing_eagerly() and not isinstance(param, ops.Tensor): - raise TypeError("graph() needs a argument `param` to be tf.Tensor " - "(e.g. tf.placeholder) in graph mode, but received " - f"param={param} of type {type(param).__name__}.") + if not context.executing_eagerly() and not isinstance( + param, tensor_lib.Tensor + ): + raise TypeError( + "graph() needs a argument `param` to be tf.Tensor " + "(e.g. tf.placeholder) in graph mode, but received " + f"param={param} of type {type(param).__name__}." + ) writer = _summary_state.writer if writer is None: return control_flow_ops.no_op() @@ -1170,7 +1175,7 @@ def _serialize_graph(arbitrary_graph): def _choose_step(step): if step is None: return training_util.get_or_create_global_step() - if not isinstance(step, ops.Tensor): + if not isinstance(step, tensor_lib.Tensor): return ops.convert_to_tensor(step, dtypes.int64) return step diff --git a/tensorflow/python/ops/tensor_array_ops.py b/tensorflow/python/ops/tensor_array_ops.py index 3dba50011dd5a7..0459bff690c853 100644 --- a/tensorflow/python/ops/tensor_array_ops.py +++ b/tensorflow/python/ops/tensor_array_ops.py @@ -28,8 +28,8 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors_impl from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import tensor_util from tensorflow.python.framework import type_spec from tensorflow.python.framework import type_spec_registry @@ -103,7 +103,7 @@ def __init__(self, raise ValueError( "Cannot provide both `handle` and `tensor_array_name` arguments at " "the same time.") - if handle is not None and not isinstance(handle, ops.Tensor): + if handle is not None and not isinstance(handle, tensor_lib.Tensor): raise TypeError( f"Expected `handle` to be a Tensor, but got `{handle}` of type " f"`{type(handle)}` instead.") @@ -452,21 +452,26 @@ def __init__(self, self._dynamic_size = dynamic_size self._size = size - if (flow is not None and - (not isinstance(flow, ops.Tensor) or flow.dtype != dtypes.variant)): + if flow is not None and ( + not isinstance(flow, tensor_lib.Tensor) or flow.dtype != dtypes.variant + ): raise TypeError( - f"Expected `flow` to be a variant tensor, but received `{flow.dtype}` " - f"instead.") + f"Expected `flow` to be a variant tensor, but received `{flow.dtype}`" + " instead." + ) if flow is None and size is None: - raise ValueError("Argument `size` must be provided if argument `flow` " - "is not provided.") + raise ValueError( + "Argument `size` must be provided if argument `flow` is not provided." + ) if flow is not None and size is not None: - raise ValueError("Cannot provide both `flow` and `size` arguments " - "at the same time.") + raise ValueError( + "Cannot provide both `flow` and `size` arguments at the same time." + ) if flow is not None and element_shape is not None: raise ValueError( "Cannot provide both `flow` and `element_shape` arguments" - "at the same time.") + "at the same time." + ) self._dtype = dtypes.as_dtype(dtype).base_dtype @@ -1434,7 +1439,7 @@ def _serialize(self): @property def _component_specs(self): - return [tensor_spec.TensorSpec([], dtypes.variant)] + return [tensor_lib.TensorSpec([], dtypes.variant)] def _to_components(self, value): if not isinstance(value, TensorArray): @@ -1510,7 +1515,7 @@ def placeholder_value(self, placeholder_context): return self._value def _flatten(self): - return [tensor_spec.TensorSpec([], dtypes.variant)] + return [tensor_lib.TensorSpec([], dtypes.variant)] def _from_tensors(self, tensors): return next(tensors) diff --git a/tensorflow/python/ops/variable_scope.py b/tensorflow/python/ops/variable_scope.py index a86b8c2999f49a..33dd0438fa2f2f 100644 --- a/tensorflow/python/ops/variable_scope.py +++ b/tensorflow/python/ops/variable_scope.py @@ -27,6 +27,7 @@ from tensorflow.python.eager import monitoring from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.framework import tensor_conversion_registry from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops @@ -698,7 +699,7 @@ def _get_partitioned_variable(self, sharded variable exists for the given name but with different sharding. """ initializing_from_value = initializer is not None and isinstance( - initializer, ops.Tensor) + initializer, tensor.Tensor) if name in self._vars: raise ValueError( "A partitioner was provided, but an unpartitioned version of the " @@ -780,7 +781,7 @@ def _get_partitioned_variable(self, elif callable(initializer): init = initializer init_shape = var_shape - elif isinstance(initializer, ops.Tensor): + elif isinstance(initializer, tensor.Tensor): init = array_ops.slice(initializer, var_offset, var_shape) # Use the dtype of the given tensor. dtype = init.dtype.base_dtype diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py index 3e6dc5e3a5145e..76d8bb9ad6ecea 100644 --- a/tensorflow/python/ops/variables.py +++ b/tensorflow/python/ops/variables.py @@ -25,6 +25,7 @@ from tensorflow.python.eager import context from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_conversion_registry from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops @@ -981,7 +982,7 @@ def _TensorConversionFunction(v, dtype=None, name=None, as_ref=False): # pylint @classmethod def _OverloadAllOperators(cls): # pylint: disable=invalid-name """Register overloads for all operators.""" - for operator in ops.Tensor.OVERLOADABLE_OPERATORS: + for operator in tensor_lib.Tensor.OVERLOADABLE_OPERATORS: cls._OverloadOperator(operator) # For slicing, bind getitem differently than a tensor (use SliceHelperVar # instead) @@ -990,9 +991,10 @@ def _OverloadAllOperators(cls): # pylint: disable=invalid-name @classmethod def _OverloadOperator(cls, operator): # pylint: disable=invalid-name - """Defer an operator overload to `ops.Tensor`. + """Defer an operator overload to `tensor_lib.Tensor`. - We pull the operator out of ops.Tensor dynamically to avoid ordering issues. + We pull the operator out of tensor_lib.Tensor dynamically to avoid ordering + issues. Args: operator: string. The operator name. @@ -1004,7 +1006,7 @@ def _OverloadOperator(cls, operator): # pylint: disable=invalid-name if operator == "__eq__" or operator == "__ne__": return - tensor_oper = getattr(ops.Tensor, operator) + tensor_oper = getattr(tensor_lib.Tensor, operator) def _run_op(a, *args, **kwargs): # pylint: disable=protected-access @@ -1014,17 +1016,24 @@ def _run_op(a, *args, **kwargs): setattr(cls, operator, _run_op) def __hash__(self): - if ops.Tensor._USE_EQUALITY and ops.executing_eagerly_outside_functions(): # pylint: disable=protected-access + if ( + tensor_lib.Tensor._USE_EQUALITY + and ops.executing_eagerly_outside_functions() + ): # pylint: disable=protected-access raise TypeError( "Variable is unhashable. " - f"Instead, use variable.ref() as the key. (Variable: {self})") + f"Instead, use variable.ref() as the key. (Variable: {self})" + ) else: return id(self) # TODO(gjn): duplicate of math_ops.tensor_equals, consider removing def __eq__(self, other): """Compares two variables element-wise for equality.""" - if ops.Tensor._USE_EQUALITY and ops.executing_eagerly_outside_functions(): # pylint: disable=protected-access + if ( + tensor_lib.Tensor._USE_EQUALITY + and ops.executing_eagerly_outside_functions() + ): # pylint: disable=protected-access return gen_math_ops.equal(self, other, incompatible_shape_error=False) else: # In legacy graph mode, tensor equality is object equality @@ -1033,7 +1042,10 @@ def __eq__(self, other): # TODO(gjn): duplicate of math_ops.tensor_not_equals, consider removing def __ne__(self, other): """Compares two variables element-wise for equality.""" - if ops.Tensor._USE_EQUALITY and ops.executing_eagerly_outside_functions(): # pylint: disable=protected-access + if ( + tensor_lib.Tensor._USE_EQUALITY + and ops.executing_eagerly_outside_functions() + ): # pylint: disable=protected-access return gen_math_ops.not_equal(self, other, incompatible_shape_error=False) else: # In legacy graph mode, tensor equality is object equality @@ -1342,7 +1354,7 @@ def _try_guard_against_uninitialized_dependencies(name, initial_value): Raises: TypeError: If `initial_value` is not a `Tensor`. """ - if not isinstance(initial_value, ops.Tensor): + if not isinstance(initial_value, tensor_lib.Tensor): raise TypeError("initial_value needs to be a Tensor: %s" % initial_value) # Don't modify initial_value if it contains any cyclic dependencies. diff --git a/tensorflow/python/ops/weak_tensor_ops.py b/tensorflow/python/ops/weak_tensor_ops.py index 43ac9db14565c3..627768a45a0fff 100644 --- a/tensorflow/python/ops/weak_tensor_ops.py +++ b/tensorflow/python/ops/weak_tensor_ops.py @@ -16,7 +16,7 @@ import inspect -from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.framework.weak_tensor import WeakTensor from tensorflow.python.ops import weak_tensor_ops_list from tensorflow.python.util import dispatch @@ -24,8 +24,8 @@ # This file must depend on math_ops so that e.g. `__add__` is # added to the Tensor class. -for operator in ops.Tensor.OVERLOADABLE_OPERATORS: - tensor_oper = getattr(ops.Tensor, operator) +for operator in tensor.Tensor.OVERLOADABLE_OPERATORS: + tensor_oper = getattr(tensor.Tensor, operator) setattr(WeakTensor, operator, tensor_oper) # List of unary ops that have support for WeakTensor. diff --git a/tensorflow/python/ops/weak_tensor_ops_test.py b/tensorflow/python/ops/weak_tensor_ops_test.py index db15af9a0b63e4..c7816bc8092f25 100644 --- a/tensorflow/python/ops/weak_tensor_ops_test.py +++ b/tensorflow/python/ops/weak_tensor_ops_test.py @@ -18,6 +18,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.framework import test_util from tensorflow.python.framework.weak_tensor import WeakTensor from tensorflow.python.ops import array_ops @@ -136,7 +137,7 @@ def test_multi_arg_unary_ops_return_weak_tensor(self): def test_unary_ops_return_normal_tensor(self, unary_api_specific_dtype): a = WeakTensor(constant_op.constant([1, 2, 3], dtypes.float32)) res = unary_api_specific_dtype(a) - self.assertIsInstance(res, ops.Tensor) + self.assertIsInstance(res, tensor.Tensor) # Test unary ops with optional dtype arg. def test_elementwise_unary_ops_optional_dtype(self): @@ -148,25 +149,26 @@ def test_elementwise_unary_ops_optional_dtype(self): # dtype specified in the argument. self.assertIsInstance( - array_ops.zeros_like(a, dtype=dtypes.int32), ops.Tensor + array_ops.zeros_like(a, dtype=dtypes.int32), tensor.Tensor ) self.assertIsInstance( - array_ops.ones_like(a, dtype=dtypes.int32), ops.Tensor + array_ops.ones_like(a, dtype=dtypes.int32), tensor.Tensor ) - self.assertIsInstance(array_ops.zeros_like(a, dtypes.int32), ops.Tensor) - self.assertIsInstance(array_ops.ones_like(a, dtypes.int32), ops.Tensor) + self.assertIsInstance(array_ops.zeros_like(a, dtypes.int32), tensor.Tensor) + self.assertIsInstance(array_ops.ones_like(a, dtypes.int32), tensor.Tensor) self.assertIsInstance( np_array_ops.arange( WeakTensor(constant_op.constant(5)), 0, 1, dtypes.float32 ), - ops.Tensor, + tensor.Tensor, ) # Test unary ops that require dtype arg. def test_unary_ops_explicit_dtype_return(self): a = WeakTensor(constant_op.constant([1, 2, 3], dtypes.float32)) - self.assertIsInstance(math_ops.cast(a, dtypes.int32), ops.Tensor) - self.assertIsInstance(math_ops.saturate_cast(a, dtypes.int32), ops.Tensor) + self.assertIsInstance(math_ops.cast(a, dtypes.int32), tensor.Tensor) + self.assertIsInstance( + math_ops.saturate_cast(a, dtypes.int32), tensor.Tensor) def _get_test_input(op): diff --git a/tensorflow/python/ops/while_v2.py b/tensorflow/python/ops/while_v2.py index b0b505df8c0cf4..e702a8f7e8bf2b 100644 --- a/tensorflow/python/ops/while_v2.py +++ b/tensorflow/python/ops/while_v2.py @@ -30,6 +30,7 @@ from tensorflow.python.framework import func_graph as func_graph_module from tensorflow.python.framework import indexed_slices from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import tensor_util @@ -805,7 +806,7 @@ def _get_structured_grad_output(outputs, grads, body_grad_graph): dense_shape=outputs[outputs_idx + 2])) outputs_idx += 3 else: - assert isinstance(output, ops.Tensor) + assert isinstance(output, tensor_lib.Tensor) result.append(outputs[outputs_idx]) outputs_idx += 1 diff --git a/tensorflow/python/ops/while_v2_indexed_slices_rewriter.py b/tensorflow/python/ops/while_v2_indexed_slices_rewriter.py index 64f7ed8d40db61..56e352a63c4207 100644 --- a/tensorflow/python/ops/while_v2_indexed_slices_rewriter.py +++ b/tensorflow/python/ops/while_v2_indexed_slices_rewriter.py @@ -18,7 +18,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import func_graph from tensorflow.python.framework import indexed_slices -from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.framework import tensor_conversion from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops @@ -176,7 +176,7 @@ def _create_grad_indexed_slices_init(grad_output_slices, forward_input): Zeros IndexedSlices, created in current Graph. """ assert isinstance(grad_output_slices, indexed_slices.IndexedSlices) - assert isinstance(forward_input, ops.Tensor) + assert isinstance(forward_input, tensor.Tensor) values_out = grad_output_slices.values indices_out = grad_output_slices.indices From 2a120f9201cf82c146e1a1936e508e05721aa487 Mon Sep 17 00:00:00 2001 From: Juan Martinez Castellanos Date: Tue, 11 Jul 2023 10:35:52 -0700 Subject: [PATCH 136/376] Make all targets under tensorflow/compiler/xla/python_api/ have strict dependencies. PiperOrigin-RevId: 547231775 --- tensorflow/compiler/xla/python_api/BUILD | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/tensorflow/compiler/xla/python_api/BUILD b/tensorflow/compiler/xla/python_api/BUILD index 1185709fd30075..6dddb9d69858b8 100644 --- a/tensorflow/compiler/xla/python_api/BUILD +++ b/tensorflow/compiler/xla/python_api/BUILD @@ -1,3 +1,5 @@ +load("//tensorflow:strict.default.bzl", "py_strict_library", "py_strict_test") + # Description: # Python API for XLA. load("//tensorflow/compiler/xla/tests:build_defs.bzl", "generate_backend_suites") @@ -9,7 +11,7 @@ package( generate_backend_suites() -py_library( +py_strict_library( name = "types", srcs = ["types_.py"], srcs_version = "PY3", @@ -22,7 +24,7 @@ py_library( ], ) -py_library( +py_strict_library( name = "xla_shape", srcs = ["xla_shape.py"], srcs_version = "PY3", @@ -30,10 +32,11 @@ py_library( deps = [ ":types", "//tensorflow/compiler/xla:xla_data_proto_py", + "//third_party/py/numpy", ], ) -py_library( +py_strict_library( name = "xla_literal", srcs = ["xla_literal.py"], srcs_version = "PY3", @@ -42,10 +45,11 @@ py_library( ":types", ":xla_shape", "//tensorflow/compiler/xla:xla_data_proto_py", + "//third_party/py/numpy", ], ) -py_test( +py_strict_test( name = "xla_shape_test", srcs = ["xla_shape_test.py"], python_version = "PY3", @@ -55,12 +59,13 @@ py_test( ], deps = [ ":xla_shape", + "//tensorflow/compiler/xla:xla_data_proto_py", "//third_party/py/numpy", "@absl_py//absl/testing:absltest", ], ) -py_test( +py_strict_test( name = "xla_literal_test", srcs = ["xla_literal_test.py"], python_version = "PY3", @@ -70,6 +75,7 @@ py_test( ], deps = [ ":xla_literal", + "//tensorflow/compiler/xla:xla_data_proto_py", "//third_party/py/numpy", "@absl_py//absl/testing:absltest", ], From 976881d29f9b653f87a2c7047742a1c3208288dd Mon Sep 17 00:00:00 2001 From: Yu Feng Date: Tue, 11 Jul 2023 10:37:49 -0700 Subject: [PATCH 137/376] Add a binding for the copy constructor of Mesh. Such that we can cast from base class to the Python class. PiperOrigin-RevId: 547232385 --- tensorflow/dtensor/python/layout.py | 61 +++++++++++----------- tensorflow/python/pywrap_dtensor_device.cc | 2 + 2 files changed, 33 insertions(+), 30 deletions(-) diff --git a/tensorflow/dtensor/python/layout.py b/tensorflow/dtensor/python/layout.py index c3b21d744c2e20..052e8de2de626e 100644 --- a/tensorflow/dtensor/python/layout.py +++ b/tensorflow/dtensor/python/layout.py @@ -221,6 +221,14 @@ def to_spec(d) -> tf_device.DeviceSpec: use_xla_spmd, ) + @classmethod + def _new_object(cls, *args, **kwargs): + # Need to explicitly invoke the base class __init__ because + # Mesh.__init__ overrode it with a different signature. + self = _pywrap_dtensor_device.Mesh.__new__(cls) + super().__init__(self, *args, **kwargs) + return self + def global_device_ids(self) -> np.ndarray: """Returns a global device list as an array.""" return np.array(super().global_device_ids(), dtype=np.int64).reshape( @@ -255,26 +263,25 @@ def coords(self, device_idx: int) -> tensor.Tensor: @classmethod def from_proto(cls, proto: layout_pb2.MeshProto) -> 'Mesh': """Construct a mesh instance from input `proto`.""" - mesh = _pywrap_dtensor_device.Mesh.__new__(cls) - _pywrap_dtensor_device.Mesh.__init__(mesh, mesh_proto=proto) - return mesh + return cls._new_object(mesh_proto=proto) @classmethod def from_string(cls, mesh_str: str) -> 'Mesh': - mesh = _pywrap_dtensor_device.Mesh.__new__(cls) - _pywrap_dtensor_device.Mesh.__init__(mesh, mesh_str=mesh_str) - return mesh + return cls._new_object(mesh_str=mesh_str) @classmethod def from_device(cls, device: str) -> 'Mesh': """Constructs a single device mesh from a device string.""" - mesh = _pywrap_dtensor_device.Mesh.__new__(cls) - _pywrap_dtensor_device.Mesh.__init__(mesh, single_device=device) - return mesh + return cls._new_object(single_device=device) + + @classmethod + def _from_mesh(cls, mesh: _pywrap_dtensor_device.Mesh): + """Creates a copy from an existing pywrap mesh object.""" + return cls._new_object(mesh=mesh) @functools.cached_property def _host_mesh(self) -> 'Mesh': - return Mesh.from_string(super().host_mesh().to_string()) + return Mesh._from_mesh(super().host_mesh()) def host_mesh(self) -> 'Mesh': """Returns a host mesh.""" @@ -426,6 +433,14 @@ def __init__(self, sharding_specs: List[str], mesh: Mesh): super().__init__(sharding_specs=sharding_specs, mesh=mesh) + @classmethod + def _new_object(cls, *args, **kwargs): + # Need to explicitly invoke the base class __init__ because + # Layout.__init__ overrode it with a different signature. + self = _pywrap_dtensor_device.Layout.__new__(cls) + super().__init__(self, *args, **kwargs) + return self + def __repr__(self) -> str: return f'Layout(sharding_specs={self.sharding_specs}, mesh={self.mesh})' @@ -436,10 +451,9 @@ def __hash__(self): def __reduce__(self): return Layout.from_string, (self.to_string(),) - # TODO(b/242201545): Find a way to return Mesh object from the pywrap module. @property def mesh(self): - return Mesh.from_proto(super().mesh.as_proto()) + return Mesh._from_mesh(mesh=super().mesh) # pylint: disable=protected-access @property def shape(self): @@ -450,16 +464,13 @@ def batch_sharded( cls, mesh: Mesh, batch_dim: str, rank: int, axis: int = 0 ) -> 'Layout': """Returns a layout sharded on batch dimension.""" - layout_obj = _pywrap_dtensor_device.Layout.__new__(cls) - _pywrap_dtensor_device.Layout.__init__( + return cls._new_object( # Watchout for the different ordering. - layout_obj, mesh=mesh, rank=rank, batch_dim=batch_dim, axis=axis, ) - return layout_obj # TODO(b/242201545): Move this to C++ / find the corresponding function there. def delete(self, dims: List[int]) -> 'Layout': @@ -474,18 +485,12 @@ def delete(self, dims: List[int]) -> 'Layout': @classmethod def from_proto(cls, layout_proto: layout_pb2.LayoutProto) -> 'Layout': """Creates an instance from a LayoutProto.""" - layout_obj = _pywrap_dtensor_device.Layout.__new__(cls) - _pywrap_dtensor_device.Layout.__init__( - layout_obj, layout_proto=layout_proto - ) - return layout_obj + return cls._new_object(layout_proto=layout_proto) @classmethod def from_string(cls, layout_str: str) -> 'Layout': """Creates an instance from a human-readable string.""" - layout_obj = _pywrap_dtensor_device.Layout.__new__(cls) - _pywrap_dtensor_device.Layout.__init__(layout_obj, layout_str=layout_str) - return layout_obj + return cls._new_object(layout_str=layout_str) @classmethod def inner_sharded(cls, mesh: Mesh, inner_dim: str, rank: int) -> 'Layout': @@ -495,9 +500,7 @@ def inner_sharded(cls, mesh: Mesh, inner_dim: str, rank: int) -> 'Layout': @classmethod def from_single_device_mesh(cls, mesh: Mesh) -> 'Layout': """Constructs a single device layout from a single device mesh.""" - layout = _pywrap_dtensor_device.Layout.__new__(cls) - _pywrap_dtensor_device.Layout.__init__(layout, mesh=mesh) - return layout + return cls._new_object(mesh=mesh) @classmethod def from_device(cls, device: str) -> 'Layout': @@ -534,6 +537,4 @@ def offset_tuple_to_global_index(self, offset_tuple): @classmethod def replicated(cls, mesh: Mesh, rank: int) -> 'Layout': """Returns a replicated layout of rank `rank`.""" - layout_obj = _pywrap_dtensor_device.Layout.__new__(cls) - _pywrap_dtensor_device.Layout.__init__(layout_obj, mesh=mesh, rank=rank) - return layout_obj + return cls._new_object(mesh=mesh, rank=rank) diff --git a/tensorflow/python/pywrap_dtensor_device.cc b/tensorflow/python/pywrap_dtensor_device.cc index df42d1ec27eea5..842abaf8393f7c 100644 --- a/tensorflow/python/pywrap_dtensor_device.cc +++ b/tensorflow/python/pywrap_dtensor_device.cc @@ -395,6 +395,8 @@ PYBIND11_MODULE(_pywrap_dtensor_device, m) { tensor_handle, element_layouts, device_info, status.get()); }); py::class_(m, "Mesh") + .def(py::init([](Mesh& mesh) { return mesh; }), py::arg("mesh"), + "Create a copy of a mesh.") .def(py::init(&Mesh::CreateMesh)) .def(py::init([](absl::string_view single_device) { auto mesh = Mesh::GetSingleDeviceMesh(single_device); From 3e341088f8d15355a52cf375891af07b2c141c1f Mon Sep 17 00:00:00 2001 From: Fiona Lang Date: Tue, 11 Jul 2023 10:49:50 -0700 Subject: [PATCH 138/376] Update ops.Tensor references to //third_party/tensorflow/python/framework/tensor.py. PiperOrigin-RevId: 547235969 --- tensorflow/core/function/transform/BUILD | 2 +- .../core/function/transform/transform.py | 7 +++--- tensorflow/python/autograph/converters/BUILD | 2 +- .../python/autograph/converters/lists_test.py | 4 ++-- tensorflow/python/autograph/operators/BUILD | 2 +- .../operators/data_structures_test.py | 4 ++-- tensorflow/python/client/BUILD | 1 + tensorflow/python/client/session.py | 7 +++--- tensorflow/python/compiler/tensorrt/BUILD | 1 + .../python/compiler/tensorrt/trt_convert.py | 5 ++-- .../python/compiler/xla/experimental/BUILD | 2 +- .../xla/experimental/xla_sharding_test.py | 6 ++--- tensorflow/python/debug/cli/BUILD | 3 ++- tensorflow/python/debug/cli/cli_shared.py | 5 +++- .../python/distribute/coordinator/BUILD | 1 + .../coordinator/fault_tolerance_test_base.py | 5 ++-- .../python/distribute/parallel_device/BUILD | 1 + .../parallel_device/parallel_device.py | 24 ++++++++++++++----- tensorflow/python/distribute/v1/BUILD | 1 + .../distribute/v1/cross_device_ops_test.py | 5 ++-- tensorflow/python/feature_column/BUILD | 1 + .../feature_column/feature_column_v2.py | 3 ++- .../kernel_tests/tensor_priority_test.py | 3 ++- tensorflow/python/types/core.py | 2 +- .../api/lib/python_object_to_proto_visitor.py | 6 ++--- 25 files changed, 66 insertions(+), 37 deletions(-) diff --git a/tensorflow/core/function/transform/BUILD b/tensorflow/core/function/transform/BUILD index 8bedf764edb48d..91e7302db8229a 100644 --- a/tensorflow/core/function/transform/BUILD +++ b/tensorflow/core/function/transform/BUILD @@ -27,7 +27,7 @@ pytype_strict_library( "//tensorflow/python/framework:func_graph", "//tensorflow/python/framework:function_def_to_graph", "//tensorflow/python/framework:ops", - "//tensorflow/python/framework:tensor_spec", + "//tensorflow/python/framework:tensor", "//tensorflow/python/ops:custom_gradient", "//tensorflow/python/ops:default_gradient", "//tensorflow/python/ops:handle_data_util", diff --git a/tensorflow/core/function/transform/transform.py b/tensorflow/core/function/transform/transform.py index a4a0ecd77a407f..2d259ddae0d2d0 100644 --- a/tensorflow/core/function/transform/transform.py +++ b/tensorflow/core/function/transform/transform.py @@ -25,13 +25,14 @@ from tensorflow.python.framework import func_graph as func_graph_module from tensorflow.python.framework import function_def_to_graph as function_def_lib from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.ops import custom_gradient as custom_gradient_lib from tensorflow.python.ops import default_gradient from tensorflow.python.ops import handle_data_util from tensorflow.python.platform import tf_logging from tensorflow.python.util import compat -_TensorType = Union[ops.EagerTensor, ops.Tensor] +_TensorType = Union[ops.EagerTensor, tensor.Tensor] _FunctionDefTransformerType = Callable[[function_pb2.FunctionDef], None] @@ -233,8 +234,8 @@ def add(x, y): # Set handle data. for i, output in enumerate(cf.outputs): func_graph_output = func_graph.outputs[i] - if isinstance(output, ops.Tensor) and isinstance( - func_graph_output, ops.Tensor + if isinstance(output, tensor.Tensor) and isinstance( + func_graph_output, tensor.Tensor ): func_graph_output.set_shape(output.shape) handle_data_util.copy_handle_data(output, func_graph_output) diff --git a/tensorflow/python/autograph/converters/BUILD b/tensorflow/python/autograph/converters/BUILD index 2261a63a93c29e..70d4bab1d2be48 100644 --- a/tensorflow/python/autograph/converters/BUILD +++ b/tensorflow/python/autograph/converters/BUILD @@ -376,7 +376,7 @@ py_strict_test( "//tensorflow/python/autograph/lang:directives", "//tensorflow/python/autograph/lang:special_functions", "//tensorflow/python/framework:dtypes", - "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/ops:array_ops_stack", "//tensorflow/python/ops:list_ops", "//tensorflow/python/platform:client_testlib", diff --git a/tensorflow/python/autograph/converters/lists_test.py b/tensorflow/python/autograph/converters/lists_test.py index 43dfa5f48e1622..a6613f47ab7d16 100644 --- a/tensorflow/python/autograph/converters/lists_test.py +++ b/tensorflow/python/autograph/converters/lists_test.py @@ -20,7 +20,7 @@ from tensorflow.python.autograph.lang import directives from tensorflow.python.autograph.lang import special_functions from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.ops import array_ops_stack from tensorflow.python.ops import list_ops from tensorflow.python.platform import test @@ -37,7 +37,7 @@ def f(): tl = tr() # Empty tensor lists cannot be evaluated or stacked. - self.assertIsInstance(tl, ops.Tensor) + self.assertIsInstance(tl, tensor.Tensor) self.assertEqual(tl.dtype, dtypes.variant) def test_initialized_list(self): diff --git a/tensorflow/python/autograph/operators/BUILD b/tensorflow/python/autograph/operators/BUILD index cde1fbf8bf2daf..765ab5fa24b67c 100644 --- a/tensorflow/python/autograph/operators/BUILD +++ b/tensorflow/python/autograph/operators/BUILD @@ -154,7 +154,7 @@ py_strict_test( ":data_structures", "//tensorflow/python/framework:constant_op", "//tensorflow/python/framework:dtypes", - "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:test_lib", "//tensorflow/python/ops:list_ops", "//tensorflow/python/ops:tensor_array_ops", diff --git a/tensorflow/python/autograph/operators/data_structures_test.py b/tensorflow/python/autograph/operators/data_structures_test.py index 599d0a21e10ef5..707406b9651cda 100644 --- a/tensorflow/python/autograph/operators/data_structures_test.py +++ b/tensorflow/python/autograph/operators/data_structures_test.py @@ -17,7 +17,7 @@ from tensorflow.python.autograph.operators import data_structures from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.framework import test_util from tensorflow.python.ops import list_ops from tensorflow.python.ops import tensor_array_ops @@ -30,7 +30,7 @@ def test_new_list_empty(self): l = data_structures.new_list() # Can't evaluate an empty list. # TODO(mdan): sess.run should allow tf.variant maybe? - self.assertTrue(isinstance(l, ops.Tensor)) + self.assertTrue(isinstance(l, tensor.Tensor)) def test_new_list_tensor(self): l = data_structures.new_list([3, 4, 5]) diff --git a/tensorflow/python/client/BUILD b/tensorflow/python/client/BUILD index 1f26969b32b394..0a9fedc17725e0 100644 --- a/tensorflow/python/client/BUILD +++ b/tensorflow/python/client/BUILD @@ -273,6 +273,7 @@ py_strict_library( "//tensorflow/python/framework:ops", "//tensorflow/python/framework:sparse_tensor", "//tensorflow/python/framework:stack", + "//tensorflow/python/framework:tensor", "//tensorflow/python/ops:session_ops", "//tensorflow/python/platform:tf_logging", "//tensorflow/python/training/experimental:mixed_precision_global_state", diff --git a/tensorflow/python/client/session.py b/tensorflow/python/client/session.py index 6c68d0c17595f0..8c0fec1591bf79 100644 --- a/tensorflow/python/client/session.py +++ b/tensorflow/python/client/session.py @@ -35,6 +35,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import stack +from tensorflow.python.framework import tensor from tensorflow.python.ops import session_ops from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training.experimental import mixed_precision_global_state @@ -502,7 +503,7 @@ def __init__(self, graph, fetches, feeds, feed_handles=None): self._fetches.append(fetch) self._ops.append(False) # Remember the fetch if it is for a tensor handle. - if (isinstance(fetch, ops.Tensor) and + if (isinstance(fetch, tensor.Tensor) and (fetch.op.type == 'GetSessionHandle' or fetch.op.type == 'GetSessionHandleV2')): self._fetch_handles[fetch.ref()] = fetch.op.inputs[0].dtype @@ -1158,7 +1159,7 @@ def _feed_fn(feed, feed_val): raise TypeError( f'Cannot interpret feed_dict key as Tensor: {e.args[0]}') - if isinstance(subfeed_val, ops.Tensor): + if isinstance(subfeed_val, tensor.Tensor): raise TypeError( 'The value of a feed cannot be a tf.Tensor object. Acceptable ' 'feed values include Python scalars, strings, lists, numpy ' @@ -1322,7 +1323,7 @@ def _single_operation_run(): self._call_tf_sessionrun(None, {}, [], target_list, None) return _single_operation_run - elif isinstance(fetches, ops.Tensor): + elif isinstance(fetches, tensor.Tensor): # Special case for fetching a single tensor, because the # function can return the result of `TF_Run()` directly. assert len(fetch_list) == 1 diff --git a/tensorflow/python/compiler/tensorrt/BUILD b/tensorflow/python/compiler/tensorrt/BUILD index 89ee69d2637dc1..f3fd845ff53b10 100644 --- a/tensorflow/python/compiler/tensorrt/BUILD +++ b/tensorflow/python/compiler/tensorrt/BUILD @@ -49,6 +49,7 @@ py_strict_library( "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:errors", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/grappler:tf_optimizer", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:resource_variable_ops_gen", diff --git a/tensorflow/python/compiler/tensorrt/trt_convert.py b/tensorflow/python/compiler/tensorrt/trt_convert.py index 392869eee6d0f5..2f80952e303852 100644 --- a/tensorflow/python/compiler/tensorrt/trt_convert.py +++ b/tensorflow/python/compiler/tensorrt/trt_convert.py @@ -37,6 +37,7 @@ from tensorflow.python.framework import errors from tensorflow.python.framework import importer from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.grappler import tf_optimizer from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_resource_variable_ops @@ -562,7 +563,7 @@ def _add_nodes_denylist(self): collection_def = self._grappler_meta_graph_def.collection_def["train_op"] denylist = collection_def.node_list.value for i in self._nodes_denylist: - if isinstance(i, ops.Tensor): + if isinstance(i, tensor.Tensor): denylist.append(_to_bytes(i.name)) else: denylist.append(_to_bytes(i)) @@ -692,7 +693,7 @@ def calibrate(self, for k, v in input_map_fn().items(): if not isinstance(k, str): raise ValueError("Keys of input_map_fn must be of type str") - if not isinstance(v, ops.Tensor): + if not isinstance(v, tensor.Tensor): raise ValueError("Values of input_map_fn must be of type tf.Tensor") self._calibration_graph = ops.Graph() diff --git a/tensorflow/python/compiler/xla/experimental/BUILD b/tensorflow/python/compiler/xla/experimental/BUILD index 595897108d19c8..018e4daedb1292 100644 --- a/tensorflow/python/compiler/xla/experimental/BUILD +++ b/tensorflow/python/compiler/xla/experimental/BUILD @@ -29,7 +29,7 @@ py_strict_test( "//tensorflow/compiler/xla:xla_data_proto_py", "//tensorflow/python/eager:def_function", "//tensorflow/python/framework:dtypes", - "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:test_lib", "//tensorflow/python/ops:array_ops", "//third_party/py/numpy", diff --git a/tensorflow/python/compiler/xla/experimental/xla_sharding_test.py b/tensorflow/python/compiler/xla/experimental/xla_sharding_test.py index 2f0281e99b21de..924f737c81fb48 100644 --- a/tensorflow/python/compiler/xla/experimental/xla_sharding_test.py +++ b/tensorflow/python/compiler/xla/experimental/xla_sharding_test.py @@ -22,7 +22,7 @@ from tensorflow.python.compiler.xla.experimental import xla_sharding from tensorflow.python.eager import def_function from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops @@ -90,7 +90,7 @@ def test_tile_annotates_tensor_correctly(self): def tile_helper(tensor): self.assertIsNone(xla_sharding.get_tensor_sharding(tensor)) tiled_tensor = xla_sharding.tile(tensor, np.array([2, 1, 6])) - self.assertIsInstance(tiled_tensor, ops.Tensor) + self.assertIsInstance(tiled_tensor, tensor_lib.Tensor) tiled_sharding = xla_sharding.get_tensor_sharding(tiled_tensor) tile_shape = xla_sharding.get_sharding_tile_shape(tiled_sharding) # This is the shape of the tile assignment [2, 1, 6] @@ -108,7 +108,7 @@ def test_split_annotates_tensor_correctly(self): def split_helper(tensor): self.assertIsNone(xla_sharding.get_tensor_sharding(tensor)) split_tensor = xla_sharding.split(tensor, 2, 3) - self.assertIsInstance(split_tensor, ops.Tensor) + self.assertIsInstance(split_tensor, tensor_lib.Tensor) split_sharding = xla_sharding.get_tensor_sharding(split_tensor) split_shape = xla_sharding.get_sharding_tile_shape(split_sharding) expected_shape = [1, 1, 3] diff --git a/tensorflow/python/debug/cli/BUILD b/tensorflow/python/debug/cli/BUILD index 98f8badc98f18a..f3b8a53513de72 100644 --- a/tensorflow/python/debug/cli/BUILD +++ b/tensorflow/python/debug/cli/BUILD @@ -78,7 +78,8 @@ py_strict_library( ":debugger_cli_common", ":tensor_format", "//tensorflow/python/debug/lib:common", - "//tensorflow/python/framework:for_generated_wrappers", + "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/ops:variables", "//tensorflow/python/platform:gfile", "//third_party/py/numpy", diff --git a/tensorflow/python/debug/cli/cli_shared.py b/tensorflow/python/debug/cli/cli_shared.py index 3c5e21a2ff0bac..69466d446c1469 100644 --- a/tensorflow/python/debug/cli/cli_shared.py +++ b/tensorflow/python/debug/cli/cli_shared.py @@ -22,6 +22,7 @@ from tensorflow.python.debug.cli import tensor_format from tensorflow.python.debug.lib import common from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.ops import variables from tensorflow.python.platform import gfile @@ -402,7 +403,9 @@ def get_run_short_description(run_call_count, description = "run #%d: " % run_call_count - if isinstance(fetches, (ops.Tensor, ops.Operation, variables.Variable)): + if isinstance( + fetches, (tensor_lib.Tensor, ops.Operation, variables.Variable) + ): description += "1 fetch (%s); " % common.get_graph_element_name(fetches) else: # Could be (nested) list, tuple, dict or namedtuple. diff --git a/tensorflow/python/distribute/coordinator/BUILD b/tensorflow/python/distribute/coordinator/BUILD index aa9f71c740449f..b3c9f3f836ae80 100644 --- a/tensorflow/python/distribute/coordinator/BUILD +++ b/tensorflow/python/distribute/coordinator/BUILD @@ -152,6 +152,7 @@ py_strict_library( "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:errors", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:check_ops", "//tensorflow/python/ops:math_ops", diff --git a/tensorflow/python/distribute/coordinator/fault_tolerance_test_base.py b/tensorflow/python/distribute/coordinator/fault_tolerance_test_base.py index 510f04f29bbb31..ca68282ca0ba3f 100644 --- a/tensorflow/python/distribute/coordinator/fault_tolerance_test_base.py +++ b/tensorflow/python/distribute/coordinator/fault_tolerance_test_base.py @@ -31,6 +31,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import math_ops @@ -540,14 +541,14 @@ def worker_fn(): # Attempt to fetch before killing worker task should succeed. fetched = remote_value.get()[0] - self.assertIsInstance(fetched, ops.Tensor) + self.assertIsInstance(fetched, tensor.Tensor) self.assertEqual(fetched.device, "/job:chief/replica:0/task:0/device:CPU:0") self.assertEqual((1, -1), remote_value.get()) remote_value.get()[0].numpy() # As well as the remote tensors that point to worker0 or worker1. values = remote_value._values[0] - self.assertIsInstance(values, ops.Tensor) + self.assertIsInstance(values, tensor.Tensor) self.assertRegex(values.device, "/job:worker/replica:0/task:[0-1]/device:CPU:0") self.assertEqual((1, -1), remote_value._values) diff --git a/tensorflow/python/distribute/parallel_device/BUILD b/tensorflow/python/distribute/parallel_device/BUILD index 3341e9557669f8..de6f9a3220a08d 100644 --- a/tensorflow/python/distribute/parallel_device/BUILD +++ b/tensorflow/python/distribute/parallel_device/BUILD @@ -27,6 +27,7 @@ py_strict_library( "//tensorflow/python/framework:composite_tensor", "//tensorflow/python/framework:constant_op", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:variables", "//tensorflow/python/tpu/ops", diff --git a/tensorflow/python/distribute/parallel_device/parallel_device.py b/tensorflow/python/distribute/parallel_device/parallel_device.py index c3255c57aba9b2..771925efa2372c 100644 --- a/tensorflow/python/distribute/parallel_device/parallel_device.py +++ b/tensorflow/python/distribute/parallel_device/parallel_device.py @@ -23,6 +23,7 @@ from tensorflow.python.framework import composite_tensor from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.ops import array_ops from tensorflow.python.ops import variables from tensorflow.python.tpu.ops import tpu_ops @@ -83,8 +84,14 @@ def __init__(self, components): def _pack_tensor(self, *tensors): """Helper to pack plain-old-tensors, not structures or composites.""" for tensor in tensors: - if not isinstance(tensor, (ops.Tensor, composite_tensor.CompositeTensor, - variables.Variable)): + if not isinstance( + tensor, + ( + tensor_lib.Tensor, + composite_tensor.CompositeTensor, + variables.Variable, + ), + ): raise ValueError( ("Every component to pack onto the ParallelDevice must already be " "a tensor, got {}. Consider running `tf.constant` or " @@ -129,10 +136,15 @@ def pack(self, tensors): def _unpack_tensor(self, parallel_tensor): """Helper to unpack a single tensor.""" - if not isinstance(parallel_tensor, ( - ops.Tensor, composite_tensor.CompositeTensor, variables.Variable)): - raise ValueError( - "Expected a tensor, got {}.".format(parallel_tensor)) + if not isinstance( + parallel_tensor, + ( + tensor_lib.Tensor, + composite_tensor.CompositeTensor, + variables.Variable, + ), + ): + raise ValueError("Expected a tensor, got {}.".format(parallel_tensor)) with ops.device(self._name): return tpu_ops.tpu_replicated_output( parallel_tensor, num_replicas=len(self.components)) diff --git a/tensorflow/python/distribute/v1/BUILD b/tensorflow/python/distribute/v1/BUILD index e88086280e7677..59f19db8b4e11c 100644 --- a/tensorflow/python/distribute/v1/BUILD +++ b/tensorflow/python/distribute/v1/BUILD @@ -37,6 +37,7 @@ cuda_py_strict_test( "//tensorflow/python/framework:indexed_slices", "//tensorflow/python/framework:kernels", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:collective_ops", "//tensorflow/python/ops:math_ops", diff --git a/tensorflow/python/distribute/v1/cross_device_ops_test.py b/tensorflow/python/distribute/v1/cross_device_ops_test.py index fa59aba7f52c79..360cec0bd5ce71 100644 --- a/tensorflow/python/distribute/v1/cross_device_ops_test.py +++ b/tensorflow/python/distribute/v1/cross_device_ops_test.py @@ -41,6 +41,7 @@ from tensorflow.python.framework import indexed_slices as indexed_slices_lib from tensorflow.python.framework import kernels from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.ops import array_ops from tensorflow.python.ops import collective_ops from tensorflow.python.ops import math_ops @@ -52,7 +53,7 @@ def _get_devices(devices): return tuple(device_util.resolve(d) for d in devices) elif isinstance(devices, value_lib.DistributedValues): return devices._devices - elif isinstance(devices, ops.Tensor): + elif isinstance(devices, tensor_lib.Tensor): return (device_util.resolve(devices.device),) return (device_util.resolve(devices),) @@ -422,7 +423,7 @@ def testReduceDistributedVariable(self, distribution, else: result = cross_device_ops_instance.reduce(reduce_util.ReduceOp.MEAN, v, v) for v in result.values: - self.assertIsInstance(v, ops.Tensor) + self.assertIsInstance(v, tensor_lib.Tensor) self.evaluate(variables.global_variables_initializer()) self.assertAllEqual(self.evaluate(result.values), [1.0, 1.0]) diff --git a/tensorflow/python/feature_column/BUILD b/tensorflow/python/feature_column/BUILD index b1b54e34e068b3..9bb28913c72198 100644 --- a/tensorflow/python/feature_column/BUILD +++ b/tensorflow/python/feature_column/BUILD @@ -76,6 +76,7 @@ py_strict_library( "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", "//tensorflow/python/framework:sparse_tensor", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:array_ops_stack", diff --git a/tensorflow/python/feature_column/feature_column_v2.py b/tensorflow/python/feature_column/feature_column_v2.py index 4613bf386228d9..c1e5e22867cf82 100644 --- a/tensorflow/python/feature_column/feature_column_v2.py +++ b/tensorflow/python/feature_column/feature_column_v2.py @@ -141,6 +141,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops_stack @@ -1502,7 +1503,7 @@ def categorical_column_with_vocabulary_file_v2(key, 'in the vocabulary_file %s.', vocabulary_size, key, vocabulary_file) # `vocabulary_size` isn't required for lookup, but it is for `_num_buckets`. - if not isinstance(vocabulary_size, ops.Tensor) and vocabulary_size < 1: + if not isinstance(vocabulary_size, tensor_lib.Tensor) and vocabulary_size < 1: raise ValueError('Invalid vocabulary_size in {}.'.format(key)) if num_oov_buckets: if default_value is not None: diff --git a/tensorflow/python/kernel_tests/tensor_priority_test.py b/tensorflow/python/kernel_tests/tensor_priority_test.py index bb779f26eff30c..4111fa080f592a 100644 --- a/tensorflow/python/kernel_tests/tensor_priority_test.py +++ b/tensorflow/python/kernel_tests/tensor_priority_test.py @@ -16,6 +16,7 @@ import numpy as np from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_conversion_registry from tensorflow.python.platform import test as test_lib @@ -34,7 +35,7 @@ class NumpyArraySubclass(np.ndarray): for rhs in supported_rhs_without_delegation: tensor = ops.convert_to_tensor([[10.0, 20.0]]) res = tensor + rhs - self.assertIsInstance(res, ops.Tensor) + self.assertIsInstance(res, tensor_lib.Tensor) def testUnsupportedRhsWithoutDelegation(self): diff --git a/tensorflow/python/types/core.py b/tensorflow/python/types/core.py index f6d6158a12f22a..83b67ef71ba3fd 100644 --- a/tensorflow/python/types/core.py +++ b/tensorflow/python/types/core.py @@ -56,7 +56,7 @@ def shape(self): pass -# `ops.EagerTensor` subclasses `Symbol` by way of subclassing `ops.Tensor`; +# `ops.EagerTensor` subclasses `Symbol` by way of subclassing `tensor.Tensor`; # care should be taken when performing `isinstance` checks on `Value`, e.g.: # # ``` diff --git a/tensorflow/tools/api/lib/python_object_to_proto_visitor.py b/tensorflow/tools/api/lib/python_object_to_proto_visitor.py index 3066c0e597f3f0..0212b07d7600cf 100644 --- a/tensorflow/tools/api/lib/python_object_to_proto_visitor.py +++ b/tensorflow/tools/api/lib/python_object_to_proto_visitor.py @@ -89,9 +89,9 @@ def _SkipMember(cls, member): # pylint: disable=unused-argument # Differences created by typing implementations. -_NORMALIZE_TYPE[( - 'tensorflow.python.framework.ops.Tensor')] = ( - "") +_NORMALIZE_TYPE[ + 'tensorflow.python.framework.tensor.Tensor' +] = "" _NORMALIZE_TYPE['typing.Generic'] = "" # TODO(b/203104448): Remove once the golden files are generated in Python 3.7. _NORMALIZE_TYPE[""] = 'typing.Union' From 4d4cc6d06e0aa55c15dfdea61c8b57e5f8a50cf3 Mon Sep 17 00:00:00 2001 From: Kuangyuan Chen Date: Tue, 11 Jul 2023 10:56:17 -0700 Subject: [PATCH 139/376] Propagate cancellation manager through BatchFunction op PiperOrigin-RevId: 547238051 --- .../kernel/kernel_fallback_compat_request_state.cc | 4 +++- .../kernel/kernel_fallback_compat_request_state.h | 11 ++++++----- .../runtime/runtime_fallback_batch_tf_opkernels.cc | 6 +++++- .../core/runtime_fallback/util/fallback_test_util.cc | 6 +++++- tensorflow/core/tfrt/mlrt/kernel/batch_kernel.cc | 3 +++ 5 files changed, 22 insertions(+), 8 deletions(-) diff --git a/tensorflow/core/runtime_fallback/kernel/kernel_fallback_compat_request_state.cc b/tensorflow/core/runtime_fallback/kernel/kernel_fallback_compat_request_state.cc index 363f02a89358c0..f9a404d560e7b3 100644 --- a/tensorflow/core/runtime_fallback/kernel/kernel_fallback_compat_request_state.cc +++ b/tensorflow/core/runtime_fallback/kernel/kernel_fallback_compat_request_state.cc @@ -157,7 +157,8 @@ Status SetUpKernelFallbackCompatRequestContext( const absl::optional& model_metadata, std::function)>* runner, tfrt_stub::CostRecorder* cost_recorder, - tfrt::ResourceContext* client_graph_resource_context) { + tfrt::ResourceContext* client_graph_resource_context, + tensorflow::CancellationManager* cancellation_manager) { DCHECK(builder); DCHECK(device_manager); DCHECK(pflr); @@ -173,6 +174,7 @@ Status SetUpKernelFallbackCompatRequestContext( fallback_request_state.set_cost_recorder(cost_recorder); fallback_request_state.set_client_graph_resource_context( client_graph_resource_context); + fallback_request_state.set_cancellation_manager(cancellation_manager); return OkStatus(); } diff --git a/tensorflow/core/runtime_fallback/kernel/kernel_fallback_compat_request_state.h b/tensorflow/core/runtime_fallback/kernel/kernel_fallback_compat_request_state.h index 94df91109c06bd..a37c772b978c43 100644 --- a/tensorflow/core/runtime_fallback/kernel/kernel_fallback_compat_request_state.h +++ b/tensorflow/core/runtime_fallback/kernel/kernel_fallback_compat_request_state.h @@ -235,11 +235,12 @@ Status SetUpKernelFallbackCompatRequestContext( const tensorflow::ProcessFunctionLibraryRuntime* pflr, tfrt_stub::OpKernelRunnerTable* runner_table, FallbackResourceArray* resource_array, - tensorflow::thread::ThreadPoolInterface* user_intra_op_threadpool = nullptr, - const std::optional& model_metadata = std::nullopt, - std::function)>* runner = nullptr, - tfrt_stub::CostRecorder* cost_recorder = nullptr, - tfrt::ResourceContext* client_graph_resource_context = nullptr); + tensorflow::thread::ThreadPoolInterface* user_intra_op_threadpool, + const std::optional& model_metadata, + std::function)>* runner, + tfrt_stub::CostRecorder* cost_recorder, + tfrt::ResourceContext* client_graph_resource_context, + tensorflow::CancellationManager* cancellation_manager); } // namespace tfd } // namespace tensorflow diff --git a/tensorflow/core/runtime_fallback/runtime/runtime_fallback_batch_tf_opkernels.cc b/tensorflow/core/runtime_fallback/runtime/runtime_fallback_batch_tf_opkernels.cc index 426f189a503224..1facf07d621831 100644 --- a/tensorflow/core/runtime_fallback/runtime/runtime_fallback_batch_tf_opkernels.cc +++ b/tensorflow/core/runtime_fallback/runtime/runtime_fallback_batch_tf_opkernels.cc @@ -267,7 +267,11 @@ Status SetUpKernelFallbackCompatRequestContextForBatch( return SetUpKernelFallbackCompatRequestContext( builder, device_manager, pflr, runner_table, resource_array, - intra_op_threadpool, session_metadata, /*runner=*/nullptr); + intra_op_threadpool, session_metadata, + src_fallback_request_state->runner(), + src_fallback_request_state->cost_recorder(), + src_fallback_request_state->client_graph_resource_context(), + src_fallback_request_state->cancellation_manager()); } StatusOr> SetUpRequestContext( diff --git a/tensorflow/core/runtime_fallback/util/fallback_test_util.cc b/tensorflow/core/runtime_fallback/util/fallback_test_util.cc index 3e451b3199408d..9d5029d747aa5c 100644 --- a/tensorflow/core/runtime_fallback/util/fallback_test_util.cc +++ b/tensorflow/core/runtime_fallback/util/fallback_test_util.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/core/runtime_fallback/util/fallback_test_util.h" #include +#include #include #include "tensorflow/compiler/mlir/tfrt/jit/tf_jitrt_request_context.h" @@ -69,7 +70,10 @@ tfrt::ExecutionContext CreateFallbackTestExecutionContext( status = SetUpKernelFallbackCompatRequestContext( &request_context_builder, eager_context->local_device_mgr(), eager_context->pflr(), runner_table, resource_array, - user_intra_op_threadpool); + user_intra_op_threadpool, /*model_metadata=*/std::nullopt, + /*runner=*/nullptr, /*cost_recorder=*/nullptr, + /*client_graph_resource_context=*/resource_context, + /*cancellation_manager=*/nullptr); TF_DCHECK_OK(status); status = SetUpTfJitRtRequestContext(&request_context_builder); diff --git a/tensorflow/core/tfrt/mlrt/kernel/batch_kernel.cc b/tensorflow/core/tfrt/mlrt/kernel/batch_kernel.cc index 0d274cab8ada57..796a484dfa4665 100644 --- a/tensorflow/core/tfrt/mlrt/kernel/batch_kernel.cc +++ b/tensorflow/core/tfrt/mlrt/kernel/batch_kernel.cc @@ -322,6 +322,9 @@ void MlrtBatchResource::ProcessFuncBatchImpl( fallback_request_state.set_client_graph_resource_context( caller_fallback_request_state.client_graph_resource_context()); + fallback_request_state.set_cancellation_manager( + caller_fallback_request_state.cancellation_manager()); + tensorflow::profiler::TraceMeProducer activity( // To TraceMeConsumers in WorkQueue. [step_id] { From 5c0549ccd18d28d6d328166255eff8ded402415f Mon Sep 17 00:00:00 2001 From: Fiona Lang Date: Tue, 11 Jul 2023 11:23:38 -0700 Subject: [PATCH 140/376] Move dlpack imports from python/__init__.py to python/modules_with_exports.py. PiperOrigin-RevId: 547246489 --- tensorflow/python/BUILD | 1 + tensorflow/python/__init__.py | 4 ---- tensorflow/python/modules_with_exports.py | 4 ++++ 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index f23e2de226a2f9..7aeba32095772a 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -352,6 +352,7 @@ py_library( "//tensorflow/python/distribute:sharded_variable", "//tensorflow/python/distribute/failure_handling:failure_handling_lib", "//tensorflow/python/distribute/failure_handling:preemption_watcher", + "//tensorflow/python/dlpack", "//tensorflow/python/eager:context", "//tensorflow/python/eager:def_function", "//tensorflow/python/eager:monitoring", diff --git a/tensorflow/python/__init__.py b/tensorflow/python/__init__.py index 411e32a440f901..174befc5b6c522 100644 --- a/tensorflow/python/__init__.py +++ b/tensorflow/python/__init__.py @@ -77,10 +77,6 @@ from tensorflow.python.debug.lib import dumping_callback from tensorflow.python.ops import gen_debug_ops -# DLPack -from tensorflow.python.dlpack.dlpack import from_dlpack -from tensorflow.python.dlpack.dlpack import to_dlpack - # XLA JIT compiler APIs. from tensorflow.python.compiler.xla import jit from tensorflow.python.compiler.xla import xla diff --git a/tensorflow/python/modules_with_exports.py b/tensorflow/python/modules_with_exports.py index 824692d450283c..b65d583d8d1971 100644 --- a/tensorflow/python/modules_with_exports.py +++ b/tensorflow/python/modules_with_exports.py @@ -37,6 +37,10 @@ # Distribute from tensorflow.python import distribute +# DLPack +from tensorflow.python.dlpack.dlpack import from_dlpack +from tensorflow.python.dlpack.dlpack import to_dlpack + # Eager from tensorflow.python.eager import context from tensorflow.python.eager import def_function From c3ce6627361e744850ab82701229a530a6a0d99d Mon Sep 17 00:00:00 2001 From: Alan Liu Date: Tue, 11 Jul 2023 11:46:00 -0700 Subject: [PATCH 141/376] Update Partition() (graph_partition.cc) to handle debug_info in the GraphDef. When partitioning a graph, make it build the GraphDebugInfo stack traces for the nodes in each partitioned graph into the GraphDef for that partitioned graph. If the Graph has no stack traces, then no work is done beyond allocating an empty hash map and doing some failing lookups in it. To enable testing, added a new parameter to ToGraphDef() and related utility methods converting Graph to GraphDef, include_debug_info. When set to true, these methods populate the GraphDef's debug_info field. PiperOrigin-RevId: 547253176 --- tensorflow/cc/framework/scope.cc | 4 +- tensorflow/cc/framework/scope.h | 6 +- tensorflow/core/graph/graph.cc | 12 +- tensorflow/core/graph/graph.h | 16 ++- tensorflow/core/graph/graph_partition.cc | 27 +++++ tensorflow/core/graph/graph_partition_test.cc | 111 +++++++++++++++++- 6 files changed, 164 insertions(+), 12 deletions(-) diff --git a/tensorflow/cc/framework/scope.cc b/tensorflow/cc/framework/scope.cc index d9ed820e15e24e..6667b6919d52e6 100644 --- a/tensorflow/cc/framework/scope.cc +++ b/tensorflow/cc/framework/scope.cc @@ -306,11 +306,11 @@ void Scope::UpdateStatus(const Status& s) const { } } -Status Scope::ToGraphDef(GraphDef* gdef) const { +Status Scope::ToGraphDef(GraphDef* gdef, bool include_debug_info) const { if (!ok()) { return *impl()->status_; } - graph()->ToGraphDef(gdef); + graph()->ToGraphDef(gdef, /*include_flib_def=*/true, include_debug_info); return OkStatus(); } diff --git a/tensorflow/cc/framework/scope.h b/tensorflow/cc/framework/scope.h index 777d0ed6c01e39..771fdaa11688c9 100644 --- a/tensorflow/cc/framework/scope.h +++ b/tensorflow/cc/framework/scope.h @@ -200,8 +200,10 @@ class Scope { /// If status() is ok, convert the Graph object stored in this scope /// to a GraphDef proto and return an ok Status. Otherwise, return the error - /// status as is without performing GraphDef conversion. - Status ToGraphDef(GraphDef* gdef) const; + /// status as is without performing GraphDef conversion. If + /// `include_debug_info` is true, populate the `debug_info` field of the + /// GraphDef from stack traces in this Graph. + Status ToGraphDef(GraphDef* gdef, bool include_debug_info = false) const; // START_SKIP_DOXYGEN diff --git a/tensorflow/core/graph/graph.cc b/tensorflow/core/graph/graph.cc index c54f73e6e7a074..2477c3cb863de3 100644 --- a/tensorflow/core/graph/graph.cc +++ b/tensorflow/core/graph/graph.cc @@ -826,8 +826,10 @@ void AddInput(NodeDef* dst, StringPiece src_name, int src_slot) { } // namespace -void Graph::ToGraphDef(GraphDef* graph_def, bool include_flib_def) const { - ToGraphDefSubRange(graph_def, /*from_node_id=*/0, include_flib_def); +void Graph::ToGraphDef(GraphDef* graph_def, bool include_flib_def, + bool include_debug_info) const { + ToGraphDefSubRange(graph_def, /*from_node_id=*/0, include_flib_def, + include_debug_info); } GraphDef Graph::ToGraphDefDebug() const { @@ -837,13 +839,17 @@ GraphDef Graph::ToGraphDefDebug() const { } void Graph::ToGraphDefSubRange(GraphDef* graph_def, int from_node_id, - bool include_flib_def) const { + bool include_flib_def, + bool include_debug_info) const { graph_def->Clear(); *graph_def->mutable_versions() = versions(); if (include_flib_def) { *graph_def->mutable_library() = ops_.ToProto(); } + if (include_debug_info) { + *graph_def->mutable_debug_info() = BuildDebugInfo(); + } graph_def->mutable_node()->Reserve(std::max(1, num_nodes() - from_node_id)); diff --git a/tensorflow/core/graph/graph.h b/tensorflow/core/graph/graph.h index 617f071a6333a3..0c83580b15134e 100644 --- a/tensorflow/core/graph/graph.h +++ b/tensorflow/core/graph/graph.h @@ -677,8 +677,14 @@ class Graph { // contain references to functions whose definition is not included. It can // make sense to do this in cases where the caller already has a copy of the // function library. + // If `include_debug_info` is true, the `debug_info` field of the GraphDef + // will be populated with stack traces from the nodes and the function + // library. Note that if `include_debug_info` is true and `include_flib_def` + // is false, then `debug_info` will contain stack traces for nodes in the + // function library, which will not itself be included in the GraphDef. void ToGraphDefSubRange(GraphDef* graph_def, int from_node_id, - bool include_flib_def = true) const; + bool include_flib_def = true, + bool include_debug_info = false) const; // Serialize to a GraphDef. `include_flib_def` indicates whether the function // library will be populated in the `graph_def`. `include_flib_def` should be @@ -687,7 +693,13 @@ class Graph { // `graph_def` is incomplete and may contain references to functions whose // definition is not included. It can make sense to do this in cases where the // caller already has a copy of the function library. - void ToGraphDef(GraphDef* graph_def, bool include_flib_def = true) const; + // If `include_debug_info` is true, the `debug_info` field of the GraphDef + // will be populated with stack traces from the nodes and the function + // library. Note that if `include_debug_info` is true and `include_flib_def` + // is false, then `debug_info` will contain stack traces for nodes in the + // function library, which will not itself be included in the GraphDef. + void ToGraphDef(GraphDef* graph_def, bool include_flib_def = true, + bool include_debug_info = false) const; // This version can be called from debugger to inspect the graph content. // Use the previous version outside debug context for efficiency reasons. diff --git a/tensorflow/core/graph/graph_partition.cc b/tensorflow/core/graph/graph_partition.cc index a4f09383c63b5d..1a3f5216c6b26b 100644 --- a/tensorflow/core/graph/graph_partition.cc +++ b/tensorflow/core/graph/graph_partition.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/core/graph/graph_partition.h" #include +#include #include #include #include @@ -32,6 +33,7 @@ limitations under the License. #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/control_flow.h" #include "tensorflow/core/graph/costmodel.h" +#include "tensorflow/core/graph/graph_debug_info_builder.h" #include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/graph/tensor_id.h" @@ -980,7 +982,10 @@ void SetIncarnation(const PartitionOptions& opts, GraphDef* gdef) { Status Partition(const PartitionOptions& opts, Graph* g, std::unordered_map* partitions) { + // TODO(b/290689453) Refactor this into smaller functions Status status; + absl::flat_hash_map> + debug_info_builders; partitions->clear(); GraphInfo g_info; @@ -1219,6 +1224,19 @@ Status Partition(const PartitionOptions& opts, Graph* g, Graph::kControlSlot); } } + + // For each partition, lazily create a GraphDebugInfoBuilder. Gather stack + // traces for the nodes in that partition into the builder. + const std::shared_ptr& stack_trace = + dst->GetStackTrace(); + if (stack_trace != nullptr) { + std::unique_ptr& builder = + debug_info_builders[dstp]; + if (!builder) { + builder = std::make_unique(); + } + builder->AccumulateStackTrace(*stack_trace, dst->name()); + } } const FunctionLibraryDefinition* flib_def = opts.flib_def; @@ -1250,6 +1268,15 @@ Status Partition(const PartitionOptions& opts, Graph* g, VLOG(1) << "Added send/recv: controls=" << num_control << ", data=" << num_data; + // For each partition, build the GraphDebugInfo for all of its nodes' stack + // traces, and add it to the GraphDef for that partition. + for (auto& it : *partitions) { + const auto& builder_iter = debug_info_builders.find(it.first); + if (builder_iter != debug_info_builders.end()) { + GraphDef& gdef = it.second; + *gdef.mutable_debug_info() = builder_iter->second->Build(); + } + } if (VLOG_IS_ON(2)) { for (auto& it : *partitions) { GraphDef* gdef = &it.second; diff --git a/tensorflow/core/graph/graph_partition_test.cc b/tensorflow/core/graph/graph_partition_test.cc index e62d20456edfa0..51f59c02897a28 100644 --- a/tensorflow/core/graph/graph_partition_test.cc +++ b/tensorflow/core/graph/graph_partition_test.cc @@ -15,9 +15,14 @@ limitations under the License. #include "tensorflow/core/graph/graph_partition.h" +#include +#include #include #include +#include +#include +#include "absl/strings/str_cat.h" #include "tensorflow/cc/ops/array_ops.h" #include "tensorflow/cc/ops/const_op.h" #include "tensorflow/cc/ops/control_flow_ops.h" @@ -58,9 +63,33 @@ using ops::Const; using ops::Identity; using ops::LoopCond; using ops::NextIteration; +using ::testing::Eq; +using ::testing::Ne; const char gpu_device[] = "/job:a/replica:0/task:0/device:GPU:0"; +class TestStackTrace : public AbstractStackTrace { + public: + explicit TestStackTrace(const std::vector frames) + : frames_(std::move(frames)) {} + + absl::Span ToFrames() const override { return frames_; } + + std::vector GetUserFrames(int limit) const override { + return frames_; + } + + StackFrame LastUserFrame() const override { return frames_.back(); } + + std::string ToString(const TracePrintingOptions& opts) const override { + auto frame = LastUserFrame(); + return absl::StrCat(frame.file_name, ":", frame.line_number, ":", + frame.function_name); + } + + std::vector frames_; +}; + string SplitByDevice(const Node* node) { return node->assigned_device_name(); } string DeviceName(const Node* node) { @@ -194,6 +223,18 @@ Output Combine(const Scope& scope, Input a, Input b) { return ConstructOp(scope, "Combine", {std::move(a), std::move(b)}); } +std::string FormatStackTrace(const GraphDebugInfo::StackTrace& stack_trace, + const GraphDebugInfo& debug_info) { + std::string result; + for (const GraphDebugInfo::FileLineCol& file_line_col : + stack_trace.file_line_cols()) { + const std::string& file = debug_info.files(file_line_col.file_index()); + absl::StrAppend(&result, file_line_col.func(), "@", file, ":", + file_line_col.line(), ".", file_line_col.col(), "\n"); + } + return result; +} + class GraphPartitionTest : public ::testing::Test { protected: GraphPartitionTest() @@ -203,8 +244,8 @@ class GraphPartitionTest : public ::testing::Test { scope_b_(Scope::NewRootScope().ExitOnError().WithDevice( "/job:a/replica:0/task:0/cpu:1")) {} - const GraphDef& ToGraphDef() { - TF_EXPECT_OK(in_.ToGraphDef(&in_graph_def_)); + const GraphDef& ToGraphDef(bool include_debug_info = false) { + TF_EXPECT_OK(in_.ToGraphDef(&in_graph_def_, include_debug_info)); return in_graph_def_; } @@ -465,7 +506,6 @@ TEST_F(GraphPartitionTest, Functions) { *fdef_lib.add_function() = test::function::XTimesFour(); TF_ASSERT_OK(in_.graph()->AddFunctionLibrary(fdef_lib)); - using namespace ::tensorflow::ops; // NOLINT(build/namespaces) auto a1 = FloatInput(in_.WithOpName("A1")); auto b1 = FloatInput(in_.WithOpName("B1")); ConstructOp(in_.WithOpName("A2"), "XTimesTwo", {a1}); @@ -523,6 +563,71 @@ TEST_F(GraphPartitionTest, SetIncarnation) { } } +TEST_F(GraphPartitionTest, GraphDebugInfo) { + GraphDef graph_def; + Output a1 = FloatInput(in_.WithOpName("A1")); + Output b1 = FloatInput(in_.WithOpName("B1")); + Combine(in_.WithOpName("B2"), a1, b1); + + Node *a1_node = nullptr, *b1_node = nullptr, *b2_node = nullptr; + for (Node* node : in_.graph()->op_nodes()) { + if (node->name() == "A1") { + a1_node = node; + } else if (node->name() == "B1") { + b1_node = node; + } else if (node->name() == "B2") { + b2_node = node; + } + } + EXPECT_NE(a1_node, nullptr); + EXPECT_NE(b1_node, nullptr); + EXPECT_NE(b2_node, nullptr); + + TestStackTrace a1_stack_trace( + std::vector{{"main.cc", 20, "x"}, {"alpha.cc", 30, "a1"}}); + TestStackTrace b1_stack_trace( + std::vector{{"window.cc", 21, "y"}, {"beta.cc", 35, "b1"}}); + TestStackTrace b2_stack_trace( + std::vector{{"cache.cc", 22, "bar"}, {"beta.cc", 39, "b2"}}); + a1_node->SetStackTrace(std::make_shared(a1_stack_trace)); + b1_node->SetStackTrace(std::make_shared(b1_stack_trace)); + b2_node->SetStackTrace(std::make_shared(b2_stack_trace)); + + TF_EXPECT_OK(in_.ToGraphDef(&graph_def, /*include_debug_info=*/true)); + + // `Partition()` uses the first letter of the op name ('A' or 'B') to choose a + // device for each node. It calls the function under test, also named + // `Partition()`, to do the actual partitioning. + Partition(ToGraphDef(/*include_debug_info=*/true), &partitions_); + EXPECT_EQ(2, partitions_.size()); + + // Expect each partitioned graph to contain the stack traces for its nodes. + // A stack trace for A1 should be in the A partition (".../cpu:0"). + string a = "/job:a/replica:0/task:0/cpu:0"; + const GraphDebugInfo& a_debug_info = partitions_[a].debug_info(); + const auto& a_it = a_debug_info.traces().find("A1"); + EXPECT_EQ(1, a_debug_info.traces().size()); + EXPECT_THAT(a_it, Ne(a_debug_info.traces().end())); + EXPECT_THAT(FormatStackTrace(a_it->second, a_debug_info), + Eq("x@main.cc:20.0\n" + "a1@alpha.cc:30.0\n")); + + // Stack traces for B1 and B2 should be in the B partition (".../cpu:1"). + string b = "/job:a/replica:0/task:0/cpu:1"; + const GraphDebugInfo& b_debug_info = partitions_[b].debug_info(); + const auto& b1_it = b_debug_info.traces().find("B1"); + const auto& b2_it = b_debug_info.traces().find("B2"); + EXPECT_EQ(2, b_debug_info.traces().size()); + EXPECT_THAT(b1_it, Ne(b_debug_info.traces().end())); + EXPECT_THAT(b2_it, Ne(b_debug_info.traces().end())); + EXPECT_THAT(FormatStackTrace(b1_it->second, b_debug_info), + Eq("y@window.cc:21.0\n" + "b1@beta.cc:35.0\n")); + EXPECT_THAT(FormatStackTrace(b2_it->second, b_debug_info), + Eq("bar@cache.cc:22.0\n" + "b2@beta.cc:39.0\n")); +} + TEST(TopologicalSortNodesWithTimePriorityTest, NoDependencies) { // Create placeholders, shuffle them so the order in the graph is not strictly // increasing. From 9bc114a4958050452a593a4123259a0cc64cde45 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Tue, 11 Jul 2023 12:27:05 -0700 Subject: [PATCH 142/376] [xla:gpu] Disable cuda graphs when running auto tuning PiperOrigin-RevId: 547264135 --- tensorflow/compiler/xla/service/gpu/autotuner_compile_util.cc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tensorflow/compiler/xla/service/gpu/autotuner_compile_util.cc b/tensorflow/compiler/xla/service/gpu/autotuner_compile_util.cc index e5ee60c7f73860..d799987126e681 100644 --- a/tensorflow/compiler/xla/service/gpu/autotuner_compile_util.cc +++ b/tensorflow/compiler/xla/service/gpu/autotuner_compile_util.cc @@ -212,6 +212,8 @@ StatusOr> AutotunerCompileUtil::RunBackend( options.set_xla_gpu_dump_autotune_results_to(""); options.set_xla_gpu_load_autotune_results_from(""); options.set_xla_gpu_dump_llvmir(false); + // Avoid using Gpu graphs as we don't want to measure graph construction time. + options.set_xla_gpu_cuda_graph_level(0); // Avoid using another thread pool. options.set_xla_gpu_force_compilation_parallelism(1); module->config().set_debug_options(options); From ad14d68d5b146de6539b209915afb5c9204d850f Mon Sep 17 00:00:00 2001 From: Richard Levasseur Date: Tue, 11 Jul 2023 12:29:33 -0700 Subject: [PATCH 143/376] Internal Code Change PiperOrigin-RevId: 547264730 --- tensorflow/tensorflow.bzl | 58 +++++++++++++++++++++------------------ 1 file changed, 32 insertions(+), 26 deletions(-) diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl index faa9b490b6d324..efda8f31ff6542 100644 --- a/tensorflow/tensorflow.bzl +++ b/tensorflow/tensorflow.bzl @@ -1314,32 +1314,6 @@ generate_op_reg_offsets = rule( implementation = _generate_op_reg_offsets_impl, ) -# Generates a Python library target wrapping the ops registered in "deps". -# -# Args: -# name: used as the name of the generated target and as a name component of -# the intermediate files. -# out: name of the python file created by this rule. If None, then -# "ops/gen_{name}.py" is used. -# hidden: Optional list of ops names to make private in the Python module. -# It is invalid to specify both "hidden" and "op_allowlist". -# visibility: passed to py_library. -# deps: list of dependencies for the intermediate tool used to generate the -# python target. NOTE these `deps` are not applied to the final python -# library target itself. -# require_shape_functions: Unused. Leave this as False. -# hidden_file: optional file that contains a list of op names to make private -# in the generated Python module. Each op name should be on a line by -# itself. Lines that start with characters that are invalid op name -# starting characters are treated as comments and ignored. -# generated_target_name: name of the generated target (overrides the -# "name" arg) -# op_whitelist: [DEPRECATED] an older spelling for "op_allowlist" -# op_allowlist: if not empty, only op names in this list will be wrapped. It -# is invalid to specify both "hidden" and "op_allowlist". -# cc_linkopts: Optional linkopts to be added to tf_cc_binary that contains the -# specified ops. - def tf_gen_op_wrapper_py( name, out = None, @@ -1358,6 +1332,38 @@ def tf_gen_op_wrapper_py( copts = [], extra_py_deps = None, py_lib_rule = native.py_library): + """Generates a Python library target wrapping the ops registered in "deps". + + Args: + name: used as the name of the generated target and as a name component of + the intermediate files. + out: name of the python file created by this rule. If None, then + "ops/gen_{name}.py" is used. + hidden: Optional list of ops names to make private in the Python module. + It is invalid to specify both "hidden" and "op_allowlist". + visibility: passed to py_library. + deps: list of dependencies for the intermediate tool used to generate the + python target. NOTE these `deps` are not applied to the final python + library target itself. + require_shape_functions: Unused. Leave this as False. + hidden_file: optional file that contains a list of op names to make private + in the generated Python module. Each op name should be on a line by + itself. Lines that start with characters that are invalid op name + starting characters are treated as comments and ignored. + generated_target_name: name of the generated target (overrides the + "name" arg) + op_whitelist: [DEPRECATED] an older spelling for "op_allowlist" + op_allowlist: if not empty, only op names in this list will be wrapped. It + is invalid to specify both "hidden" and "op_allowlist". + cc_linkopts: Optional linkopts to be added to tf_cc_binary that contains the + specified ops. + api_def_srcs: undocumented. + compatible_with: undocumented. + testonly: undocumented. + copts: undocumented. + extra_py_deps: undocumented. + py_lib_rule: undocumented. + """ _ = require_shape_functions # Unused. if op_whitelist and op_allowlist: fail("op_whitelist is deprecated. Only use op_allowlist.") From 138866e95e429b2a1f0cf81716c7884834093d98 Mon Sep 17 00:00:00 2001 From: Adam Cogdell Date: Tue, 11 Jul 2023 12:57:32 -0700 Subject: [PATCH 144/376] Ensure scatter_dims_to_operand_dims is not an out of range index. PiperOrigin-RevId: 547272211 --- .../xla/service/gpu/ir_emitter_unnested.cc | 4 +++ .../compiler/xla/service/hlo_verifier.cc | 13 +++++++++ .../compiler/xla/service/hlo_verifier_test.cc | 27 +++++++++++++++++++ 3 files changed, 44 insertions(+) diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index a399816dc67739..a8524de7f3bc46 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -2637,6 +2637,10 @@ Status IrEmitterUnnested::EmitScatter( index.GetType()); int64_t operand_dim = desc.dim_numbers.getScatterDimsToOperandDims()[i]; + if (operand_dim > rank) { + return absl::OutOfRangeError( + "The provided scatter_dims_to_operand_dims was out of range."); + } TF_ASSIGN_OR_RETURN( llvm::Value* const loaded_scatter_index, desc.scatter_indices_gen(raw_scatter_index_index.SourceIndexOfReshape( diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index 0f0a98bbb33f4e..e8fb8526164f12 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -2682,6 +2682,19 @@ class InstructionVerifier : public DfsHloVisitorWithDefault { return OkStatus(); } + Status HandleScatter(HloInstruction* scatter) override { + int64_t rank = scatter->operand(0)->shape().rank(); + for (int64_t operand_dim : + scatter->scatter_dimension_numbers().scatter_dims_to_operand_dims()) { + if (operand_dim > rank) { + return absl::OutOfRangeError(absl::StrCat( + "The provided scatter_dims_to_operand_dim was out of range.", + " (operand_dim: ", operand_dim, ", rank: ", rank, ")")); + } + } + return OkStatus(); + } + Status Preprocess(HloInstruction* instruction) override { auto [it, inserted] = instructions_by_name_.emplace(instruction->name(), instruction); diff --git a/tensorflow/compiler/xla/service/hlo_verifier_test.cc b/tensorflow/compiler/xla/service/hlo_verifier_test.cc index f351eaed4ca022..1fa4f1a4a3a430 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier_test.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier_test.cc @@ -2364,6 +2364,33 @@ TEST_F(HloVerifierTest, ReduceScatterNonUniformGroups) { HasSubstr("Replica groups expected to be of uniform size")); } +TEST_F(HloVerifierTest, ScatterInvalidScatterDim) { + const char* const hlo_string = R"( + HloModule Module + add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) + } + + ENTRY CRS { + Arg_0 = s8[11,6]{1,0} parameter(0) + constant = s32[] constant(1) + broadcast = s32[1,7,9,2,16,2]{5,4,3,2,1,0} broadcast(constant), dimensions={} + Arg_1 = s8[1,7,9,2,9,4,16]{6,5,4,3,2,1,0} parameter(1) + scatter = s8[11,6]{1,0} scatter(Arg_0, broadcast, Arg_1), update_window_dims={4,5}, inserted_window_dims={}, scatter_dims_to_operand_dims={1094795585,1}, index_vector_dim=5, to_apply=add + abs = s8[11,6]{1,0} abs(scatter) + ROOT tuple = (s8[11,6]{1,0}) tuple(abs) + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnUnverifiedModule(hlo_string)); + + auto status = verifier().Run(module.get()).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.message(), + HasSubstr("Invalid scatter_dims_to_operand_dims mapping")); +} + TEST_F(HloVerifierTest, VerifyBroadcastDimensionsOrder) { const char* const hlo = R"( HloModule module From 2b1cf8a58bb614307938ea95984c77b3cea59ff9 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Tue, 11 Jul 2023 13:08:49 -0700 Subject: [PATCH 145/376] [xla:gpu] Add an option to instantiate all CUDA graphs before running executable Setting `xla_gpu_cuda_graph_num_runs_to_instantiate` to a negative value will instantiate all CUDA graphs before executing the main function PiperOrigin-RevId: 547275419 --- tensorflow/compiler/xla/runtime/executable.cc | 5 + tensorflow/compiler/xla/runtime/executable.h | 4 + .../compiler/xla/service/gpu/runtime/BUILD | 1 + .../xla/service/gpu/runtime/executable.cc | 15 +- .../xla/service/gpu/runtime/graph_launch.cc | 220 ++++++++++++------ .../xla/service/gpu/runtime/graph_launch.h | 18 +- .../xla/stream_executor/cuda/cuda_graph.cc | 2 +- 7 files changed, 187 insertions(+), 78 deletions(-) diff --git a/tensorflow/compiler/xla/runtime/executable.cc b/tensorflow/compiler/xla/runtime/executable.cc index 357d3ce9b03a75..c2e07c5bcffda1 100644 --- a/tensorflow/compiler/xla/runtime/executable.cc +++ b/tensorflow/compiler/xla/runtime/executable.cc @@ -470,6 +470,11 @@ bool Executable::IsAsync(unsigned ordinal) const { return functions_[ordinal].results_memory_layout.has_async_results; } +std::string_view Executable::function_name(unsigned ordinal) const { + assert(ordinal < functions_.size() && "function ordinal out of bounds"); + return functions_[ordinal].name; +} + unsigned Executable::num_results(unsigned ordinal) const { assert(ordinal < functions_.size() && "function ordinal out of bounds"); return functions_[ordinal].runtime_signature.num_results(); diff --git a/tensorflow/compiler/xla/runtime/executable.h b/tensorflow/compiler/xla/runtime/executable.h index b6f76800340afd..44f6332300f473 100644 --- a/tensorflow/compiler/xla/runtime/executable.h +++ b/tensorflow/compiler/xla/runtime/executable.h @@ -162,6 +162,10 @@ class Executable { bool IsAsync(unsigned ordinal) const; bool IsAsync() const { return IsAsync(0); } + // Returns the name of the exported function with the given ordinal. + std::string_view function_name(unsigned ordinal) const; + std::string_view function_name() const { return function_name(0); } + // Returns the number of results of the exported function with given ordinal. unsigned num_results(unsigned ordinal) const; unsigned num_results() const { return num_results(0); } diff --git a/tensorflow/compiler/xla/service/gpu/runtime/BUILD b/tensorflow/compiler/xla/service/gpu/runtime/BUILD index 7082ff8996ce1c..ebaba4c2b36b2d 100644 --- a/tensorflow/compiler/xla/service/gpu/runtime/BUILD +++ b/tensorflow/compiler/xla/service/gpu/runtime/BUILD @@ -349,6 +349,7 @@ cc_library( "//tensorflow/tsl/profiler/lib:scoped_annotation_stack", "//tensorflow/tsl/profiler/lib:traceme", "//tensorflow/tsl/profiler/lib:traceme_encode", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:node_hash_map", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", diff --git a/tensorflow/compiler/xla/service/gpu/runtime/executable.cc b/tensorflow/compiler/xla/service/gpu/runtime/executable.cc index 3d90510b7d4582..1c11b72cd5e06e 100644 --- a/tensorflow/compiler/xla/service/gpu/runtime/executable.cc +++ b/tensorflow/compiler/xla/service/gpu/runtime/executable.cc @@ -35,6 +35,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/runtime/custom_call.h" #include "tensorflow/compiler/xla/service/gpu/runtime/fft.h" #include "tensorflow/compiler/xla/service/gpu/runtime/gemm.h" +#include "tensorflow/compiler/xla/service/gpu/runtime/graph_launch.h" #include "tensorflow/compiler/xla/service/gpu/runtime/io_feed.h" #include "tensorflow/compiler/xla/service/gpu/runtime/memcpy.h" #include "tensorflow/compiler/xla/service/gpu/runtime/memset.h" @@ -169,7 +170,7 @@ GpuRuntimeExecutable::GpuRuntimeExecutable( /*static*/ StatusOr> GpuRuntimeExecutable::Create(std::unique_ptr program) { - // Options for the default XLA Runtim compilation pipeline. + // Options for the default XLA Runtime compilation pipeline. runtime::CompilationPipelineOptions copts; // Populate mapping from XLA (SE) enums/structs type id to symbol names. @@ -413,6 +414,18 @@ Status GpuRuntimeExecutable::Execute( return InternalError("Failed to initialize runtime modules state: %s", state_ref.status().message()); +#if GOOGLE_CUDA + // Instantiate all CUDA graphs before executing the main function. + if (debug_options_.xla_gpu_cuda_graph_num_runs_to_instantiate() < 0) { + if (auto instantiated = graph_instances_.InstantiateAllGraphs( + run_options, executable, user_data, temp_buffer.opaque()); + !instantiated.ok()) { + return InternalError("Failed to instantiate CUDA graphs: %s", + instantiated.message()); + } + } +#endif // GOOGLE_CUDA + // Collect all emitted diagnostic messages. std::string diagnostic; runtime::DiagnosticEngine diagnostic_engine; diff --git a/tensorflow/compiler/xla/service/gpu/runtime/graph_launch.cc b/tensorflow/compiler/xla/service/gpu/runtime/graph_launch.cc index 5b18bf5ca6120c..359011ab83a32a 100644 --- a/tensorflow/compiler/xla/service/gpu/runtime/graph_launch.cc +++ b/tensorflow/compiler/xla/service/gpu/runtime/graph_launch.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/runtime/graph_launch.h" +#include #include #include #include @@ -52,10 +53,22 @@ using xla::runtime::Arguments; using xla::runtime::AsyncTaskRunner; using xla::runtime::CustomCall; using xla::runtime::Executable; +using xla::runtime::FunctionRef; +using xla::runtime::FunctionType; using xla::runtime::MemrefDesc; -using xla::runtime::ScalarArg; +using xla::runtime::MemrefType; using xla::runtime::StridedMemrefView; +#if GOOGLE_CUDA +using se::gpu::OwnedCudaGraph; + +// Captures Gpu graph by running given function in capture mode. +static absl::StatusOr CaptureGraph( + const ServiceExecutableRunOptions* run_options, + runtime::FunctionRef function_ref, Arguments& args, + CustomCall::UserData user_data); +#endif // GOOGLE_CUDA + //===----------------------------------------------------------------------===// // CUDA graphs caching. //===----------------------------------------------------------------------===// @@ -72,6 +85,86 @@ CapturedFunctionExecutionCount* CapturedFunctionExecutionCounts::operator()( return &counts_[executor]; } +Status GraphInstances::InstantiateAllGraphs( + const ServiceExecutableRunOptions* run_options, + const Executable& executable, const CustomCall::UserData& user_data, + void* ptr) { + absl::MutexLock lock(&mutex_); + + se::StreamExecutor* executor = run_options->stream()->parent(); + + // All Gpu graphs are already instantiated for a given executor. + if (instantiated_.contains(executor)) return OkStatus(); + + VLOG(3) << "Instantate all Gpu graphs in executable " << executable.name(); + + TraceMe trace("cuda.graph.instantiate_all"); + + // Initialize graph instances snapshot for a given executor. + StreamExecutorGraphInstances::Snapshot instances = + graphs_[executor].snapshot(); + + // Instantiate all Gpu graphs by calling graph capture functions with fake + // arguments. Once we'll execute them first time for real, they'll be updated + // with correct pointers. + for (unsigned ordinal = 1; ordinal < executable.num_functions(); ++ordinal) { + if (!absl::StartsWith(executable.function_name(ordinal), + "xla.gpu.cuda.graph.capture")) + continue; + + VLOG(3) << "Instantiate Gpu graph defined by capture function @" + << executable.function_name(ordinal) << " (ordinal = " << ordinal + << ")"; + + TraceMe trace_instantiation([&] { + return TraceMeEncode("cuda.graph.instantiate", {{"ordinal", ordinal}}); + }); + + FunctionRef function_ref = executable.function_ref(ordinal); + + const FunctionType& signature = executable.signature(ordinal); + assert(signature.num_results() == 0 && "unexpected number of results"); + Arguments args(signature.num_operands()); + + // Prepare arguments for the graph capture function. + for (size_t j = 0; j < signature.num_operands(); ++j) { + auto* memref = llvm::dyn_cast(signature.operand(j)); + + if (!memref) + return absl::InternalError(absl::StrFormat( + "Unsupported capture function argument type #%d", j)); + + if (memref->sizes().size() != 1) + return absl::InternalError( + absl::StrFormat("Unsupported capture function memref rank #%d: %d", + j, memref->sizes().size())); + + std::array sizes = {memref->size(0)}; + std::array strides = {1}; + + args.emplace_back(memref->element_type(), ptr, + /*offset=*/0, sizes, strides); + } + +#if GOOGLE_CUDA + // Instantiate a Gpu graph with fake arguments. + auto instantiate = [&]() -> absl::StatusOr { + TF_ASSIGN_OR_RETURN( + auto g, CaptureGraph(run_options, function_ref, args, user_data)); + TF_ASSIGN_OR_RETURN(auto e, se::gpu::InstantiateCudaGraph(std::move(g))); + return GraphInstance(0, std::move(e)); + }; + + TF_ASSIGN_OR_RETURN(GraphInstance * instance, + instances.GetOrCreate(ordinal, instantiate)); + (void)instance; +#endif // GOOGLE_CUDA + } + + instantiated_.insert(executor); + return OkStatus(); +} + //===----------------------------------------------------------------------===// // Helper structure to hash the remaining arguments' memref pointers. //===----------------------------------------------------------------------===// @@ -99,8 +192,6 @@ H AbslHashValue(H h, const RemainingArgsPtrs& m) { #if GOOGLE_CUDA -using se::gpu::OwnedCudaGraph; - static bool InDebugMode() { #ifdef NDEBUG return false; @@ -108,9 +199,26 @@ static bool InDebugMode() { return true; } +// Forwards custom call arguments to an arguments container that can be passed +// to an executable function. +static absl::Status ForwardArguments(CustomCall::RemainingArgs fwd_args, + Arguments& args) { + for (size_t i = 0; i < fwd_args.size(); ++i) { + if (auto memref = fwd_args.get(i); succeeded(memref)) { + args.emplace_back(memref->dtype, memref->data, /*offset=*/0, + memref->sizes, memref->strides); + continue; + } + + return absl::InvalidArgumentError("Unsupported argument type"); + } + + return OkStatus(); +} + static absl::StatusOr CaptureGraph( const ServiceExecutableRunOptions* run_options, - runtime::FunctionRef function_ref, CustomCall::RemainingArgs fwd_args, + runtime::FunctionRef function_ref, Arguments& args, CustomCall::UserData user_data) { // We capture graph on a borrowed stream because we do not want to // accidentally record any concurrent kernel launches from other XLA @@ -162,29 +270,6 @@ static absl::StatusOr CaptureGraph( // Graph capture function should not launch any async tasks. opts.async_task_runner = reinterpret_cast(0XDEADBEEF); - // Graph capture functions can only have index arguments for launch - // dimensions, or memrefs for passing buffers. We need to re-package custom - // call arguments into a container that can be passed to an executable - // function. - Arguments args(fwd_args.size()); - - for (size_t i = 0; i < fwd_args.size(); ++i) { - // `index` argument passed as int64_t. - if (auto idx = fwd_args.get(i); succeeded(idx)) { - args.emplace_back(*idx); - continue; - } - - // Pass `memref` argument as a MemrefDesc. - if (auto memref = fwd_args.get(i); succeeded(memref)) { - args.emplace_back(memref->dtype, memref->data, /*offset=*/0, - memref->sizes, memref->strides); - continue; - } - - return absl::InvalidArgumentError("Unsupported argument type"); - } - // Create a graph from running the graph capture function. auto captured = se::gpu::CaptureCudaGraph(capture_stream->get(), [&]() { return function_ref(args, runtime::NoResultConverter{}, opts, @@ -223,32 +308,15 @@ static absl::Status RunGraphWithoutCapture( // Graph capture function should not launch any async tasks. opts.async_task_runner = reinterpret_cast(0XDEADBEEF); - Arguments args(fwd_args.size()); + Arguments args(fwd_args.size()); + TF_RETURN_IF_ERROR(ForwardArguments(fwd_args, args)); - for (size_t i = 0; i < fwd_args.size(); ++i) { - // `index` argument passed as int64_t. - if (auto idx = fwd_args.get(i); succeeded(idx)) { - args.emplace_back(*idx); - continue; - } - - // Pass `memref` argument as a MemrefDesc. - if (auto memref = fwd_args.get(i); succeeded(memref)) { - args.emplace_back(memref->dtype, memref->data, /*offset=*/0, - memref->sizes, memref->strides); - continue; - } - - return absl::InvalidArgumentError("Unsupported argument type"); - } - - auto status = - function_ref(args, runtime::NoResultConverter{}, opts, InDebugMode()) - .status(); - if (!status.ok()) { + auto executed = + function_ref(args, runtime::NoResultConverter{}, opts, InDebugMode()); + if (!executed.ok()) { return InternalError("RunGraphWithoutCapture failed (%s): %s", diagnostic.empty() ? "" : diagnostic, - status.ToString()); + executed.status().ToString()); } return absl::OkStatus(); } @@ -272,7 +340,7 @@ static absl::Status LaunchGraph( ConcurrentRegionStatus* region_status, CustomCall::RemainingArgs fwd_args, CustomCall::FunctionOrdinal capture) { #if GOOGLE_CUDA - VLOG(1) << "Launch Cuda Graph: capture=" << capture.ordinal; + VLOG(1) << "Launch Cuda Graph: ordinal = " << capture.ordinal; // Get a reference to exported function that captures the cuda graph. runtime::FunctionRef function_ref = executable->function_ref(capture.ordinal); @@ -287,15 +355,13 @@ static absl::Status LaunchGraph( gemm_config, gpu_lock, region_status); }; - TF_ASSIGN_OR_RETURN( - std::unique_ptr> * get_count, - counts->GetOrCreate( - capture.ordinal, - []() -> absl::StatusOr>> { - return std::make_unique>(0); - })); - uint64_t count = (*get_count)->fetch_add(1); - uint64_t instantiation_threshold = + TF_ASSIGN_OR_RETURN(std::unique_ptr> * get_count, + counts->GetOrCreate(capture.ordinal, [] { + return std::make_unique>(0); + })); + + int64_t count = (*get_count)->fetch_add(1); + int64_t num_runs_to_instantiate = debug_options->xla_gpu_cuda_graph_num_runs_to_instantiate(); // TODO(ezhulenev): Cupti tracing leads to deadlocks in CUDA 11. Always fall @@ -306,23 +372,27 @@ static absl::Status LaunchGraph( bool is_profiling = tsl::profiler::ScopedAnnotationStack::IsEnabled(); #endif - if (count < instantiation_threshold || is_profiling) { - // Run captured graph directly. + if (count < num_runs_to_instantiate || is_profiling) { + VLOG(3) << "Run gpu graph in op-by-op mode: ordinal = " << capture.ordinal; return RunGraphWithoutCapture(run_options, function_ref, fwd_args, user_data()); } - TF_ASSIGN_OR_RETURN( - GraphInstance * instance, - instances->GetOrCreate( - capture.ordinal, [&]() -> absl::StatusOr { - TF_ASSIGN_OR_RETURN(auto g, CaptureGraph(run_options, function_ref, - fwd_args, user_data())); + // Instantiate Gpu graph by running graph capture function. + auto instantiate = [&]() -> absl::StatusOr { + Arguments args(fwd_args.size()); + TF_RETURN_IF_ERROR(ForwardArguments(fwd_args, args)); + + TF_ASSIGN_OR_RETURN( + auto g, CaptureGraph(run_options, function_ref, args, user_data())); - TF_ASSIGN_OR_RETURN(auto e, - se::gpu::InstantiateCudaGraph(std::move(g))); - return GraphInstance(ptrs_hash, std::move(e)); - })); + TF_ASSIGN_OR_RETURN(auto e, se::gpu::InstantiateCudaGraph(std::move(g))); + + return GraphInstance(ptrs_hash, std::move(e)); + }; + + TF_ASSIGN_OR_RETURN(GraphInstance * instance, + instances->GetOrCreate(capture.ordinal, instantiate)); { // Lock graph instance for read only access. If we'll have to update the @@ -343,9 +413,13 @@ static absl::Status LaunchGraph( // Otherwise we have to re-capture the graph and update the graph instance. VLOG(3) << "Update cached graph instance"; + + Arguments args(fwd_args.size()); + TF_RETURN_IF_ERROR(ForwardArguments(fwd_args, args)); + // Capture CUDA graph by running capture function. TF_ASSIGN_OR_RETURN( - auto g, CaptureGraph(run_options, function_ref, fwd_args, user_data())); + auto g, CaptureGraph(run_options, function_ref, args, user_data())); // At this point we have to grab a writer lock, because we might potentially // have concurrent execution of the cached graph instance. diff --git a/tensorflow/compiler/xla/service/gpu/runtime/graph_launch.h b/tensorflow/compiler/xla/service/gpu/runtime/graph_launch.h index 7171862da99a23..96dfd6a1a3d3ef 100644 --- a/tensorflow/compiler/xla/service/gpu/runtime/graph_launch.h +++ b/tensorflow/compiler/xla/service/gpu/runtime/graph_launch.h @@ -18,14 +18,15 @@ limitations under the License. #include #include -#include -#include -#include #include +#include "absl/container/flat_hash_set.h" #include "absl/container/node_hash_map.h" #include "tensorflow/compiler/xla/runtime/custom_call_registry.h" +#include "tensorflow/compiler/xla/runtime/executable.h" +#include "tensorflow/compiler/xla/service/service_executable_run_options.h" #include "tensorflow/compiler/xla/stream_executor/stream_executor.h" +#include "tensorflow/compiler/xla/stream_executor/stream_executor_pimpl.h" #if GOOGLE_CUDA #include "tensorflow/compiler/xla/stream_executor/cuda/cuda_graph.h" @@ -83,10 +84,21 @@ class GraphInstances { public: StreamExecutorGraphInstances* operator()(se::StreamExecutor* executor); + // Instantiates all Gpu graphs defined by the given executable using user + // provided run options. This guarantees that once we start execution, all Gpu + // graphs are ready, and will only require cheap update operation and will not + // require allocating new resources (we avoid non deterministic OOM errors). + Status InstantiateAllGraphs(const ServiceExecutableRunOptions* run_options, + const runtime::Executable& executable, + const runtime::CustomCall::UserData& user_data, + void* ptr); + private: mutable absl::Mutex mutex_; absl::node_hash_map graphs_ ABSL_GUARDED_BY(mutex_); + absl::flat_hash_set instantiated_ + ABSL_GUARDED_BY(mutex_); }; // Xla executable keeps a mapping from stream executors to execution counts. diff --git a/tensorflow/compiler/xla/stream_executor/cuda/cuda_graph.cc b/tensorflow/compiler/xla/stream_executor/cuda/cuda_graph.cc index 23de42ff1763ae..b4e9e3729090bc 100644 --- a/tensorflow/compiler/xla/stream_executor/cuda/cuda_graph.cc +++ b/tensorflow/compiler/xla/stream_executor/cuda/cuda_graph.cc @@ -70,7 +70,7 @@ void CudaGraphSupport::DestroyGraphExec::operator()(cudaGraphExec_t instance) { tsl::Status OwnedCudaGraphExec::Update(OwnedCudaGraph graph) { VLOG(3) << "Update CUDA graph exec with a new graph after " << num_launches_ - << " launches since last update " + << " launches since last update" << " #" << num_updates_++; num_launches_ = 0; From 9af07b52871bdb711beaf5b058edef64b592f27c Mon Sep 17 00:00:00 2001 From: Fiona Lang Date: Tue, 11 Jul 2023 13:17:23 -0700 Subject: [PATCH 146/376] Move compiler imports from python/__init__.py to python/modules_with_exports.py. PiperOrigin-RevId: 547277744 --- tensorflow/python/BUILD | 3 +++ tensorflow/python/__init__.py | 7 ------- tensorflow/python/modules_with_exports.py | 5 +++++ 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 7aeba32095772a..5d37854d700fc7 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -344,6 +344,9 @@ py_library( ":no_contrib", ":tf2", "//tensorflow/core/function/trace_type", + "//tensorflow/python/compiler/mlir", + "//tensorflow/python/compiler/xla", + "//tensorflow/python/compiler/xla:compiler_py", "//tensorflow/python/data", "//tensorflow/python/distribute", "//tensorflow/python/distribute:merge_call_interim", diff --git a/tensorflow/python/__init__.py b/tensorflow/python/__init__.py index 174befc5b6c522..4e1807e79f9905 100644 --- a/tensorflow/python/__init__.py +++ b/tensorflow/python/__init__.py @@ -77,13 +77,6 @@ from tensorflow.python.debug.lib import dumping_callback from tensorflow.python.ops import gen_debug_ops -# XLA JIT compiler APIs. -from tensorflow.python.compiler.xla import jit -from tensorflow.python.compiler.xla import xla - -# MLIR APIs. -from tensorflow.python.compiler.mlir import mlir - # Update dispatch decorator docstrings to contain lists of registered APIs. # (This should come after any imports that register APIs.) from tensorflow.python.util import dispatch diff --git a/tensorflow/python/modules_with_exports.py b/tensorflow/python/modules_with_exports.py index b65d583d8d1971..2edd9ae88d5ba0 100644 --- a/tensorflow/python/modules_with_exports.py +++ b/tensorflow/python/modules_with_exports.py @@ -31,6 +31,11 @@ from tensorflow.core.protobuf.config_pb2 import * from tensorflow.core.util.event_pb2 import * +# Compiler +from tensorflow.python.compiler.xla import jit +from tensorflow.python.compiler.xla import xla +from tensorflow.python.compiler.mlir import mlir + # Data from tensorflow.python import data From bfe68b5e78d6430eb06c254f379dfa5f55b85b10 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 11 Jul 2023 13:40:16 -0700 Subject: [PATCH 147/376] Integrate LLVM at llvm/llvm-project@be29fe2f987b Updates LLVM usage to match [be29fe2f987b](https://github.com/llvm/llvm-project/commit/be29fe2f987b) PiperOrigin-RevId: 547284303 --- .../transforms/vectorization/vectorize_for_cpu.cc | 3 ++- third_party/llvm/generated.patch | 11 +++++++++++ third_party/llvm/workspace.bzl | 4 ++-- 3 files changed, 15 insertions(+), 3 deletions(-) diff --git a/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/vectorization/vectorize_for_cpu.cc b/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/vectorization/vectorize_for_cpu.cc index c07760ff679544..5c7496c4193d17 100644 --- a/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/vectorization/vectorize_for_cpu.cc +++ b/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/vectorization/vectorize_for_cpu.cc @@ -30,6 +30,7 @@ limitations under the License. #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tensor/Transforms/Transforms.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/OpDefinition.h" @@ -386,7 +387,7 @@ struct VectorizeForCPUPass ThloReverseVectorizationPattern, TransferReadOfOneDimExpandShape>(ctx); tensor::CastOp::getCanonicalizationPatterns(patterns, ctx); - vector::populateVectorTransferTensorSliceTransforms(patterns); + tensor::populateFoldTensorSubsetIntoVectorTransferPatterns(patterns); if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns)))) return signalPassFailure(); } diff --git a/third_party/llvm/generated.patch b/third_party/llvm/generated.patch index 509398da979e83..a7f99f08514996 100644 --- a/third_party/llvm/generated.patch +++ b/third_party/llvm/generated.patch @@ -1 +1,12 @@ Auto generated patch. Do not edit or delete it, even if empty. +diff -ruN --strip-trailing-cr a/llvm/lib/Target/WebAssembly/MCTargetDesc/WebAssemblyMCCodeEmitter.cpp b/llvm/lib/Target/WebAssembly/MCTargetDesc/WebAssemblyMCCodeEmitter.cpp +--- a/llvm/lib/Target/WebAssembly/MCTargetDesc/WebAssemblyMCCodeEmitter.cpp ++++ b/llvm/lib/Target/WebAssembly/MCTargetDesc/WebAssemblyMCCodeEmitter.cpp +@@ -127,6 +127,7 @@ + Ctx.reportError( + SMLoc(), + Twine("Wasm globals should only be accessed symbolically!")); ++ break; + default: + encodeULEB128(uint64_t(MO.getImm()), OS); + } diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl index d544405fd05954..f3bd820fe1217d 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "cf410b181f8c546b9ae4cd65a82d08e65bacec82" - LLVM_SHA256 = "b46fea00b4d661444425f4dcd39f5eb12f6a5d8c4964e8e0f3c8e0e601490476" + LLVM_COMMIT = "be29fe2f987b5bf58d7f6aa77c06e58d9402064a" + LLVM_SHA256 = "96b8dbd215400b2434823ae57a5dd53f84cb2162001a31f2ea65fdfe3c06e9ab" tf_http_archive( name = name, From 89cae4a671a3fb76d38bef8e6fbc38ef7d5b55e6 Mon Sep 17 00:00:00 2001 From: David Dunleavy Date: Tue, 11 Jul 2023 13:51:44 -0700 Subject: [PATCH 148/376] [NFC] Change uses of get_compatible_with_cloud to get_compatible_with_portable. PiperOrigin-RevId: 547287511 --- .../xla/mlir/backends/cpu/transforms/BUILD | 4 +- .../xla/mlir/backends/gpu/transforms/BUILD | 4 +- .../compiler/xla/mlir/framework/ir/BUILD | 6 +- .../xla/mlir/framework/transforms/BUILD | 4 +- .../compiler/xla/mlir/math/transforms/BUILD | 6 +- .../compiler/xla/mlir/memref/transforms/BUILD | 6 +- tensorflow/compiler/xla/mlir/runtime/BUILD | 4 +- tensorflow/compiler/xla/mlir/runtime/ir/BUILD | 10 +-- .../xla/mlir/runtime/transforms/BUILD | 28 ++++---- .../compiler/xla/mlir/runtime/utils/BUILD | 12 ++-- tensorflow/compiler/xla/mlir/utils/BUILD | 4 +- tensorflow/compiler/xla/mlir/xla_cpu/ir/BUILD | 10 +-- tensorflow/compiler/xla/mlir_hlo/BUILD | 72 +++++++++---------- .../xla/python/profiler/internal/BUILD | 4 +- tensorflow/compiler/xla/runtime/BUILD | 52 +++++++------- tensorflow/compiler/xla/runtime/ffi/BUILD | 8 +-- tensorflow/compiler/xla/service/gpu/BUILD | 24 +++---- .../compiler/xla/translate/mhlo_to_hlo/BUILD | 4 +- 18 files changed, 131 insertions(+), 131 deletions(-) diff --git a/tensorflow/compiler/xla/mlir/backends/cpu/transforms/BUILD b/tensorflow/compiler/xla/mlir/backends/cpu/transforms/BUILD index 55fa13e9372023..eee5a6e926911f 100644 --- a/tensorflow/compiler/xla/mlir/backends/cpu/transforms/BUILD +++ b/tensorflow/compiler/xla/mlir/backends/cpu/transforms/BUILD @@ -1,5 +1,5 @@ load("//tensorflow/tsl/platform:rules_cc.bzl", "cc_library") -load("//tensorflow/tsl:tsl.default.bzl", "get_compatible_with_cloud") +load("//tensorflow/tsl:tsl.default.bzl", "get_compatible_with_portable") load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") package( @@ -10,7 +10,7 @@ package( gentbl_cc_library( name = "passes_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( [ diff --git a/tensorflow/compiler/xla/mlir/backends/gpu/transforms/BUILD b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/BUILD index 1f192d5507e424..4b43a94b860c4f 100644 --- a/tensorflow/compiler/xla/mlir/backends/gpu/transforms/BUILD +++ b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/BUILD @@ -1,5 +1,5 @@ load("//tensorflow/tsl/platform:rules_cc.bzl", "cc_library") -load("//tensorflow/tsl:tsl.default.bzl", "get_compatible_with_cloud") +load("//tensorflow/tsl:tsl.default.bzl", "get_compatible_with_portable") load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") package( @@ -10,7 +10,7 @@ package( gentbl_cc_library( name = "passes_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( [ diff --git a/tensorflow/compiler/xla/mlir/framework/ir/BUILD b/tensorflow/compiler/xla/mlir/framework/ir/BUILD index 07f6d6404661ea..d1d021e7ba79e7 100644 --- a/tensorflow/compiler/xla/mlir/framework/ir/BUILD +++ b/tensorflow/compiler/xla/mlir/framework/ir/BUILD @@ -1,6 +1,6 @@ load("//tensorflow/tsl/platform:rules_cc.bzl", "cc_library") load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") -load("//tensorflow/tsl:tsl.default.bzl", "get_compatible_with_cloud") +load("//tensorflow/tsl:tsl.default.bzl", "get_compatible_with_portable") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -13,7 +13,7 @@ td_library( srcs = [ "xla_framework_ops.td", ], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "@llvm-project//mlir:ControlFlowInterfacesTdFiles", "@llvm-project//mlir:OpBaseTdFiles", @@ -23,7 +23,7 @@ td_library( gentbl_cc_library( name = "xla_framework_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["-gen-op-decls"], diff --git a/tensorflow/compiler/xla/mlir/framework/transforms/BUILD b/tensorflow/compiler/xla/mlir/framework/transforms/BUILD index 5a1fddaa0a9278..52b3ad191bec19 100644 --- a/tensorflow/compiler/xla/mlir/framework/transforms/BUILD +++ b/tensorflow/compiler/xla/mlir/framework/transforms/BUILD @@ -1,6 +1,6 @@ load("//tensorflow/tsl/platform:rules_cc.bzl", "cc_library") load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") -load("//tensorflow/tsl:tsl.default.bzl", "get_compatible_with_cloud") +load("//tensorflow/tsl:tsl.default.bzl", "get_compatible_with_portable") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -10,7 +10,7 @@ package( gentbl_cc_library( name = "passes_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( [ diff --git a/tensorflow/compiler/xla/mlir/math/transforms/BUILD b/tensorflow/compiler/xla/mlir/math/transforms/BUILD index 271cc519266e67..a3a44b926f1393 100644 --- a/tensorflow/compiler/xla/mlir/math/transforms/BUILD +++ b/tensorflow/compiler/xla/mlir/math/transforms/BUILD @@ -1,5 +1,5 @@ load("//tensorflow/tsl/platform:rules_cc.bzl", "cc_library") -load("//tensorflow/tsl:tsl.default.bzl", "get_compatible_with_cloud") +load("//tensorflow/tsl:tsl.default.bzl", "get_compatible_with_portable") load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") package( @@ -10,7 +10,7 @@ package( gentbl_cc_library( name = "passes_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( [ @@ -32,7 +32,7 @@ cc_library( "math_optimization.cc", ], hdrs = ["passes.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ ":passes_inc_gen", "@llvm-project//mlir:ArithDialect", diff --git a/tensorflow/compiler/xla/mlir/memref/transforms/BUILD b/tensorflow/compiler/xla/mlir/memref/transforms/BUILD index cd77dc34eea949..d1372b1feb59e1 100644 --- a/tensorflow/compiler/xla/mlir/memref/transforms/BUILD +++ b/tensorflow/compiler/xla/mlir/memref/transforms/BUILD @@ -1,5 +1,5 @@ load("//tensorflow/tsl/platform:rules_cc.bzl", "cc_library") -load("//tensorflow/tsl:tsl.default.bzl", "get_compatible_with_cloud") +load("//tensorflow/tsl:tsl.default.bzl", "get_compatible_with_portable") load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") package( @@ -10,7 +10,7 @@ package( gentbl_cc_library( name = "passes_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( [ @@ -29,7 +29,7 @@ cc_library( name = "passes", srcs = ["aligned_allocations.cc"], hdrs = ["passes.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ ":passes_inc_gen", "@llvm-project//llvm:Support", diff --git a/tensorflow/compiler/xla/mlir/runtime/BUILD b/tensorflow/compiler/xla/mlir/runtime/BUILD index 6c2504899bfd01..e36272acd6e0cd 100644 --- a/tensorflow/compiler/xla/mlir/runtime/BUILD +++ b/tensorflow/compiler/xla/mlir/runtime/BUILD @@ -1,5 +1,5 @@ load("//tensorflow/compiler/xla:xla.bzl", "xla_cc_binary") -load("//tensorflow/tsl:tsl.default.bzl", "get_compatible_with_cloud") +load("//tensorflow/tsl:tsl.default.bzl", "get_compatible_with_portable") load("@bazel_skylib//rules:build_test.bzl", "build_test") package_group( @@ -38,7 +38,7 @@ build_test( xla_cc_binary( name = "xla-runtime-opt", srcs = ["xla-runtime-opt.cc"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "//tensorflow/compiler/xla/mlir/math/transforms:passes", "//tensorflow/compiler/xla/mlir/memref/transforms:passes", diff --git a/tensorflow/compiler/xla/mlir/runtime/ir/BUILD b/tensorflow/compiler/xla/mlir/runtime/ir/BUILD index 2f7875bb9116ff..49b659fe9dcc5d 100644 --- a/tensorflow/compiler/xla/mlir/runtime/ir/BUILD +++ b/tensorflow/compiler/xla/mlir/runtime/ir/BUILD @@ -1,5 +1,5 @@ load("//tensorflow/tsl/platform:rules_cc.bzl", "cc_library") -load("//tensorflow/tsl:tsl.default.bzl", "get_compatible_with_cloud") +load("//tensorflow/tsl:tsl.default.bzl", "get_compatible_with_portable") load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") package( @@ -15,7 +15,7 @@ td_library( "rt_interfaces.td", "rt_ops.td", ], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), includes = ["include"], visibility = ["//visibility:private"], deps = [ @@ -27,7 +27,7 @@ td_library( gentbl_cc_library( name = "rt_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["-gen-dialect-decls"], @@ -69,7 +69,7 @@ gentbl_cc_library( gentbl_cc_library( name = "rt_interfaces_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["-gen-attr-interface-decls"], @@ -97,7 +97,7 @@ cc_library( "rt_interfaces.h", "rt_ops.h", ], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ ":rt_inc_gen", ":rt_interfaces_inc_gen", diff --git a/tensorflow/compiler/xla/mlir/runtime/transforms/BUILD b/tensorflow/compiler/xla/mlir/runtime/transforms/BUILD index 1fe391993265dd..2401fdd984d85b 100644 --- a/tensorflow/compiler/xla/mlir/runtime/transforms/BUILD +++ b/tensorflow/compiler/xla/mlir/runtime/transforms/BUILD @@ -1,6 +1,6 @@ load("//tensorflow/compiler/xla:xla.bzl", "xla_cc_test") load("//tensorflow/tsl/platform:rules_cc.bzl", "cc_library") -load("//tensorflow/tsl:tsl.default.bzl", "get_compatible_with_cloud") +load("//tensorflow/tsl:tsl.default.bzl", "get_compatible_with_portable") load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") load("//tensorflow/tsl/platform:build_config.bzl", "if_llvm_system_z_available") @@ -12,7 +12,7 @@ package( gentbl_cc_library( name = "passes_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( [ @@ -37,7 +37,7 @@ cc_library( "rt_to_llvm.cc", ], hdrs = ["passes.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ ":custom_call_encoding", ":passes_inc_gen", @@ -66,7 +66,7 @@ cc_library( name = "calling_convention", srcs = ["calling_convention.cc"], hdrs = ["calling_convention.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "//tensorflow/compiler/xla/mlir/runtime/ir:rt", "@llvm-project//mlir:FuncDialect", @@ -78,7 +78,7 @@ cc_library( xla_cc_test( name = "calling_convention_test", srcs = ["calling_convention_test.cc"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ ":calling_convention", "//tensorflow/compiler/xla/mlir/runtime/ir:rt", @@ -95,7 +95,7 @@ cc_library( name = "compilation_pipeline_cpu", srcs = ["compilation_pipeline_cpu.cc"], hdrs = ["compilation_pipeline_cpu.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), visibility = ["//visibility:public"], deps = [ ":compilation_pipeline_options", @@ -150,7 +150,7 @@ cc_library( name = "compilation_pipeline_gpu", srcs = ["compilation_pipeline_gpu.cc"], hdrs = ["compilation_pipeline_gpu.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), visibility = ["//visibility:public"], deps = [ ":compilation_pipeline_options", @@ -184,7 +184,7 @@ cc_library( cc_library( name = "compilation_pipeline_options", hdrs = ["compilation_pipeline_options.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ ":custom_call_encoding", "//tensorflow/compiler/xla/runtime:type_id", @@ -196,7 +196,7 @@ cc_library( name = "custom_call_encoding", srcs = ["custom_call_encoding.cc"], hdrs = ["custom_call_encoding.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla/mlir/runtime/ir:rt", @@ -219,7 +219,7 @@ cc_library( name = "jit_compiler", srcs = ["jit_compiler.cc"], hdrs = ["jit_compiler.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ ":calling_convention", ":compiler", @@ -267,7 +267,7 @@ cc_library( name = "specialization", srcs = ["specialization.cc"], hdrs = ["specialization.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ ":type_converter", "//tensorflow/compiler/xla/mlir/runtime/utils:constraints", @@ -292,7 +292,7 @@ cc_library( name = "type_converter", srcs = ["type_converter.cc"], hdrs = ["type_converter.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla/mlir/runtime/ir:rt", @@ -309,7 +309,7 @@ cc_library( xla_cc_test( name = "type_converter_test", srcs = ["type_converter_test.cc"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ ":type_converter", "//tensorflow/compiler/xla/runtime:types", @@ -322,7 +322,7 @@ xla_cc_test( cc_library( name = "compiler", hdrs = ["compiler.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", diff --git a/tensorflow/compiler/xla/mlir/runtime/utils/BUILD b/tensorflow/compiler/xla/mlir/runtime/utils/BUILD index 0b552cd7e5c2aa..0bec5a090be1d3 100644 --- a/tensorflow/compiler/xla/mlir/runtime/utils/BUILD +++ b/tensorflow/compiler/xla/mlir/runtime/utils/BUILD @@ -1,4 +1,4 @@ -load("//tensorflow/tsl:tsl.default.bzl", "get_compatible_with_cloud") +load("//tensorflow/tsl:tsl.default.bzl", "get_compatible_with_portable") load("//tensorflow/tsl/platform:rules_cc.bzl", "cc_library") package( @@ -11,7 +11,7 @@ cc_library( name = "async_runtime_api", srcs = ["async_runtime_api.cc"], hdrs = ["async_runtime_api.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "//tensorflow/compiler/xla/runtime:async_runtime", "//tensorflow/tsl/platform:platform_port", @@ -26,7 +26,7 @@ cc_library( cc_library( name = "c_runner_utils", hdrs = ["c_runner_utils.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "@llvm-project//llvm:OrcJIT", "@llvm-project//mlir:mlir_c_runner_utils", @@ -37,7 +37,7 @@ cc_library( name = "constraints", srcs = ["constraints.cc"], hdrs = ["constraints.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "//tensorflow/compiler/xla/runtime:constraints", "//tensorflow/compiler/xla/runtime:errors", @@ -55,7 +55,7 @@ cc_library( name = "custom_calls", srcs = ["custom_calls.cc"], hdrs = ["custom_calls.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", @@ -66,7 +66,7 @@ cc_library( cc_library( name = "float_16bits", hdrs = ["float_16bits.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "@llvm-project//llvm:OrcJIT", "@llvm-project//mlir:mlir_float16_utils", diff --git a/tensorflow/compiler/xla/mlir/utils/BUILD b/tensorflow/compiler/xla/mlir/utils/BUILD index 9e9a6f974d91a3..9d8e0449bf34fb 100644 --- a/tensorflow/compiler/xla/mlir/utils/BUILD +++ b/tensorflow/compiler/xla/mlir/utils/BUILD @@ -1,4 +1,4 @@ -load("//tensorflow/tsl:tsl.default.bzl", "get_compatible_with_cloud") +load("//tensorflow/tsl:tsl.default.bzl", "get_compatible_with_portable") load("//tensorflow/tsl/platform:rules_cc.bzl", "cc_library") package( @@ -13,7 +13,7 @@ cc_library( name = "error_util", srcs = ["error_util.cc"], hdrs = ["error_util.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "//tensorflow/tsl/platform:errors", "@com_google_absl//absl/status", diff --git a/tensorflow/compiler/xla/mlir/xla_cpu/ir/BUILD b/tensorflow/compiler/xla/mlir/xla_cpu/ir/BUILD index 33922080c1dcd8..2737d23adda775 100644 --- a/tensorflow/compiler/xla/mlir/xla_cpu/ir/BUILD +++ b/tensorflow/compiler/xla/mlir/xla_cpu/ir/BUILD @@ -1,6 +1,6 @@ load("//tensorflow/tsl/platform:rules_cc.bzl", "cc_library") load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") -load("//tensorflow/tsl:tsl.default.bzl", "get_compatible_with_cloud") +load("//tensorflow/tsl:tsl.default.bzl", "get_compatible_with_portable") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -14,7 +14,7 @@ td_library( "xla_cpu_enums.td", "xla_cpu_ops.td", ], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "//tensorflow/compiler/xla/mlir_hlo:hlo_ops_td_files", "@llvm-project//mlir:BufferizableOpInterfaceTdFiles", @@ -25,7 +25,7 @@ td_library( gentbl_cc_library( name = "xla_cpu_dialect_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["-gen-dialect-decls"], @@ -43,7 +43,7 @@ gentbl_cc_library( gentbl_cc_library( name = "xla_cpu_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["-gen-op-decls"], @@ -61,7 +61,7 @@ gentbl_cc_library( gentbl_cc_library( name = "xla_cpu_enums_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["-gen-enum-decls"], diff --git a/tensorflow/compiler/xla/mlir_hlo/BUILD b/tensorflow/compiler/xla/mlir_hlo/BUILD index fd47f29995196e..8cc3a88e4fe81a 100644 --- a/tensorflow/compiler/xla/mlir_hlo/BUILD +++ b/tensorflow/compiler/xla/mlir_hlo/BUILD @@ -1,5 +1,5 @@ load("//tensorflow/tsl/platform:rules_cc.bzl", "cc_library") -load("//tensorflow/tsl:tsl.default.bzl", "get_compatible_with_cloud") +load("//tensorflow/tsl:tsl.default.bzl", "get_compatible_with_portable") load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "gentbl_filegroup", "td_library") load("@bazel_skylib//rules:build_test.bzl", "build_test") @@ -25,7 +25,7 @@ filegroup( td_library( name = "hlo_ops_td_files", srcs = glob(["mhlo/IR/*.td"]), - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), includes = ["."], deps = [ "@llvm-project//mlir:BuiltinDialectTdFiles", @@ -46,7 +46,7 @@ td_library( gentbl_cc_library( name = "mhlo_pass_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), strip_include_prefix = ".", tbl_outs = [ ( @@ -64,7 +64,7 @@ gentbl_cc_library( gentbl_cc_library( name = "lmhlo_pass_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), strip_include_prefix = ".", tbl_outs = [ ( @@ -82,7 +82,7 @@ gentbl_cc_library( gentbl_cc_library( name = "hlo_ops_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), strip_include_prefix = ".", tbl_outs = [ ( @@ -101,7 +101,7 @@ gentbl_cc_library( gentbl_cc_library( name = "hlo_ops_attrs_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), strip_include_prefix = ".", tbl_outs = [ ( @@ -120,7 +120,7 @@ gentbl_cc_library( gentbl_cc_library( name = "hlo_ops_enums_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), strip_include_prefix = ".", tbl_outs = [ ( @@ -139,7 +139,7 @@ gentbl_cc_library( gentbl_cc_library( name = "hlo_ops_typedefs_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), strip_include_prefix = ".", tbl_outs = [ ( @@ -164,7 +164,7 @@ gentbl_cc_library( gentbl_cc_library( name = "hlo_ops_pattern_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), strip_include_prefix = "mhlo/IR/", tbl_outs = [ ( @@ -183,7 +183,7 @@ gentbl_cc_library( gentbl_cc_library( name = "lhlo_ops_structs_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), strip_include_prefix = ".", tbl_outs = [ ( @@ -202,7 +202,7 @@ gentbl_cc_library( gentbl_cc_library( name = "lhlo_ops_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), strip_include_prefix = ".", tbl_outs = [ ( @@ -221,7 +221,7 @@ gentbl_cc_library( gentbl_cc_library( name = "lhlo_gpu_ops_enums_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), strip_include_prefix = ".", tbl_outs = [ ( @@ -240,7 +240,7 @@ gentbl_cc_library( gentbl_cc_library( name = "lhlo_gpu_ops_dialect_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), strip_include_prefix = ".", tbl_outs = [ ( @@ -259,7 +259,7 @@ gentbl_cc_library( gentbl_cc_library( name = "lhlo_gpu_ops_attrdefs_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), strip_include_prefix = ".", tbl_outs = [ ( @@ -278,7 +278,7 @@ gentbl_cc_library( gentbl_filegroup( name = "hlo_ops_doc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( [ @@ -295,7 +295,7 @@ gentbl_filegroup( gentbl_filegroup( name = "lhlo_ops_doc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( [ @@ -324,7 +324,7 @@ cc_library( td_library( name = "lhlo_gpu_ops_td_files", srcs = glob(["lhlo_gpu/IR/*.td"]), - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), includes = ["."], deps = [ ":hlo_ops_td_files", @@ -335,7 +335,7 @@ td_library( gentbl_cc_library( name = "lhlo_gpu_ops_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), strip_include_prefix = ".", tbl_outs = [ ( @@ -355,7 +355,7 @@ gentbl_cc_library( #TODO(aminim): revisit the naming and grouping of these rules post-move. gentbl_cc_library( name = "canonicalize_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), strip_include_prefix = ".", tbl_outs = [ ( @@ -371,7 +371,7 @@ gentbl_cc_library( td_library( name = "deallocation_ops_td_files", srcs = glob(["deallocation/IR/*.td"]), - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), includes = ["."], deps = [ "@llvm-project//mlir:OpBaseTdFiles", @@ -381,7 +381,7 @@ td_library( gentbl_cc_library( name = "deallocation_ops_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), strip_include_prefix = ".", tbl_outs = [ ( @@ -455,7 +455,7 @@ cc_library( gentbl_cc_library( name = "deallocation_passes_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), strip_include_prefix = ".", tbl_outs = [ ( @@ -505,7 +505,7 @@ cc_library( td_library( name = "lhlo_ops_td_files", srcs = glob(["lhlo/IR/*.td"]), - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), includes = ["."], deps = [ ":hlo_ops_td_files", @@ -523,7 +523,7 @@ td_library( gentbl_cc_library( name = "lhlo_structured_interface_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), strip_include_prefix = ".", tbl_outs = [ ( @@ -1060,7 +1060,7 @@ cc_library( gentbl_cc_library( name = "legalize_to_standard_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), strip_include_prefix = "mhlo/transforms/", tbl_outs = [ ( @@ -1080,7 +1080,7 @@ gentbl_cc_library( gentbl_cc_library( name = "lower_complex_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), strip_include_prefix = "mhlo/transforms/", tbl_outs = [ ( @@ -1137,7 +1137,7 @@ cc_library( gentbl_cc_library( name = "chlo_legalize_to_hlo_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), strip_include_prefix = "mhlo/transforms", tbl_outs = [ ( @@ -1393,7 +1393,7 @@ cc_library( gentbl_cc_library( name = "gml_st_test_passes_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), strip_include_prefix = ".", tbl_outs = [ ( @@ -1434,7 +1434,7 @@ cc_library( gentbl_cc_library( name = "transforms_passes_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), strip_include_prefix = ".", tbl_outs = [ ( @@ -1452,7 +1452,7 @@ gentbl_cc_library( gentbl_cc_library( name = "gpu_transforms_passes_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), strip_include_prefix = ".", tbl_outs = [ ( @@ -1718,7 +1718,7 @@ filegroup( td_library( name = "gml_st_ops_td_files", srcs = glob(["gml_st/IR/*.td"]), - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), includes = ["."], deps = [ "@llvm-project//mlir:ControlFlowInterfacesTdFiles", @@ -1731,7 +1731,7 @@ td_library( gentbl_cc_library( name = "gml_st_ops_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), strip_include_prefix = ".", tbl_outs = [ ( @@ -1798,7 +1798,7 @@ cc_library( gentbl_cc_library( name = "gml_st_passes_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), strip_include_prefix = ".", tbl_outs = [ ( @@ -1839,7 +1839,7 @@ cc_library( td_library( name = "thlo_ops_td_files", srcs = glob(["thlo/IR/*.td"]), - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), includes = ["."], deps = [ "@llvm-project//mlir:ControlFlowInterfacesTdFiles", @@ -1851,7 +1851,7 @@ td_library( gentbl_cc_library( name = "thlo_ops_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), strip_include_prefix = ".", tbl_outs = [ ( @@ -1922,7 +1922,7 @@ cc_library( gentbl_cc_library( name = "thlo_passes_inc_gen", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), strip_include_prefix = ".", tbl_outs = [ ( diff --git a/tensorflow/compiler/xla/python/profiler/internal/BUILD b/tensorflow/compiler/xla/python/profiler/internal/BUILD index 8f4f0a413e8b5f..5ffef7c8e5e284 100644 --- a/tensorflow/compiler/xla/python/profiler/internal/BUILD +++ b/tensorflow/compiler/xla/python/profiler/internal/BUILD @@ -1,4 +1,4 @@ -load("//tensorflow/tsl:tsl.default.bzl", "get_compatible_with_cloud") +load("//tensorflow/tsl:tsl.default.bzl", "get_compatible_with_portable") load("//tensorflow/tsl/profiler/builds:build_config.bzl", "tf_profiler_copts") load("//tensorflow/tsl/platform:rules_cc.bzl", "cc_library") @@ -12,7 +12,7 @@ cc_library( name = "python_hooks", srcs = ["python_hooks.cc"], hdrs = ["python_hooks.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), copts = tf_profiler_copts() + ["-fexceptions"], features = ["-use_header_modules"], # Incompatible with -fexceptions. visibility = [ diff --git a/tensorflow/compiler/xla/runtime/BUILD b/tensorflow/compiler/xla/runtime/BUILD index de4875741b977c..719ea5e796e715 100644 --- a/tensorflow/compiler/xla/runtime/BUILD +++ b/tensorflow/compiler/xla/runtime/BUILD @@ -1,5 +1,5 @@ load("//tensorflow/compiler/xla:xla.bzl", "xla_cc_test") -load("//tensorflow/tsl:tsl.default.bzl", "filegroup", "get_compatible_with_cloud") +load("//tensorflow/tsl:tsl.default.bzl", "filegroup", "get_compatible_with_portable") load("//tensorflow/tsl/platform:build_config.bzl", "if_llvm_system_z_available", "tf_platform_deps") load("//tensorflow/tsl/platform:rules_cc.bzl", "cc_library") @@ -16,7 +16,7 @@ cc_library( name = "arguments", srcs = ["arguments.cc"], hdrs = ["arguments.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ ":async_runtime", ":types", @@ -42,7 +42,7 @@ cc_library( name = "async_runtime", srcs = ["async_runtime.cc"], hdrs = ["async_runtime.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "//tensorflow/tsl/platform:env", "//tensorflow/tsl/platform:platform_port", @@ -66,7 +66,7 @@ xla_cc_test( cc_library( name = "async_values_cache", hdrs = ["async_values_cache.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "//tensorflow/tsl/platform", ] + tf_platform_deps( @@ -79,7 +79,7 @@ cc_library( name = "constraints", srcs = ["constraints.cc"], hdrs = ["constraints.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -123,7 +123,7 @@ cc_library( name = "custom_call", srcs = ["custom_call.cc"], hdrs = ["custom_call.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ ":async_runtime", ":diagnostics", @@ -170,7 +170,7 @@ cc_library( name = "custom_call_registry", srcs = ["custom_call_registry.cc"], hdrs = ["custom_call_registry.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ ":custom_call", "@llvm-project//llvm:Support", @@ -181,7 +181,7 @@ cc_library( name = "diagnostics", srcs = ["diagnostics.cc"], hdrs = ["diagnostics.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ ":logical_result", "//tensorflow/tsl/platform:logging", @@ -204,7 +204,7 @@ xla_cc_test( cc_library( name = "errors", hdrs = ["errors.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "@com_google_absl//absl/status", "@com_google_absl//absl/strings:str_format", @@ -215,7 +215,7 @@ cc_library( name = "executable", srcs = ["executable.cc"], hdrs = ["executable.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ ":arguments", ":async_runtime", @@ -269,7 +269,7 @@ cc_library( name = "execution_engine", srcs = ["execution_engine.cc"], hdrs = ["execution_engine.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ ":errors", "@com_google_absl//absl/status", @@ -340,7 +340,7 @@ cc_library( name = "jit_executable", srcs = ["jit_executable.cc"], hdrs = ["jit_executable.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ ":async_values_cache", ":constraints", @@ -359,14 +359,14 @@ cc_library( cc_library( name = "logical_result", hdrs = ["logical_result.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = ["@llvm-project//mlir:Support"], ) cc_library( name = "map_by_type", hdrs = ["map_by_type.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ ":type_id", "@llvm-project//llvm:Support", @@ -388,7 +388,7 @@ cc_library( name = "memory_mapper", srcs = ["memory_mapper.cc"], hdrs = ["memory_mapper.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "//tensorflow/tsl/platform", "@llvm-project//llvm:ExecutionEngine", @@ -402,7 +402,7 @@ cc_library( cc_library( name = "memref_view", hdrs = ["memref_view.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "//tensorflow/compiler/xla:xla_data_proto_cc", "@com_google_absl//absl/types:span", @@ -412,7 +412,7 @@ cc_library( cc_library( name = "module", hdrs = ["module.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ ":custom_call_registry", "@com_google_absl//absl/status", @@ -424,7 +424,7 @@ cc_library( name = "module_registry", srcs = ["module_registry.cc"], hdrs = ["module_registry.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ ":module", ], @@ -444,7 +444,7 @@ xla_cc_test( cc_library( name = "results", hdrs = ["results.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ ":logical_result", ":types", @@ -467,13 +467,13 @@ xla_cc_test( cc_library( name = "runtime", hdrs = ["runtime.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), ) cc_library( name = "state", hdrs = ["state.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/synchronization", @@ -495,7 +495,7 @@ cc_library( name = "symbolic_shape", srcs = ["symbolic_shape.cc"], hdrs = ["symbolic_shape.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ ":arguments", ":constraints", @@ -526,7 +526,7 @@ cc_library( name = "types", srcs = ["types.cc"], hdrs = ["types.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:xla_data_proto_cc", @@ -540,7 +540,7 @@ cc_library( cc_library( name = "tracing", hdrs = ["tracing.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ ":custom_call", ":type_id", @@ -551,7 +551,7 @@ cc_library( name = "type_id", srcs = ["type_id.cc"], hdrs = ["type_id.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "@com_google_absl//absl/container:flat_hash_map", "@llvm-project//mlir:Support", @@ -561,7 +561,7 @@ cc_library( cc_library( name = "compiler", hdrs = ["compiler.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), ) cc_library( diff --git a/tensorflow/compiler/xla/runtime/ffi/BUILD b/tensorflow/compiler/xla/runtime/ffi/BUILD index c94568b48c9008..c96dcf221ce7d4 100644 --- a/tensorflow/compiler/xla/runtime/ffi/BUILD +++ b/tensorflow/compiler/xla/runtime/ffi/BUILD @@ -1,5 +1,5 @@ load("//tensorflow/tsl/platform:rules_cc.bzl", "cc_library") -load("//tensorflow/tsl:tsl.default.bzl", "filegroup", "get_compatible_with_cloud") +load("//tensorflow/tsl:tsl.default.bzl", "filegroup", "get_compatible_with_portable") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -18,13 +18,13 @@ filegroup( cc_library( name = "ffi_abi", hdrs = ["ffi_abi.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), ) cc_library( name = "ffi_api", hdrs = ["ffi_api.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ ":ffi_abi", ":ffi_c_api_hdrs", @@ -34,5 +34,5 @@ cc_library( cc_library( name = "ffi_c_api_hdrs", hdrs = ["ffi_c_api.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), ) diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 741baed5ca1e85..53ee518bb2edbc 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -26,7 +26,7 @@ load( "//tensorflow/tsl/platform/default:cuda_build_defs.bzl", "if_cuda_is_configured", ) -load("//tensorflow/tsl:tsl.default.bzl", "filegroup", "get_compatible_with_cloud") +load("//tensorflow/tsl:tsl.default.bzl", "filegroup", "get_compatible_with_portable") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -78,7 +78,7 @@ cc_library( name = "gpu_executable_run_options", srcs = ["gpu_executable_run_options.cc"], hdrs = ["gpu_executable_run_options.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), visibility = ["//visibility:public"], deps = [ "//tensorflow/compiler/xla:status_macros", @@ -117,7 +117,7 @@ cc_library( hdrs = [ "launch_dimensions.h", ], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ ":gpu_device_info", "//tensorflow/compiler/xla:shape_util", @@ -213,7 +213,7 @@ cc_library( name = "target_util", srcs = ["target_util.cc"], hdrs = ["target_util.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status", @@ -245,7 +245,7 @@ cc_library( name = "gpu_device_info", srcs = ["gpu_device_info.cc"], hdrs = ["gpu_device_info.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "//tensorflow/compiler/xla/stream_executor:device_description_proto_cc", "//tensorflow/compiler/xla/stream_executor:stream_executor_headers", @@ -257,7 +257,7 @@ cc_library( testonly = 1, srcs = ["gpu_device_info_for_tests.cc"], hdrs = ["gpu_device_info_for_tests.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ ":gpu_device_info", ], @@ -634,7 +634,7 @@ cc_library( name = "parallel_loop_emitter", srcs = ["parallel_loop_emitter.cc"], hdrs = ["parallel_loop_emitter.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ ":launch_dimensions", ":target_util", @@ -986,7 +986,7 @@ cc_library( name = "ir_emission_utils", srcs = ["ir_emission_utils.cc"], hdrs = ["ir_emission_utils.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ ":gpu_asm_opts_util", ":target_util", @@ -1027,7 +1027,7 @@ cc_library( name = "cublas_cudnn", srcs = ["cublas_cudnn.cc"], hdrs = ["cublas_cudnn.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ "//tensorflow/compiler/xla/hlo/ir:hlo", "//tensorflow/tsl/platform:statusor", @@ -1373,7 +1373,7 @@ cc_library( name = "matmul_utils", srcs = ["matmul_utils.cc"], hdrs = ["matmul_utils.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), deps = [ ":backend_configs_cc", @@ -2944,7 +2944,7 @@ cc_library( name = "gpu_asm_opts_util", srcs = ["gpu_asm_opts_util.cc"], hdrs = ["gpu_asm_opts_util.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), copts = tsl_copts(), deps = [ "//tensorflow/compiler/xla:xla_proto_cc", @@ -2957,7 +2957,7 @@ cc_library( name = "gpu_hlo_cost_analysis", srcs = ["gpu_hlo_cost_analysis.cc"], hdrs = ["gpu_hlo_cost_analysis.h"], - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), deps = [ ":backend_configs_cc", ":cublas_cudnn", diff --git a/tensorflow/compiler/xla/translate/mhlo_to_hlo/BUILD b/tensorflow/compiler/xla/translate/mhlo_to_hlo/BUILD index dedfb49b68f29b..b80ab89e28e199 100644 --- a/tensorflow/compiler/xla/translate/mhlo_to_hlo/BUILD +++ b/tensorflow/compiler/xla/translate/mhlo_to_hlo/BUILD @@ -1,7 +1,7 @@ load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") load("//tensorflow/tsl/platform:rules_cc.bzl", "cc_binary", "cc_library") load("//tensorflow/compiler/xla:xla.bzl", "xla_cc_test") -load("//tensorflow/tsl:tsl.default.bzl", "get_compatible_with_cloud") +load("//tensorflow/tsl:tsl.default.bzl", "get_compatible_with_portable") load("@bazel_skylib//rules:build_test.bzl", "build_test") package( @@ -122,7 +122,7 @@ cc_binary( gentbl_cc_library( name = "operator_writer_inc", - compatible_with = get_compatible_with_cloud(), + compatible_with = get_compatible_with_portable(), tbl_outs = [([], "operator_writers.inc")], tblgen = ":operator_writer_gen", td_file = "//tensorflow/compiler/xla/mlir_hlo:mhlo/IR/hlo_ops.td", From 4ca4a8dea9c5ebe8b29ea3dc06dd85d3d346ea60 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 11 Jul 2023 13:53:23 -0700 Subject: [PATCH 149/376] This is an initial commit to introduce a MHLO custom op that leverages cusparseLt. It performs C=C+A*B and assumes the input A,B, C and the output are dense arrays on the host. Pruning and compression will be done after the data are transfer ed to the device. PiperOrigin-RevId: 547288029 --- .../cpu/transforms/sparse_rewrite_passes.cc | 49 +++++++++++++++++++ 1 file changed, 49 insertions(+) diff --git a/tensorflow/compiler/xla/mlir/backends/cpu/transforms/sparse_rewrite_passes.cc b/tensorflow/compiler/xla/mlir/backends/cpu/transforms/sparse_rewrite_passes.cc index 0b473a634fedab..83c1bdf5d0b371 100644 --- a/tensorflow/compiler/xla/mlir/backends/cpu/transforms/sparse_rewrite_passes.cc +++ b/tensorflow/compiler/xla/mlir/backends/cpu/transforms/sparse_rewrite_passes.cc @@ -495,6 +495,54 @@ struct SparseSDDMMCallRewriter { } }; +// This rewriter rewrites 2:4 SpMM custom op to linalg.generic operator that +// carries the DENSE24 trait and does multiplication. +struct Sparse2To4SpMMCallRewriter { + LogicalResult operator()(mhlo::CustomCallOp op, PatternRewriter& rewriter) { + assert(op.getInputs().size() == 3 && "Need C, A, B matrices"); + assert(op.getResults().size() == 1 && "Need one output tensor"); + Location loc = op.getLoc(); + Value mat_c = op.getInputs()[0]; + Value mat_a = op.getInputs()[1]; + Value mat_b = op.getInputs()[2]; + + auto etp = mat_c.getType().dyn_cast().getElementType(); + // Build the enveloping generic op with the following trait: + // indexing_maps = [ + // affine_map<(i,j,k) -> (i,k)>, // A + // affine_map<(i,j,k) -> (k,j)>, // B + // affine_map<(i,j,k) -> (i,j)> // S + // ], + // iterator_types = ["parallel", "parallel", "reduction"], + // doc = "C(i,j) += SUM_k A(i,k) B(k,j)" + SmallVector iteratorTypes; + iteratorTypes.push_back(utils::IteratorType::parallel); + iteratorTypes.push_back(utils::IteratorType::parallel); + iteratorTypes.push_back(utils::IteratorType::reduction); + using MapList = ArrayRef>; + auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); }; + AffineExpr i, j, k; + bindDims(op.getContext(), i, j, k); + auto indexing_maps = infer({{i, k}, {k, j}, {i, j}}); + auto generic_op = rewriter.create( + loc, TypeRange{mat_c.getType()}, ValueRange{mat_a, mat_b}, + ValueRange{mat_c}, indexing_maps, iteratorTypes); + // Set DENSE24 attribute. + generic_op->setAttr("DENSE24", rewriter.getI32IntegerAttr(1)); + // Construct operations in the linalg.generic block. + Block* main = rewriter.createBlock(&generic_op.getRegion(), {}, + {etp, etp, etp}, {loc, loc, loc}); + Value arg_c = main->getArgument(2); + rewriter.setInsertionPointToStart(&generic_op.getRegion().front()); + auto mul = rewriter.create(loc, main->getArgument(0), + main->getArgument(1)); + auto add = rewriter.create(loc, mul.getResult(), arg_c); + rewriter.create(loc, add.getResult()); + rewriter.replaceOp(op, generic_op.getResults()); + return success(); + } +}; + class SparseCustomCallRewriter : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; using SparseCustomTargetRewriter = std::function { std::make_pair("sparse_tensor_transpose", SparseTransposeCallRewriter()), // User custom ops that need rewriting. std::make_pair("sparse_jax_sddmm", SparseSDDMMCallRewriter()), + std::make_pair("sparse_jax_2to4_spmm", Sparse2To4SpMMCallRewriter()), }; // Rewrites a CustomCallOp to corresponding sparse_tensor operation. From 57dd47d18a1c7036f121a2868d0925ebe84033f0 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 11 Jul 2023 14:19:09 -0700 Subject: [PATCH 150/376] Update TFRT dependency to use revision http://github.com/tensorflow/runtime/commit/2311b85fed9d2a38619e0188a0eabcb3f1ef1b95. PiperOrigin-RevId: 547295729 --- third_party/tf_runtime/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/tf_runtime/workspace.bzl b/third_party/tf_runtime/workspace.bzl index 49c3aa40d3faea..6e4b8512665e42 100644 --- a/third_party/tf_runtime/workspace.bzl +++ b/third_party/tf_runtime/workspace.bzl @@ -6,8 +6,8 @@ def repo(): """Imports TFRT.""" # Attention: tools parse and update these lines. - TFRT_COMMIT = "62d593e29280c8ef8dee7a5477b04b89ac77c06c" - TFRT_SHA256 = "31e762e2cdfd4c956ba92f9f90fe7d5f0896cb8ec3d52111cc261f797b8aba65" + TFRT_COMMIT = "2311b85fed9d2a38619e0188a0eabcb3f1ef1b95" + TFRT_SHA256 = "e175b71871e863c1b3dc767803f1cc70a48d27964286c2875db2451401f38db4" tf_http_archive( name = "tf_runtime", From e449b106afa8e6980807ab150ee63dbaf47189ff Mon Sep 17 00:00:00 2001 From: Jieying Luo Date: Tue, 11 Jul 2023 14:36:05 -0700 Subject: [PATCH 151/376] [TF:PJRT] Returns an error if the compilation result is TensorList. PiperOrigin-RevId: 547300343 --- tensorflow/compiler/jit/xla_launch_util.cc | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc index a0ba7086f91306..0f6fcaa8913fc7 100644 --- a/tensorflow/compiler/jit/xla_launch_util.cc +++ b/tensorflow/compiler/jit/xla_launch_util.cc @@ -689,6 +689,11 @@ Status PopulateCtxOutputsFromPjRtExecutableOutputs( const DataType& type = compilation_result.outputs[i].type; VLOG(2) << "Populating output for retval " << i << " type " << DataTypeString(type); + if (type == DT_VARIANT) { + return absl::UnimplementedError( + "Support for TensorList crossing the XLA/TF boundary " + "is not implemented"); + } if (compilation_result.outputs[i].is_constant) { bool requires_copy_to_device = GetDeviceType(ctx) != DEVICE_CPU; From f687e26b27d44ebced4e5fad83838a52ac3fd448 Mon Sep 17 00:00:00 2001 From: Tao Wang Date: Tue, 11 Jul 2023 14:47:04 -0700 Subject: [PATCH 152/376] Enable to set fdo_profile through XLA python client. PiperOrigin-RevId: 547303330 --- tensorflow/compiler/xla/client/executable_build_options.h | 3 +++ tensorflow/compiler/xla/python/xla_client.py | 2 +- tensorflow/compiler/xla/python/xla_compiler.cc | 2 ++ tensorflow/compiler/xla/python/xla_extension/__init__.pyi | 1 + 4 files changed, 7 insertions(+), 1 deletion(-) diff --git a/tensorflow/compiler/xla/client/executable_build_options.h b/tensorflow/compiler/xla/client/executable_build_options.h index c0603c89de4feb..97ceef5dca9059 100644 --- a/tensorflow/compiler/xla/client/executable_build_options.h +++ b/tensorflow/compiler/xla/client/executable_build_options.h @@ -197,6 +197,9 @@ class ExecutableBuildOptions { } absl::string_view fdo_profile() const { return fdo_profile_; } + void set_fdo_profile(const std::string& fdo_profile) { + fdo_profile_ = fdo_profile; + } std::string* mutable_fdo_profile() { return &fdo_profile_; } // Returns a string representation of the build options, suitable for diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index 8b1deb8abf7372..0d6f2f6c0c6c47 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -44,7 +44,7 @@ # Just an internal arbitrary increasing number to help with backward-compatible # changes. -_version = 165 +_version = 166 # Version number for MLIR:Python components. mlir_api_version = 51 diff --git a/tensorflow/compiler/xla/python/xla_compiler.cc b/tensorflow/compiler/xla/python/xla_compiler.cc index eb21fd9fa7788a..738e9672ab13d4 100644 --- a/tensorflow/compiler/xla/python/xla_compiler.cc +++ b/tensorflow/compiler/xla/python/xla_compiler.cc @@ -867,6 +867,8 @@ void BuildXlaCompilerSubmodule(py::module& m) { py::class_(m, "ExecutableBuildOptions") .def(py::init<>()) .def("__repr__", &ExecutableBuildOptions::ToString) + .def_property("fdo_profile", &ExecutableBuildOptions::fdo_profile, + &ExecutableBuildOptions::set_fdo_profile) .def_property( "result_layout", [](const ExecutableBuildOptions& options) -> std::optional { diff --git a/tensorflow/compiler/xla/python/xla_extension/__init__.pyi b/tensorflow/compiler/xla/python/xla_extension/__init__.pyi index 38cef1047a4d4f..98e3069d1bef9b 100644 --- a/tensorflow/compiler/xla/python/xla_extension/__init__.pyi +++ b/tensorflow/compiler/xla/python/xla_extension/__init__.pyi @@ -268,6 +268,7 @@ class ExecutableBuildOptions: def __init__(self) -> None: ... def __repr__(self) -> str: ... result_layout: Optional[Shape] + fdo_profile: Optional[bytes] num_replicas: int num_partitions: int debug_options: DebugOptions From d1e3c9c4ef16c3a7330441b77d83e7c93c72e57e Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 11 Jul 2023 14:55:54 -0700 Subject: [PATCH 153/376] Fix typo in comment PiperOrigin-RevId: 547305655 --- .../compiler/xla/stream_executor/multi_platform_manager.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/compiler/xla/stream_executor/multi_platform_manager.cc b/tensorflow/compiler/xla/stream_executor/multi_platform_manager.cc index 2fd4a264077548..63ce190c5d5bb5 100644 --- a/tensorflow/compiler/xla/stream_executor/multi_platform_manager.cc +++ b/tensorflow/compiler/xla/stream_executor/multi_platform_manager.cc @@ -76,7 +76,7 @@ class MultiPlatformManagerImpl { tsl::StatusOr LookupByIdLocked(const Platform::Id& id) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); - // Returns the names of the initialied platforms satisfying the given filter. + // Returns the names of the initialized platforms satisfying the given filter. // By default, it will return all initialized platform names. std::vector InitializedPlatformNamesWithFilter( const std::function& filter = [](const Platform*) { From abae3ee267cb8680f8653a66be5767bffd1496a8 Mon Sep 17 00:00:00 2001 From: Xinyi Wang Date: Tue, 11 Jul 2023 15:00:10 -0700 Subject: [PATCH 154/376] Disable test that breaks tensorflow.gpu.pascal PiperOrigin-RevId: 547306800 --- tensorflow/python/ops/BUILD | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/ops/BUILD b/tensorflow/python/ops/BUILD index 3c6556c5506b96..83dc031e86559e 100644 --- a/tensorflow/python/ops/BUILD +++ b/tensorflow/python/ops/BUILD @@ -3670,7 +3670,10 @@ cuda_py_strict_test( srcs = ["nn_test.py"], main = "nn_test.py", python_version = "PY3", - tags = ["no_windows"], + tags = [ + "no_windows", + "notap", # TODO(b/290819913) + ], xla_tags = [ "no_cuda_asan", # times out ], From 930a3845bf7396ce80ba067545d714e4a4879baa Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 11 Jul 2023 15:00:28 -0700 Subject: [PATCH 155/376] Update rules_python version to 0.23.1 PiperOrigin-RevId: 547306878 --- WORKSPACE | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/WORKSPACE b/WORKSPACE index 389a4e5788011e..fb3af8a2bea085 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -14,9 +14,9 @@ http_archive( http_archive( name = "rules_python", - sha256 = "29a801171f7ca190c543406f9894abf2d483c206e14d6acbd695623662320097", - strip_prefix = "rules_python-0.18.1", - url = "https://github.com/bazelbuild/rules_python/releases/download/0.18.1/rules_python-0.18.1.tar.gz", + sha256 = "84aec9e21cc56fbc7f1335035a71c850d1b9b5cc6ff497306f84cced9a769841", + strip_prefix = "rules_python-0.23.1", + url = "https://github.com/bazelbuild/rules_python/releases/download/0.23.1/rules_python-0.23.1.tar.gz", ) load("@rules_python//python:repositories.bzl", "python_register_toolchains") From 6229f7468e13903e62fdf6ca2d52335a289aabb1 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 11 Jul 2023 15:12:07 -0700 Subject: [PATCH 156/376] Update TFRT dependency to use revision http://github.com/tensorflow/runtime/commit/a9afd3d20d538d81145a841ffdf20faf48dc69f8. PiperOrigin-RevId: 547310248 --- third_party/tf_runtime/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/tf_runtime/workspace.bzl b/third_party/tf_runtime/workspace.bzl index 6e4b8512665e42..45799d33c17189 100644 --- a/third_party/tf_runtime/workspace.bzl +++ b/third_party/tf_runtime/workspace.bzl @@ -6,8 +6,8 @@ def repo(): """Imports TFRT.""" # Attention: tools parse and update these lines. - TFRT_COMMIT = "2311b85fed9d2a38619e0188a0eabcb3f1ef1b95" - TFRT_SHA256 = "e175b71871e863c1b3dc767803f1cc70a48d27964286c2875db2451401f38db4" + TFRT_COMMIT = "a9afd3d20d538d81145a841ffdf20faf48dc69f8" + TFRT_SHA256 = "9ceb85b1bc9350c2c0a3f381fce8604173484f969f56d872239bac29a650f060" tf_http_archive( name = "tf_runtime", From eff4a808569baa4c3fe65a84483683cdf779db6b Mon Sep 17 00:00:00 2001 From: Ilia Sergachev Date: Tue, 11 Jul 2023 15:14:16 -0700 Subject: [PATCH 157/376] [XLA:GPU] Rollback cl/547196631. PiperOrigin-RevId: 547310976 --- .../compiler/xla/debug_options_flags.cc | 6 - tensorflow/compiler/xla/service/gpu/BUILD | 7 - .../xla/service/gpu/gemm_rewriter_triton.cc | 444 ++++++------------ .../xla/service/gpu/gemm_rewriter_triton.h | 49 +- .../service/gpu/gemm_rewriter_triton_test.cc | 156 +----- .../compiler/xla/service/gpu/gpu_compiler.cc | 37 +- .../xla/service/gpu/ir_emitter_triton.cc | 2 +- .../xla/service/gpu/ir_emitter_triton_test.cc | 161 ------- .../xla/service/gpu/triton_autotuner.cc | 11 +- tensorflow/compiler/xla/xla.proto | 4 +- 10 files changed, 172 insertions(+), 705 deletions(-) diff --git a/tensorflow/compiler/xla/debug_options_flags.cc b/tensorflow/compiler/xla/debug_options_flags.cc index b3d288bde79532..299635b30746e0 100644 --- a/tensorflow/compiler/xla/debug_options_flags.cc +++ b/tensorflow/compiler/xla/debug_options_flags.cc @@ -138,7 +138,6 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_gpu_enable_cudnn_int8x32_convolution_reordering(true); opts.set_xla_gpu_triton_gemm_any(false); opts.set_xla_gpu_enable_triton_softmax_fusion(false); - opts.set_xla_gpu_triton_fusion_level(1); // Moving reduce-scatter out of while loops can increase memory footprint, so // turning it off by default. @@ -1131,11 +1130,6 @@ void MakeDebugOptionsFlags(std::vector* flag_list, "Forces any reductions during matrix multiplications to use the " "accumulator type and not the output type. The precision of the dot " "operation may not increase that much if there is output fusion.")); - flag_list->push_back(tsl::Flag( - "xla_gpu_triton_fusion_level", - int32_setter_for(&DebugOptions::set_xla_gpu_triton_fusion_level), - debug_options->xla_gpu_triton_fusion_level(), - "Triton fusion level, higher levels mean more fused operations.")); } // NOLINT(readability/fn_size) // Allocates flag_values and flag_objects; this function must not be called more diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 53ee518bb2edbc..e473aa260f6aed 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -434,7 +434,6 @@ cc_library( "//tensorflow/tsl/platform:errors", "//tensorflow/tsl/platform:logging", "//tensorflow/tsl/platform:path", - "//tensorflow/tsl/platform:statusor", "//tensorflow/tsl/platform:tensor_float_32_hdr_lib", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", @@ -490,8 +489,6 @@ xla_test( "//tensorflow/compiler/xla:autotuning_proto_cc", "//tensorflow/compiler/xla:error_spec", "//tensorflow/compiler/xla/hlo/ir:hlo", - "//tensorflow/compiler/xla/service:pattern_matcher", - "//tensorflow/compiler/xla/service:pattern_matcher_gmock", "//tensorflow/compiler/xla/service/gpu/tests:gpu_codegen_test", "//tensorflow/compiler/xla/stream_executor:device_description", "//tensorflow/compiler/xla/stream_executor/cuda:cublas_plugin", @@ -1155,22 +1152,18 @@ cc_library( "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status", - "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/hlo/ir:hlo", "//tensorflow/compiler/xla/hlo/utils:hlo_query", "//tensorflow/compiler/xla/service:hlo_creation_utils", "//tensorflow/compiler/xla/service:hlo_pass", - "//tensorflow/compiler/xla/service:instruction_fusion", - "//tensorflow/compiler/xla/stream_executor:device_description", "//tensorflow/tsl/platform:errors", "//tensorflow/tsl/platform:status", "//tensorflow/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", ], diff --git a/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton.cc b/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton.cc index 20971738a95289..6b28352ccd61ab 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton.cc @@ -22,15 +22,12 @@ limitations under the License. #include #include #include -#include #include #include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" -#include "absl/log/check.h" #include "absl/strings/str_cat.h" -#include "absl/strings/str_join.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/autotuning.pb.h" @@ -40,7 +37,6 @@ limitations under the License. #include "tensorflow/compiler/xla/hlo/ir/hlo_instruction.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_instructions.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_opcode.h" -#include "tensorflow/compiler/xla/hlo/ir/hlo_schedule.h" #include "tensorflow/compiler/xla/hlo/utils/hlo_query.h" #include "tensorflow/compiler/xla/layout.h" #include "tensorflow/compiler/xla/literal_util.h" @@ -50,12 +46,9 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/gpu/matmul_utils.h" #include "tensorflow/compiler/xla/service/hlo_creation_utils.h" -#include "tensorflow/compiler/xla/service/instruction_fusion.h" #include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status.h" -#include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/compiler/xla/stream_executor/device_description.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/tsl/platform/errors.h" @@ -64,25 +57,6 @@ limitations under the License. namespace xla { namespace gpu { - -bool TensorIterationSpec::operator==(const TensorIterationSpec& other) const { - for (int dim = 0; dim < TensorIterationSpec::kMaxDimsPerTensor; ++dim) { - if (dim_iteration_specs_[dim].size() != other[dim].size()) { - return false; - } - for (int fragment = 0; fragment < dim_iteration_specs_[dim].size(); - ++fragment) { - if (dim_iteration_specs_[dim][fragment].stride != - other[dim][fragment].stride || - dim_iteration_specs_[dim][fragment].count != - other[dim][fragment].count) { - return false; - } - } - } - return true; -} - namespace { // Batch dimensions of an operand of a dot instruction. @@ -121,10 +95,10 @@ int64_t NonContractingDimensionIndex(const HloInstruction& dot, } // Data types that are tested to work in the triton GEMM emitter. -bool IsSupportedDataType(PrimitiveType type, GpuVersion gpu_version) { +bool IsSupportedDataType(PrimitiveType t, GpuVersion gpu_version) { auto cuda_compute_capability = std::get(gpu_version); - switch (type) { + switch (t) { case PRED: case S8: case S16: @@ -140,19 +114,21 @@ bool IsSupportedDataType(PrimitiveType type, GpuVersion gpu_version) { } } -// Let input and output data volumes of a fusion grow by small amounts. -constexpr int64_t kIoToleranceBytes = 1024; - -// Difference of input and output data volumes of an instruction. -int64_t InputMinusOutputBytes(const HloInstruction& hlo) { - CHECK(!hlo.shape().IsTuple()); - int64_t output_size = ShapeUtil::ByteSizeOf(hlo.shape()); - int64_t input_size = 0; - for (const HloInstruction* operand : hlo.operands()) { - CHECK(!operand->shape().IsTuple()); - input_size += ShapeUtil::ByteSizeOf(operand->shape()); +Status RequireTritonFusibleConvert(const HloInstruction* input, + GpuVersion gpu_version) { + if (!IsSupportedDataType(input->operand(0)->shape().element_type(), + gpu_version)) { + return Unimplemented("unsupported data type"); } - return input_size - output_size; + // TODO(b/266862494): Can pick up almost any + // convert, but if it's reducing the data volume it should rather be fused + // to the output of the producer kernel. However not all operations support + // output fusion - then it should be fused here anyway! + if (ShapeUtil::ByteSizeOf(input->operand(0)->shape()) > + ShapeUtil::ByteSizeOf(input->shape())) { + return FailedPrecondition("narrowing conversion"); + } + return OkStatus(); } // Handles numbers of dimensions of a target HLO instruction @@ -166,13 +142,6 @@ class DimensionOrder { int64_t target_dim_number; int subdim_number; int64_t size; - bool operator==(const DimDescription& other) const { - return target_dim_number == other.target_dim_number && - subdim_number == other.subdim_number && size == other.size; - } - std::string ToString() const { - return absl::StrCat(target_dim_number, ":", subdim_number, ":", size); - } }; // Sequence describing all dimensions of HLO's output shape // in layout minor-to-major (physical) order. @@ -202,35 +171,34 @@ class DimensionOrder { // Transforms the DimensionOrder so that from a description of the output // of `hlo` it becomes a description of the input of `hlo`. - FusionDecision HandleInstruction(const HloInstruction* hlo) { + Status HandleInstruction(const HloInstruction* hlo) { VLOG(7) << hlo->ToString(); - if (hlo->opcode() == HloOpcode::kParameter || - hlo->opcode() == HloOpcode::kConstant) { - return FusionDecision{}; + if (hlo->opcode() == HloOpcode::kParameter) { + return OkStatus(); } else if (hlo->opcode() == HloOpcode::kTranspose || hlo->opcode() == HloOpcode::kCopy) { return HandleCopyOrTranspose(hlo); } else if (hlo->operand_count() > 0 && IsTritonSupportedElementwise( hlo->opcode(), hlo->operand(0)->shape().element_type())) { - return FusionDecision{}; + return OkStatus(); } else if (hlo->opcode() == HloOpcode::kBitcast) { return HandleBitcast(hlo); } else if (hlo->opcode() == HloOpcode::kReshape) { if (!ShapeUtil::ReshapeIsBitcast(hlo->operand(0)->shape(), hlo->shape())) { - return "Non-bitcast reshape."; + return Unimplemented("Non-bitcast reshape."); } return HandleBitcast(hlo); } else if (hlo_query::IsScalarConstant(hlo) || hlo_query::IsBroadcastOfScalarConstant(*hlo)) { // Dimension order collapses on a scalar, for simplicity leave it equal // to the output one for now. - return FusionDecision{}; + return OkStatus(); } else { - return "Unimplemented instruction."; + return Unimplemented("Instruction: %s", hlo->ToString()); } - return FusionDecision{}; + return OkStatus(); } // Get the raw data of the dimension order. @@ -242,32 +210,20 @@ class DimensionOrder { return splittable_dimension_index_; } - // Tells that two dimension orders describe the same tensor physical layout. - bool IsPhysicallyEquivalent(const DimensionOrder& other) const; - - std::string ToString() const { - return absl::StrJoin(dim_order_, "-", - [](std::string* out, const DimDescription& d) { - absl::StrAppend(out, d.ToString()); - }); - } - private: // See HandleInstruction() for the general description of Handle*(). - FusionDecision HandleBitcast(const HloInstruction* hlo); - FusionDecision HandleCopyOrTranspose(const HloInstruction* hlo); + Status HandleBitcast(const HloInstruction* hlo); + Status HandleCopyOrTranspose(const HloInstruction* hlo); DimOrderVector dim_order_; - const int64_t splittable_dimension_index_; + int64_t splittable_dimension_index_; }; -using DimIterationSpec = TensorIterationSpec::DimIterationSpec; - -TensorIterationSpec DimensionOrderToTensorIterationSpec( +DotFusionAnalysis::TensorIterationSpec DimensionOrderToTensorIterationSpec( const DimensionOrder& order) { const DimensionOrder::DimOrderVector& dim_order_vector = order.GetDimOrderVector(); - TensorIterationSpec tensor_spec; + DotFusionAnalysis::TensorIterationSpec tensor_spec; int64_t accumulated_stride = 1; for (int dim_order_index = 0; dim_order_index < dim_order_vector.size(); ++dim_order_index) { @@ -280,7 +236,8 @@ TensorIterationSpec DimensionOrderToTensorIterationSpec( continue; } - DimIterationSpec& dim_spec = tensor_spec[dim.target_dim_number]; + DotFusionAnalysis::DimIterationSpec& dim_spec = + tensor_spec[dim.target_dim_number]; if (dim_order_index > 0 && dim_order_vector[dim_order_index - 1].target_dim_number == dim.target_dim_number) { @@ -300,7 +257,7 @@ TensorIterationSpec DimensionOrderToTensorIterationSpec( accumulated_stride *= dim.size; } // Create all absent dimensions as degenerate ones to simplify later queries. - for (DimIterationSpec& dim_spec : tensor_spec) { + for (DotFusionAnalysis::DimIterationSpec& dim_spec : tensor_spec) { if (dim_spec.empty()) { dim_spec.push_back({/*stride=*/0, /*count=*/1, /*subfragments=*/{1}}); } @@ -308,11 +265,6 @@ TensorIterationSpec DimensionOrderToTensorIterationSpec( return tensor_spec; } -bool DimensionOrder::IsPhysicallyEquivalent(const DimensionOrder& other) const { - return DimensionOrderToTensorIterationSpec(*this) == - DimensionOrderToTensorIterationSpec(other); -} - DimensionOrder DimensionOrder::FromDotOperand(const HloInstruction& dot, const int operand_number, const int64_t split_k) { @@ -335,7 +287,7 @@ DimensionOrder DimensionOrder::FromDotOutput(const HloInstruction& dot) { return DimensionOrder(&dot); } -FusionDecision DimensionOrder::HandleBitcast(const HloInstruction* hlo) { +Status DimensionOrder::HandleBitcast(const HloInstruction* hlo) { const Shape& operand_shape = hlo->operand(0)->shape(); DimOrderVector operand_dim_order; operand_dim_order.reserve(dim_order_.size()); @@ -349,7 +301,7 @@ FusionDecision DimensionOrder::HandleBitcast(const HloInstruction* hlo) { ++out_dim) { if (operand_remaining_size >= out_dim->size) { if (operand_remaining_size % out_dim->size) { - return "Unsupported bitcast"; + return Unimplemented("Unsupported bitcast: %s", hlo->ToString()); } // Output dimension fragment completely fits into the operand one: // just copy it as is. @@ -367,7 +319,7 @@ FusionDecision DimensionOrder::HandleBitcast(const HloInstruction* hlo) { // If there is a remaining fragment of a previous operand dimension // assign it first. if (out_remaining_size % operand_remaining_size) { - return "Unsupported bitcast"; + return Unimplemented("Unsupported bitcast: %s", hlo->ToString()); } operand_dim_order.push_back( {out_dim->target_dim_number, subdim_index, operand_remaining_size}); @@ -385,7 +337,7 @@ FusionDecision DimensionOrder::HandleBitcast(const HloInstruction* hlo) { // assign the remainder of the output and carry over the remainder // of the operand. if (operand_dim_size % out_remaining_size) { - return "Unsupported bitcast"; + return Unimplemented("Unsupported bitcast: %s", hlo->ToString()); } operand_remaining_size = operand_dim_size / out_remaining_size; new_fragment_size = out_remaining_size; @@ -406,7 +358,7 @@ FusionDecision DimensionOrder::HandleBitcast(const HloInstruction* hlo) { int subdim_index = operand_dim_order.back().subdim_number + 1; while (operand_dim_iter != operand_shape.layout().minor_to_major().cend()) { if (operand_shape.dimensions(*operand_dim_iter) != 1) { - return "Unsupported bitcast"; + return Unimplemented("Unsupported bitcast: %s", hlo->ToString()); } operand_dim_order.push_back( {operand_dim_order.back().target_dim_number, subdim_index, 1}); @@ -415,11 +367,10 @@ FusionDecision DimensionOrder::HandleBitcast(const HloInstruction* hlo) { } dim_order_ = operand_dim_order; - return FusionDecision{}; + return OkStatus(); } -FusionDecision DimensionOrder::HandleCopyOrTranspose( - const HloInstruction* hlo) { +Status DimensionOrder::HandleCopyOrTranspose(const HloInstruction* hlo) { // Every HLO dimension can correspond to a group of subdimensions in // dim_order_. For the easier handling of permutations: group dim_order_ by // dimension, apply permutations, then finally remove the grouping. @@ -468,25 +419,25 @@ FusionDecision DimensionOrder::HandleCopyOrTranspose( dim_order_.push_back(subdim); } } - return FusionDecision{}; + return OkStatus(); } // Tells if the dimension order is supported by the triton GEMM emitter. // Only the dimension indicated by SplittableDimensionIndex() can be split // physically once by other dimensions. Other ones can be only split logically. // All subdimensions within a dimension have to be ordered. -FusionDecision RequireTritonGemmSupportedDimOrder(const DimensionOrder& order) { - std::array subdim_counters = { +Status RequireTritonGemmSupportedDimOrder(const DimensionOrder& order) { + std::array subdim_counters = { -1, -1, -1, -1}; - std::array split_counters = { + std::array split_counters = { -1, -1, -1, -1}; const DimensionOrder::DimOrderVector& dim_order_vector = order.GetDimOrderVector(); - VLOG(8) << order.ToString(); for (int i = 0; i < dim_order_vector.size(); i++) { const auto [dim_number, subdim_number, size] = dim_order_vector[i]; + VLOG(8) << dim_number << "\t" << subdim_number << "\t" << size; if (subdim_counters[dim_number] != subdim_number - 1) { - return "Transpose within a dimension."; + return Unimplemented("Transpose within a dimension."); } ++subdim_counters[dim_number]; if (size == 1) { @@ -496,185 +447,31 @@ FusionDecision RequireTritonGemmSupportedDimOrder(const DimensionOrder& order) { ++split_counters[dim_number]; if (dim_number == order.SplittableDimensionIndex()) { if (split_counters[dim_number] > 1) { - return "2nd split of a splittable dimension."; + return Unimplemented("2nd split of a splittable dimension."); } } else if (split_counters[dim_number] > 0) { - return "Split of a non-splittable dimension."; + return Unimplemented("Split of a non-splittable dimension."); } } } - return FusionDecision{}; -} - -// Tells if an instruction has no input into which it could be fused. -// More cases should be added here. -bool CanNotBeFusedIntoAProducer(const HloInstruction& hlo) { - return hlo_query::AllOperandsAreParametersOrConstants(hlo); -} - -// Tells that fusing an instruction is efficient. -bool IsInputWorthFusing(const HloInstruction& hlo) { - return hlo_query::AllOperandsAreParametersOrConstants(hlo) || - InputMinusOutputBytes(hlo) < kIoToleranceBytes; + return OkStatus(); } -// Checks if the instruction is possible and profitable to fuse. -// If so tries to transform dim_order describing output of `hlo` into a +// Transforms dim_order describing the output of `hlo` into a // description of its input if it is supported by the triton GEMM emitter. -FusionDecision CanFuse(const HloInstruction& hlo, DimensionOrder& dim_order, - const GpuVersion gpu_version) { - if (hlo.opcode() == HloOpcode::kTuple || - hlo.opcode() == HloOpcode::kGetTupleElement) { - return "Unsupported instruction."; - } - for (const HloInstruction* operand : hlo.operands()) { - if (!IsSupportedDataType(operand->shape().element_type(), gpu_version)) { - return "Unsupported input data type."; - } - } - if (!IsSupportedDataType(hlo.shape().element_type(), gpu_version)) { - return "Unsupported output data type."; - } - if (hlo.IsConstant()) { - return "Not fusing a constant."; - } - if (hlo.opcode() == HloOpcode::kBroadcast) { - return "Not fusing a broadcast."; - } - if (!CanNotBeFusedIntoAProducer(hlo) && !IsInputWorthFusing(hlo)) { - return "Not obviously profitable to fuse as input."; - } - if (hlo.IsElementwise() && hlo.opcode() != HloOpcode::kCopy && - hlo.opcode() != HloOpcode::kConvert && - hlo.GetModule()->config().debug_options().xla_gpu_triton_fusion_level() < - 2) { - return "Skipping most elementwise operations at low fusion levels."; - } - if (FusionDecision decision = dim_order.HandleInstruction(&hlo); !decision) { - return decision; +Status CanFuse(const HloInstruction* hlo, DimensionOrder& dim_order, + const GpuVersion gpu_version) { + if (hlo->opcode() == HloOpcode::kConvert) { + return RequireTritonFusibleConvert(hlo, gpu_version); + } else if (hlo->IsElementwise() && hlo->opcode() != HloOpcode::kCopy) { + // Temporarily forbid fusing elementwise operations + // other than copy and convert. + return Unimplemented("Unsupported elementwise operation"); } + TF_RETURN_IF_ERROR(dim_order.HandleInstruction(hlo)); return RequireTritonGemmSupportedDimOrder(dim_order); } -// Clone an instruction into the fusion. -void Fuse(HloInstruction& hlo, - absl::flat_hash_map& - old_to_new_mapping, - std::vector& call_operands, - HloComputation::Builder& builder) { - if (old_to_new_mapping.contains(&hlo)) { - return; - } - VLOG(3) << "Fusing " << hlo.ToString(); - auto get_or_add_parameter = [&](HloInstruction& instr) { - if (auto it = old_to_new_mapping.find(&instr); - it != old_to_new_mapping.end()) { - return it->second; - } - call_operands.push_back(&instr); - return old_to_new_mapping - .insert({&instr, - builder.AddInstruction(HloInstruction::CreateParameter( - call_operands.size() - 1, instr.shape(), - absl::StrCat("parameter_", call_operands.size() - 1)))}) - .first->second; - }; - if (hlo.opcode() == HloOpcode::kParameter || - hlo.opcode() == HloOpcode::kGetTupleElement) { - get_or_add_parameter(hlo); - } else { - std::vector hlo_new_operands; - for (HloInstruction* operand : hlo.operands()) { - hlo_new_operands.push_back(get_or_add_parameter(*operand)); - } - old_to_new_mapping[&hlo] = builder.AddInstruction( - hlo.CloneWithNewOperands(hlo.shape(), hlo_new_operands)); - } -} - -// Tells how many new parameters does a fusion gain by fusing the operation as -// an input. -int64_t NumAddedParameters(const HloInstruction& hlo) { - // Non-scalar constant is equivalent to a parameter: one input, one output. - if (hlo.opcode() == HloOpcode::kConstant && - !ShapeUtil::IsScalar(hlo.shape())) { - return 0; - } - // All other instructions add all own inputs and remove own single output. - return hlo.operand_count() - 1; -} - -// Fuse an instruction with all its fusible inputs. -// If an input is not fusible stop there and make a parameter of the new -// fusion, otherwise put it onto stack and check its own inputs first. -void FuseWithInputsRecursively( - HloInstruction* root, DimensionOrder root_dim_order, - // Dimension orders describing inputs of corresponding instructions. - absl::flat_hash_map& dim_orders, - const GpuVersion gpu_version, - absl::flat_hash_map& - old_to_new_mapping, - std::vector& call_operands, - HloComputation::Builder& builder) { - absl::flat_hash_set visited; - std::stack to_fuse; - // Instructions at the edge 'to_fuse' that can either get fused too or - // become parameters of the fusion. Used to track the number of parameters - // of the fusion. - absl::flat_hash_set inputs; - // Currently only one physically unique dim order per scope is supported. - // Let it change while the scope has one input; afterwards require all - // of them to be physically compatible. - const HloInstruction* reference_dim_order_hlo = nullptr; - if (CanFuse(*root, root_dim_order, gpu_version)) { - to_fuse.push(root); - inputs.insert(root->operands().begin(), root->operands().end()); - // root_dim_order went through output -> input transformation here. - CHECK(dim_orders.insert({root, root_dim_order}).second) << root->ToString(); - } - visited.insert(root); - while (!to_fuse.empty()) { - bool top_is_ready_to_fuse = true; - HloInstruction* hlo = to_fuse.top(); - if (reference_dim_order_hlo == nullptr && hlo->operand_count() > 1) { - reference_dim_order_hlo = hlo; - } - for (HloInstruction* operand : hlo->mutable_operands()) { - if (visited.insert(operand).second) { - // Stop adding new parameters. - if (inputs.size() >= DotFusionAnalysis::kMaxParameterPerScope && - NumAddedParameters(*operand) > 0) { - continue; - } - // Operand's output is described by its consumer's input. - DimensionOrder operand_dim_order(dim_orders.at(hlo)); - // CanFuse() makes output -> input transformation of - // operand_dim_order if succeeds. - if (CanFuse(*operand, operand_dim_order, gpu_version)) { - if (reference_dim_order_hlo != nullptr && - !operand_dim_order.IsPhysicallyEquivalent( - dim_orders.at(reference_dim_order_hlo))) { - continue; - } - to_fuse.push(operand); - if (operand->opcode() != HloOpcode::kParameter) { - inputs.erase(operand); - } - inputs.insert(operand->operands().begin(), operand->operands().end()); - // Save the dimension order description of operand's input. - CHECK(dim_orders.insert({operand, operand_dim_order}).second) - << operand->ToString(); - top_is_ready_to_fuse = false; - } - } - } - if (top_is_ready_to_fuse) { - Fuse(*hlo, old_to_new_mapping, call_operands, builder); - to_fuse.pop(); - } - } -} - // Extracts into fused computations parts of HLO graph including dot() // operations that can target the triton GEMM emitter. class GemmRewriterTritonVisitor : public DfsHloRewriteVisitor { @@ -686,9 +483,8 @@ class GemmRewriterTritonVisitor : public DfsHloRewriteVisitor { // and replaces the original dot() with a call to the computation. Status HandleDot(HloInstruction* dot) override { VLOG(5) << dot->ToString(); - FusionDecision can_handle = CanTritonHandleGEMM(*dot, gpu_version_); - if (!can_handle) { - VLOG(3) << can_handle.Explain(); + + if (!CanTritonHandleGEMM(*dot, gpu_version_)) { return OkStatus(); } @@ -707,28 +503,72 @@ class GemmRewriterTritonVisitor : public DfsHloRewriteVisitor { std::string suggested_name = absl::StrCat("triton_gemm_", dot->name()); HloComputation::Builder builder( absl::StrCat(suggested_name, "_computation")); - std::vector call_operands; // Original instruction -> fused one. absl::flat_hash_map old_to_new_mapping; - - auto fuse_inputs = [&](int operand_number) { - absl::flat_hash_map dim_orders; - int operand_count_before = call_operands.size(); - // Direct dot inputs have well defined dimension orders. - FuseWithInputsRecursively( - dot->mutable_operand(operand_number), - DimensionOrder::FromDotOperand(*dot, operand_number), dim_orders, - gpu_version_, old_to_new_mapping, call_operands, builder); - return call_operands.size() - operand_count_before; - }; - // Separate traversal from LHS and RHS inputs of the dot: they use - // differently shaped tiles but may go through same HLO graph nodes. - TF_RET_CHECK(fuse_inputs(0) <= DotFusionAnalysis::kMaxParameterPerScope); - TF_RET_CHECK(fuse_inputs(1) <= DotFusionAnalysis::kMaxParameterPerScope); - - Fuse(*dot, old_to_new_mapping, call_operands, builder); - + absl::flat_hash_set visited; + std::vector call_operands; + // Traverse and fuse dot() inputs bottom-up starting from direct operands. + // If an input is not fusible stop there and make it a parameter of the new + // fusion, otherwise put it onto stack and check its own inputs first. + std::stack to_fuse; + // Dimension orders describing inputs of corresponding instructions. + absl::flat_hash_map dim_orders; + to_fuse.push(dot); + while (!to_fuse.empty()) { + bool top_is_ready_to_fuse = true; + HloInstruction* hlo = to_fuse.top(); + for (HloInstruction* operand : hlo->mutable_operands()) { + if (visited.insert(operand).second) { + DimensionOrder operand_dim_order = [&] { + // Direct dot inputs are described by default dimension orders. + if (operand == dot->operand(0)) { + return DimensionOrder::FromDotOperand(*dot, 0); + } else if (operand == dot->operand(1)) { + return DimensionOrder::FromDotOperand(*dot, 1); + } + // Otherwise operand's output is described by its consumer's input. + return DimensionOrder(dim_orders.at(hlo)); + }(); + // CanFuse() makes output -> input transformation of + // operand_dim_order if succeeds. + if (CanFuse(operand, operand_dim_order, gpu_version_).ok()) { + VLOG(3) << "Fusing " << operand->ToString(); + to_fuse.push(operand); + // Save the dimension order description of operand's input. + dim_orders.insert({operand, operand_dim_order}); + top_is_ready_to_fuse = false; + } + } + } + if (top_is_ready_to_fuse) { + if (hlo->opcode() == HloOpcode::kParameter || + hlo->opcode() == HloOpcode::kGetTupleElement) { + old_to_new_mapping[hlo] = + builder.AddInstruction(HloInstruction::CreateParameter( + call_operands.size(), hlo->shape(), + absl::StrCat("parameter_", call_operands.size()))); + call_operands.push_back(hlo); + } else { + std::vector hlo_new_operands; + for (HloInstruction* operand : hlo->operands()) { + const auto iter = old_to_new_mapping.find(operand); + if (iter != old_to_new_mapping.end()) { + hlo_new_operands.push_back(iter->second); + } else { + hlo_new_operands.push_back( + builder.AddInstruction(HloInstruction::CreateParameter( + call_operands.size(), operand->shape(), + absl::StrCat("parameter_", call_operands.size())))); + call_operands.push_back(operand); + } + } + old_to_new_mapping[hlo] = builder.AddInstruction( + hlo->CloneWithNewOperands(hlo->shape(), hlo_new_operands)); + } + to_fuse.pop(); + } + } HloComputation* computation = dot->GetModule()->AddComputationAndUnifyNamesAndIds(builder.Build(), /*is_entry=*/false); @@ -752,7 +592,7 @@ class GemmRewriterTritonVisitor : public DfsHloRewriteVisitor { } else { TF_RETURN_IF_ERROR(ReplaceInstruction(dot, dot_fusion)); } - XLA_VLOG_LINES(5, computation->ToString()); + VLOG(5) << computation->ToString(); return OkStatus(); } @@ -803,7 +643,7 @@ StatusOr MakeSplitKOperand( for (const HloInstruction* param : analysis.ScopeParameters(scope)) { // If an operand of dot does not read any parameters its K dimension // does not need analysis for fragmentation. - const DimIterationSpec* spec = + const DotFusionAnalysis::DimIterationSpec* spec = analysis.IterSpec(scope, param, contracting_dim_idx); // Split contracting dimension is not implemented yet. CHECK_EQ(spec->size(), 1); @@ -1045,8 +885,8 @@ DotFusionAnalysis::DotFusionAnalysis(const HloComputation* dot_computation, absl::flat_hash_map dim_orders; DimensionOrder dot_operand_dim_order = DimensionOrder::FromDotOperand(*dot, operand_number, split_k); - CHECK(dot_operand_dim_order.HandleInstruction(dot_operand)); - CHECK(RequireTritonGemmSupportedDimOrder(dot_operand_dim_order)) + TF_CHECK_OK(dot_operand_dim_order.HandleInstruction(dot_operand)); + TF_CHECK_OK(RequireTritonGemmSupportedDimOrder(dot_operand_dim_order)) << dot_computation->ToString(); dim_orders.insert({dot_operand, dot_operand_dim_order}); visited.insert(dot_operand); @@ -1067,18 +907,14 @@ DotFusionAnalysis::DotFusionAnalysis(const HloComputation* dot_computation, {hlo_operand, DimensionOrder(dim_orders.at(hlo))}); CHECK(inserted); DimensionOrder& hlo_operand_dim_order = it->second; - CHECK(hlo_operand_dim_order.HandleInstruction(hlo_operand)); - CHECK(RequireTritonGemmSupportedDimOrder(hlo_operand_dim_order)) + TF_CHECK_OK(hlo_operand_dim_order.HandleInstruction(hlo_operand)); + TF_CHECK_OK(RequireTritonGemmSupportedDimOrder(hlo_operand_dim_order)) << " " << dot_computation->ToString(); to_process.push(hlo_operand); } } - // For now all parameters of one scope have to use the same tiling. for (const HloInstruction* parameter : parameters_[scope]) { - CHECK(dim_orders.at(parameter).IsPhysicallyEquivalent( - dim_orders.at(*parameters_[scope].cbegin()))) - << dot_computation->ToString(); iter_specs_[scope][parameter] = DimensionOrderToTensorIterationSpec(dim_orders.at(parameter)); } @@ -1090,22 +926,22 @@ DotFusionAnalysis::DotFusionAnalysis(const HloComputation* dot_computation, .second); } -const DimIterationSpec* DotFusionAnalysis::IterSpec( +const DotFusionAnalysis::DimIterationSpec* DotFusionAnalysis::IterSpec( const DotFusionAnalysis::Scope scope, const HloInstruction* hlo, const int dimension) const { auto ret = iter_specs_.at(scope).find(hlo); if (ret != iter_specs_.at(scope).end()) { - return &ret->second[dimension]; + return &ret->second.at(dimension); } return nullptr; } -FusionDecision CanTritonHandleGEMM(const HloInstruction& dot, - const GpuVersion gpu_version) { +bool CanTritonHandleGEMM(const HloInstruction& dot, + const GpuVersion gpu_version) { if (dot.opcode() != HloOpcode::kDot || absl::c_any_of(dot.precision_config().operand_precision(), [](int x) { return x != PrecisionConfig::DEFAULT; })) { - return "Non-default precision."; + return false; } auto supported_output_type = [&](const PrimitiveType t) { @@ -1125,21 +961,21 @@ FusionDecision CanTritonHandleGEMM(const HloInstruction& dot, // TODO(b/266862493): Support more output types. if (!supported_output_type(dot.shape().element_type())) { - return "Unsupported output data type."; + return false; } if (!IsSupportedDataType(dot.operand(0)->shape().element_type(), gpu_version) || !IsSupportedDataType(dot.operand(1)->shape().element_type(), gpu_version)) { - return "Unsupported input data type."; + return false; } const DotDimensionNumbers& dim_numbers = dot.dot_dimension_numbers(); // TODO(b/269580541): support multiple batch dimensions. if (dim_numbers.lhs_batch_dimensions().size() > 1) { - return "Multiple batch dimensions."; + return false; } // Cases where lhs or rhs have no non-contracting dims are not handled. @@ -1149,10 +985,10 @@ FusionDecision CanTritonHandleGEMM(const HloInstruction& dot, dim_numbers.rhs_batch_dimensions().size() + dim_numbers.rhs_contracting_dimensions().size() == dot.operand(1)->shape().rank()) { - return "No non-contracting dimensions."; + return false; } - return FusionDecision{}; + return true; } bool ShouldTritonHandleGEMM(const HloInstruction& dot, @@ -1172,7 +1008,7 @@ bool ShouldTritonHandleGEMM(const HloInstruction& dot, while (!queue.empty()) { const HloInstruction* current = queue.front(); queue.pop(); - if (!CanFuse(*current, dim_order, gpu_version)) { + if (!CanFuse(current, dim_order, gpu_version).ok()) { continue; } // Stop as soon as a profitable operation is fused. diff --git a/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton.h b/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton.h index 0afc939b43ede2..715c79d9114659 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton.h +++ b/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton.h @@ -29,7 +29,6 @@ limitations under the License. #include "tensorflow/compiler/xla/hlo/ir/hlo_opcode.h" #include "tensorflow/compiler/xla/service/gpu/gpu_types.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" -#include "tensorflow/compiler/xla/service/instruction_fusion.h" namespace xla { namespace gpu { @@ -53,13 +52,13 @@ Status MakeDotSplitKBatch(HloInstruction* dot_fusion, const AutotuneResult::TritonGemmKey& tiling); // Filters GEMMs which can be handled using Triton. -FusionDecision CanTritonHandleGEMM(const HloInstruction&, - GpuVersion gpu_version); +bool CanTritonHandleGEMM(const HloInstruction&, GpuVersion gpu_version); // Filters GEMMs which are better to handle using Triton. bool ShouldTritonHandleGEMM(const HloInstruction&, GpuVersion gpu_version); -class TensorIterationSpec { +// Analysis of iteration of HLO shapes within a fusion around dot(). +class DotFusionAnalysis { public: // Description of basic iteration: `count` elements separated by `stride`. struct IterationSpecFragment { @@ -69,42 +68,16 @@ class TensorIterationSpec { // of several HLO dimensions. Product of subfragments equals `count`. std::vector subfragments; }; + // Description of complex iteration over a sequence of several strides. // Describes a logically contiguous dimension of a tensor physically // separated into multiple fragments by other dimensions. using DimIterationSpec = std::vector; // At most: contracting, non-contracting, split-K, another batch. - static constexpr int kMaxDimsPerTensor = 4; - using StorageType = std::array; - - const DimIterationSpec& operator[](int dimension) const { - return dim_iteration_specs_[dimension]; - } - - DimIterationSpec& operator[](int dimension) { - return dim_iteration_specs_[dimension]; - } - - // Compares physical layouts of tensors ignoring subfragments of dimensions. - bool operator==(const TensorIterationSpec& other) const; - - StorageType::iterator begin() { return dim_iteration_specs_.begin(); } - StorageType::iterator end() { return dim_iteration_specs_.end(); } - StorageType::const_iterator cbegin() const { - return dim_iteration_specs_.cbegin(); - } - StorageType::const_iterator cend() const { - return dim_iteration_specs_.cend(); - } - - private: - StorageType dim_iteration_specs_; -}; + static const int kMaxDimsPerTensor = 4; + using TensorIterationSpec = std::array; -// Analysis of iteration of HLO shapes within a fusion around dot(). -class DotFusionAnalysis { - public: // Execute analysis of dot fusion computation. // split_k indicates whether this operation was converted to the split-K // form and tells the analysis how to interpret the batch dimensions. @@ -115,15 +88,9 @@ class DotFusionAnalysis { // defined by left operand, right operand and output. enum class Scope { LHS = 0, RHS = 1, OUTPUT = 2 }; - // Every parameter requires a separate piece of shared memory for asynchronous - // loads. Multiple parameters are approximately equivalent to multiple - // pipeline stages. - static constexpr int kMaxParameterPerScope = 4; - // Scope -> HLO -> dot dimension number -> iteration spec at the HLO's output. - const TensorIterationSpec::DimIterationSpec* IterSpec(Scope scope, - const HloInstruction*, - int dimension) const; + const DimIterationSpec* IterSpec(Scope scope, const HloInstruction*, + int dimension) const; // Parameter HLO instructions used in a scope of `dot`. const absl::flat_hash_set& ScopeParameters( const Scope scope) const { diff --git a/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton_test.cc b/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton_test.cc index 95eaf51915d2e5..d02faa5b3abdc9 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton_test.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton_test.cc @@ -94,7 +94,7 @@ ENTRY e { GmockMatch(m::Fusion(m::Parameter(), m::Parameter()))); } -TEST_F(GemmRewriterTritonTest, DoNotFuseConstants) { +TEST_F(GemmRewriterTritonTest, DoNotFuseConstant) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(R"( HloModule m @@ -102,14 +102,14 @@ HloModule m ENTRY e { p0 = s8[60,5] parameter(0) c0 = f16[60,5] convert(p0) - cst1 = f16[] constant(1234) - r1 = f16[5,120] broadcast(cst1) + cst1 = f16[600] constant({...}) + r1 = f16[5,120] reshape(cst1) ROOT d = f16[60,120] dot(c0, r1), lhs_contracting_dims={1}, rhs_contracting_dims={0} })")); EXPECT_TRUE(GemmRewriterTriton(gpu_version_).Run(module.get()).value()); EXPECT_THAT(module->entry_computation()->root_instruction(), - GmockMatch(m::Fusion(m::Parameter(), m::Broadcast()))); + GmockMatch(m::Fusion(m::Constant(), m::Parameter()))); } using TritonDotAnalysisTest = HloTestBase; @@ -793,154 +793,6 @@ ENTRY e { EXPECT_TRUE(GemmRewriterTriton(cc).Run(module.get()).value()); } -class GemmRewriterTritonLevel2Test : public GemmRewriterTritonTest { - public: - DebugOptions GetDebugOptionsForTest() override { - DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest(); - debug_options.set_xla_gpu_triton_fusion_level(2); - return debug_options; - } -}; - -TEST_F(GemmRewriterTritonLevel2Test, DoNotFuseIncompatibleDimOrders) { - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(R"( -HloModule m - -ENTRY e { - p0 = f16[5,3] parameter(0) - p1 = f16[5,7] parameter(1) - p2 = f16[7,5] parameter(2) - t = f16[5,7] transpose(p2), dimensions={1,0} - a = f16[5,7] add(t, p1) - ROOT d = f16[3,7] dot(p0, a), - lhs_contracting_dims={0}, rhs_contracting_dims={0} -})")); - - EXPECT_TRUE(GemmRewriterTriton(gpu_version_).Run(module.get()).value()); - EXPECT_THAT( - module->entry_computation()->root_instruction(), - GmockMatch(m::Fusion(m::Parameter(), m::Parameter(), m::Transpose()))); -} - -TEST_F(GemmRewriterTritonLevel2Test, DoNotFuseTooManyParameters) { - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(R"( -ENTRY e { - tmp_0 = f32[] constant(1) - tmp_1 = f32[3,49]{1,0} broadcast(tmp_0), dimensions={} - tmp_2 = f32[3,49]{1,0} parameter(6) - tmp_3 = f32[] constant(0) - tmp_4 = f32[3,49]{1,0} broadcast(tmp_3), dimensions={} - tmp_5 = pred[3,49]{1,0} compare(tmp_2, tmp_4), direction=GT - tmp_6 = f32[3,49]{1,0} convert(tmp_5) - tmp_7 = f32[3,49]{1,0} subtract(tmp_1, tmp_6) - tmp_8 = s32[] parameter(13) - tmp_9 = f32[] convert(tmp_8) - tmp_10 = f32[] maximum(tmp_9, tmp_0) - tmp_11 = f32[] divide(tmp_3, tmp_10) - tmp_12 = f32[3,49]{1,0} broadcast(tmp_11), dimensions={} - tmp_13 = pred[3,49]{1,0} parameter(7) - tmp_14 = pred[3,49]{1,0} parameter(10) - tmp_15 = pred[3,49]{1,0} and(tmp_13, tmp_14) - tmp_16 = f32[3,49]{1,0} convert(tmp_15) - tmp_17 = f32[3,49]{1,0} multiply(tmp_12, tmp_16) - tmp_18 = f32[3,49]{1,0} negate(tmp_17) - tmp_19 = f32[3,49]{1,0} multiply(tmp_7, tmp_18) - tmp_20 = f32[3,49]{1,0} parameter(19) - tmp_21 = f32[3,49]{1,0} subtract(tmp_1, tmp_20) - tmp_22 = f32[3,49]{1,0} divide(tmp_19, tmp_21) - tmp_23 = f32[3,49]{1,0} negate(tmp_22) - tmp_24 = f32[3,49]{1,0} negate(tmp_6) - tmp_25 = f32[3,49]{1,0} multiply(tmp_24, tmp_17) - tmp_26 = f32[3,49]{1,0} divide(tmp_25, tmp_20) - tmp_27 = f32[3,49]{1,0} add(tmp_23, tmp_26) - tmp_28 = f32[3,49]{1,0} parameter(18) - tmp_29 = f32[3,49]{1,0} multiply(tmp_27, tmp_28) - tmp_30 = f32[3,49]{1,0} parameter(17) - tmp_31 = f32[3,49]{1,0} multiply(tmp_29, tmp_30) - tmp_32 = f32[3,49]{1,0} parameter(16) - tmp_33 = f32[3,49]{1,0} multiply(tmp_31, tmp_32) - tmp_34 = f32[3,49]{1,0} parameter(15) - tmp_35 = f32[3,49]{1,0} add(tmp_33, tmp_34) - tmp_36 = f32[3,49]{1,0} parameter(14) - tmp_37 = f32[3,49]{1,0} add(tmp_35, tmp_36) - tmp_38 = f32[1,1]{1,0} constant({ {0} }) - tmp_39 = f32[1,1]{1,0} broadcast(tmp_38), dimensions={0,1} - tmp_40 = f32[] reshape(tmp_39) - tmp_41 = f32[3,32]{1,0} broadcast(tmp_40), dimensions={} - tmp_42 = u32[48]{0} parameter(11) - tmp_43 = u32[48]{0} parameter(5) - tmp_44 = u32[96]{0} concatenate(tmp_42, tmp_43), dimensions={0} - tmp_45 = u32[3,32]{1,0} reshape(tmp_44) - tmp_46 = u32[96]{0} reshape(tmp_45) - tmp_47 = u32[] constant(1) - tmp_48 = u32[3,32]{1,0} broadcast(tmp_47), dimensions={} - tmp_49 = u32[96]{0} reshape(tmp_48) - tmp_50 = u32[96]{0} shift-right-logical(tmp_46, tmp_49) - tmp_51 = u32[3,32]{1,0} reshape(tmp_50) - tmp_52 = u32[3,32]{1,0} or(tmp_51, tmp_48) - tmp_53 = f32[3,32]{1,0} bitcast-convert(tmp_52) - tmp_54 = f32[3,32]{1,0} broadcast(tmp_0), dimensions={} - tmp_55 = f32[3,32]{1,0} subtract(tmp_53, tmp_54) - tmp_56 = f32[1,1]{1,0} constant({ {1} }) - tmp_57 = f32[1,1]{1,0} broadcast(tmp_56), dimensions={0,1} - tmp_58 = f32[] reshape(tmp_57) - tmp_59 = f32[3,32]{1,0} broadcast(tmp_58), dimensions={} - tmp_60 = f32[3,32]{1,0} multiply(tmp_55, tmp_59) - tmp_61 = f32[3,32]{1,0} add(tmp_60, tmp_41) - tmp_62 = f32[3,32]{1,0} maximum(tmp_41, tmp_61) - tmp_63 = f32[3,32]{1,0} broadcast(tmp_3), dimensions={} - tmp_64 = pred[3,32]{1,0} compare(tmp_62, tmp_63), direction=LT - tmp_65 = f32[3,32]{1,0} convert(tmp_64) - tmp_66 = f32[3,49]{1,0} parameter(9) - tmp_67 = f32[49]{0} parameter(4) - tmp_68 = f32[3,49]{1,0} broadcast(tmp_67), dimensions={1} - tmp_69 = f32[3,49]{1,0} add(tmp_66, tmp_68) - tmp_70 = f32[1,49]{1,0} parameter(12) - tmp_71 = f32[1,49]{1,0} broadcast(tmp_0), dimensions={} - tmp_72 = f32[1,49]{1,0} divide(tmp_70, tmp_71) - tmp_73 = f32[1,49]{1,0} broadcast(tmp_72), dimensions={0,1} - tmp_74 = f32[49]{0} reshape(tmp_73) - tmp_75 = f32[3,49]{1,0} broadcast(tmp_74), dimensions={1} - tmp_76 = f32[3,49]{1,0} subtract(tmp_69, tmp_75) - tmp_77 = f32[1,49]{1,0} parameter(3) - tmp_78 = f32[1,49]{1,0} parameter(8) - tmp_79 = f32[1,49]{1,0} divide(tmp_78, tmp_71) - tmp_80 = f32[1,49]{1,0} multiply(tmp_72, tmp_72) - tmp_81 = f32[1,49]{1,0} subtract(tmp_79, tmp_80) - tmp_82 = f32[1,49]{1,0} add(tmp_81, tmp_71) - tmp_83 = f32[1,49]{1,0} rsqrt(tmp_82) - tmp_84 = f32[1,49]{1,0} multiply(tmp_77, tmp_83) - tmp_85 = f32[1,49]{1,0} broadcast(tmp_84), dimensions={0,1} - tmp_86 = f32[49]{0} reshape(tmp_85) - tmp_87 = f32[3,49]{1,0} broadcast(tmp_86), dimensions={1} - tmp_88 = f32[3,49]{1,0} multiply(tmp_76, tmp_87) - tmp_89 = f32[1,49]{1,0} parameter(2) - tmp_90 = f32[1,49]{1,0} broadcast(tmp_89), dimensions={0,1} - tmp_91 = f32[49]{0} reshape(tmp_90) - tmp_92 = f32[3,49]{1,0} broadcast(tmp_91), dimensions={1} - tmp_93 = f32[3,49]{1,0} add(tmp_88, tmp_92) - tmp_94 = f32[49,32]{1,0} parameter(1) - tmp_95 = f32[3,32]{1,0} dot(tmp_93, tmp_94), lhs_contracting_dims={1}, rhs_contracting_dims={0} - tmp_96 = f32[32]{0} parameter(0) - tmp_97 = f32[3,32]{1,0} broadcast(tmp_96), dimensions={1} - tmp_98 = f32[3,32]{1,0} add(tmp_95, tmp_97) - tmp_99 = f32[3,32]{1,0} multiply(tmp_65, tmp_98) - tmp_100 = f32[3,32]{1,0} divide(tmp_99, tmp_63) - tmp_101 = f32[3,32]{1,0} maximum(tmp_100, tmp_63) - ROOT tmp_102 = f32[49,32]{1,0} dot(tmp_37, tmp_101), lhs_contracting_dims={0}, rhs_contracting_dims={0} -})")); - - EXPECT_TRUE(GemmRewriterTriton(gpu_version_).Run(module.get()).value()); - EXPECT_EQ(module->entry_computation()->root_instruction()->opcode(), - HloOpcode::kFusion); - EXPECT_EQ(module->entry_computation()->root_instruction()->fusion_kind(), - HloInstruction::FusionKind::kCustom); - EXPECT_LE(module->entry_computation()->root_instruction()->operand_count(), - DotFusionAnalysis::kMaxParameterPerScope * 2); -} - } // namespace } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index b3944952ac68da..f490f9b127e21a 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -973,29 +973,6 @@ Status GpuCompiler::OptimizeHloPostLayoutAssignment( }); } - GpuFloatSupport bf16_support(BF16); - GpuFloatSupport f8e5m2_support(F8E5M2); - GpuFloatSupport f8e4m3fn_support(F8E4M3FN); - FloatSupport f8e4m3b11fnuz_support(F8E4M3B11FNUZ); - FloatSupport f8e5m2fnuz_support(F8E5M2FNUZ); - FloatSupport f8e4m3fnuz_support(F8E4M3FNUZ); - - auto add_float_normalization = [&](HloPassPipeline& pipeline) { - auto& sub_pipeline = - pipeline.AddPass("float_normalization"); - sub_pipeline.AddPass(&bf16_support); - sub_pipeline.AddPass(&f8e5m2_support); - sub_pipeline.AddPass(&f8e4m3fn_support); - sub_pipeline.AddPass(&f8e4m3b11fnuz_support); - sub_pipeline.AddPass(&f8e5m2fnuz_support); - sub_pipeline.AddPass(&f8e4m3fnuz_support); - // Remove `f32 -> bf16 -> f32` casts inserted by bf16 normalization. - if (debug_options.xla_gpu_simplify_all_fp_conversions()) { - sub_pipeline.AddPass(); - } - }; - add_float_normalization(pipeline); - // By default use an externally provided thread pool. tsl::thread::ThreadPool* thread_pool = options.thread_pool; std::optional overriding_thread_pool; @@ -1017,8 +994,18 @@ Status GpuCompiler::OptimizeHloPostLayoutAssignment( &pipeline, hlo_module, stream_exec, debug_options, options, gpu_target_config, autotune_results, thread_pool)); - // The Triton autotuner can insert new reductions. - add_float_normalization(pipeline); + GpuFloatSupport bf16_support(BF16); + pipeline.AddPass(&bf16_support); + GpuFloatSupport f8e5m2_support(F8E5M2); + pipeline.AddPass(&f8e5m2_support); + GpuFloatSupport f8e4m3fn_support(F8E4M3FN); + pipeline.AddPass(&f8e4m3fn_support); + FloatSupport f8e4m3b11fnuz_support(F8E4M3B11FNUZ); + pipeline.AddPass(&f8e4m3b11fnuz_support); + FloatSupport f8e5m2fnuz_support(F8E5M2FNUZ); + pipeline.AddPass(&f8e5m2fnuz_support); + FloatSupport f8e4m3fnuz_support(F8E4M3FNUZ); + pipeline.AddPass(&f8e4m3fnuz_support); // Remove `f32 -> bf16 -> f32` casts inserted by bf16 normalization. if (debug_options.xla_gpu_simplify_all_fp_conversions()) { diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_triton.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_triton.cc index 7c9cd87953a848..709f3e40b52c3f 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_triton.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_triton.cc @@ -792,7 +792,7 @@ StatusOr MatMulImpl( if (!analysis.ScopeParameters(DotFusionAnalysis::Scope::LHS).empty()) { const HloInstruction* lhs_param0 = *analysis.ScopeParameters(DotFusionAnalysis::Scope::LHS).begin(); - const TensorIterationSpec::DimIterationSpec* lhs_nc_iter_spec = + const DotFusionAnalysis::DimIterationSpec* lhs_nc_iter_spec = analysis.IterSpec(DotFusionAnalysis::Scope::LHS, lhs_param0, lhs_noncontracting_dim_idx); lhs_nc_split = lhs_nc_iter_spec->size() > 1; diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_triton_test.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_triton_test.cc index e87b973b4a60c8..fc4bb7204c1632 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_triton_test.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_triton_test.cc @@ -25,14 +25,11 @@ limitations under the License. #include "tensorflow/compiler/xla/autotuning.pb.h" #include "tensorflow/compiler/xla/error_spec.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_computation.h" -#include "tensorflow/compiler/xla/hlo/ir/hlo_instruction.h" #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" #include "tensorflow/compiler/xla/service/gpu/gpu_device_info_for_tests.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/gpu/launch_dimensions.h" #include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h" -#include "tensorflow/compiler/xla/service/pattern_matcher.h" -#include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h" #include "tensorflow/compiler/xla/stream_executor/device_description.h" #include "tensorflow/compiler/xla/tests/verified_hlo_module.h" #include "tensorflow/tsl/lib/core/status_test_util.h" @@ -45,8 +42,6 @@ namespace xla { namespace gpu { namespace { -namespace m = ::xla::match; - class TritonGemmNoTF32Test : public GpuCodegenTest { public: void SetUp() override { @@ -720,162 +715,6 @@ ENTRY e { EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{/*aabs=*/1e-6, /*arel=*/1e-6})); } -class TritonGemmLevel2Test : public TritonGemmTest { - public: - DebugOptions GetDebugOptionsForTest() override { - DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest(); - debug_options.set_xla_gpu_triton_fusion_level(2); - return debug_options; - } -}; - -TEST_F(TritonGemmLevel2Test, BinaryOperationWithSmallInputsIsFused) { - const std::string kHloText = R"( -HloModule m - -ENTRY e { - p0 = s8[7,3] parameter(0) - p1 = f32[3,16] parameter(1) - p2 = f32[3,16] parameter(2) - e = f32[3,16] exponential(p1) - a = f32[3,16] add(e, p2) - c = f32[7,3] convert(p0) - ROOT d = f32[7,16] dot(c, a), - lhs_contracting_dims={1}, rhs_contracting_dims={0} -})"; - - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - GetOptimizedModule(kHloText)); - - EXPECT_THAT( - module->entry_computation()->root_instruction(), - GmockMatch(m::Fusion(m::Parameter(), m::Parameter(), m::Parameter()) - .WithFusionKind(HloInstruction::FusionKind::kCustom))); - - EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-1, /*arel=*/1e-3})); -} - -TEST_F(TritonGemmLevel2Test, BinaryOperationWithLargeInputsIsNotFused) { - const std::string kHloText = R"( -HloModule m - -ENTRY e { - p0 = f16[333,1000] parameter(0) - p1 = f32[1000,333] parameter(1) - p1n = f32[1000,333] negate(p1) - p2 = f32[1000,333] parameter(2) - p2n = f32[1000,333] negate(p2) - s = f32[1000,333] subtract(p1n, p2n) - c = f32[333,1000] convert(p0) - ROOT d = f32[1000,1000] dot(s, c), - lhs_contracting_dims={1}, rhs_contracting_dims={0} -})"; - - MatchOptimizedHlo(kHloText, R"( -; CHECK: fused_computation -; CHECK: negate -; CHECK: negate -; CHECK: ROOT -; CHECK-SAME: subtract -; CHECK: ENTRY -; CHECK: kLoop -; CHECK: kCustom -)"); - - EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-1, /*arel=*/1e-3})); -} - -TEST_F(TritonGemmLevel2Test, BinaryOperationOnLargeParametersIsFused) { - const std::string kHloText = R"( -HloModule m - -ENTRY e { - p0 = f16[1000,111] parameter(0) - p1 = f32[111,10000] parameter(1) - p2 = f32[111,10000] parameter(2) - s = f32[111,10000] subtract(p1, p2) - c = f32[1000,111] convert(p0) - ROOT d = f32[10000,1000] dot(s, c), - lhs_contracting_dims={0}, rhs_contracting_dims={1} -})"; - - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - GetOptimizedModule(kHloText)); - - EXPECT_THAT( - module->entry_computation()->root_instruction(), - GmockMatch(m::Fusion(m::Parameter(), m::Parameter(), m::Parameter()) - .WithFusionKind(HloInstruction::FusionKind::kCustom))); - - EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-1, /*arel=*/1e-3})); -} - -TEST_F(TritonGemmLevel2Test, LinkingLibdeviceTwiceWorks) { - const std::string kHloText = R"( -HloModule m - -ENTRY e { - p0 = s8[7,3] parameter(0) - c0 = f32[7,3] convert(p0) - e0 = f32[7,3] exponential(c0) - p1 = f32[3,16] parameter(1) - e1 = f32[3,16] exponential(p1) - d0 = f32[7,16] dot(c0, e1), - lhs_contracting_dims={1}, rhs_contracting_dims={0} - d1 = f32[7,16] dot(e0, p1), - lhs_contracting_dims={1}, rhs_contracting_dims={0} - ROOT a = f32[7,16] add(d0, d1) -})"; - - MatchOptimizedHlo(kHloText, R"( -; CHECK: ENTRY -; CHECK-NEXT: parameter -; CHECK-NEXT: parameter -; CHECK-NEXT: kCustom -; CHECK-NEXT: kCustom -; CHECK-NEXT: ROOT -; CHECK-SAME: add -)"); - - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - GetOptimizedModule(kHloText)); - - EXPECT_THAT(module->entry_computation()->root_instruction(), - GmockMatch(m::Add( - m::Fusion(m::Parameter(), m::Parameter()) - .WithFusionKind(HloInstruction::FusionKind::kCustom), - m::Fusion(m::Parameter(), m::Parameter()) - .WithFusionKind(HloInstruction::FusionKind::kCustom)))); - - EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-2, /*arel=*/1e-2})); -} - -TEST_F(TritonGemmLevel2Test, BroadcastOfConstantIsNotFused) { - const std::string kHloText = R"( -HloModule m - -ENTRY e { - p0 = f16[70,30] parameter(0) - p0c = f32[70,30] convert(p0) - constant_3663 = f32[] constant(4321) - bc0 = f32[30,5] broadcast(constant_3663) - p1 = f32[30,5] parameter(1) - a = f32[30,5] add(p1, bc0) - ROOT d = f32[70,5] dot(p0c, a), - lhs_contracting_dims={1}, rhs_contracting_dims={0} -})"; - - MatchOptimizedHlo(kHloText, R"( -; CHECK: ENTRY -; CHECK: constant -; CHECK: broadcast -; CHECK: fusion -; CHECK-SAME: kind=kCustom -)"); - - EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/2e-3, /*arel=*/2e-3})); -} - TEST_F(TritonGemmTest, Naming) { const char* hlo_text = R"( HloModule t diff --git a/tensorflow/compiler/xla/service/gpu/triton_autotuner.cc b/tensorflow/compiler/xla/service/gpu/triton_autotuner.cc index b8b8b5f6719931..440a9611a8fe27 100644 --- a/tensorflow/compiler/xla/service/gpu/triton_autotuner.cc +++ b/tensorflow/compiler/xla/service/gpu/triton_autotuner.cc @@ -418,11 +418,12 @@ std::vector GetExhaustiveMatmulAutotuneConfigs( std::vector GetFixedMatmulAutotuneConfigs( const se::CudaComputeCapability compute_capability) { std::vector configs = { - GemmKey(32, 32, 256, 1, 1, 4), GemmKey(64, 32, 32, 16, 1, 4), - GemmKey(32, 64, 64, 4, 1, 4), GemmKey(16, 16, 256, 1, 1, 4), - GemmKey(16, 128, 32, 16, 1, 4), GemmKey(16, 64, 128, 1, 1, 4), - GemmKey(16, 128, 32, 8, 1, 4), GemmKey(16, 16, 512, 1, 1, 4), - GemmKey(32, 16, 512, 1, 1, 4), GemmKey(64, 32, 64, 1, 2, 8)}; + GemmKey(32, 32, 256, 1, 1, 4), GemmKey(64, 32, 32, 16, 1, 4), + GemmKey(32, 64, 64, 4, 1, 4), GemmKey(128, 128, 64, 4, 1, 4), + GemmKey(16, 16, 256, 1, 1, 4), GemmKey(16, 128, 32, 16, 1, 4), + GemmKey(16, 64, 128, 1, 1, 4), GemmKey(16, 128, 32, 8, 1, 4), + GemmKey(16, 16, 512, 1, 1, 4), GemmKey(32, 16, 512, 1, 1, 4), + GemmKey(64, 32, 64, 1, 2, 8)}; if (compute_capability.IsAtLeast(se::CudaComputeCapability::AMPERE)) { absl::c_copy( std::vector{ diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto index f41923aeae68b6..3eb4ae20db045d 100644 --- a/tensorflow/compiler/xla/xla.proto +++ b/tensorflow/compiler/xla/xla.proto @@ -570,9 +570,7 @@ message DebugOptions { bool xla_gpu_triton_gemm_disable_reduced_precision_reduction = 226; - int32 xla_gpu_triton_fusion_level = 229; - - // Next id: 230 + // Next id: 229 // Extra options to pass to the compilation backend (e.g. LLVM); specific // interpretation of these values is left to the backend. From 6df4c01ac11ed54688c8d54dec0a1e385b8e1ef7 Mon Sep 17 00:00:00 2001 From: Fiona Lang Date: Tue, 11 Jul 2023 16:14:13 -0700 Subject: [PATCH 158/376] Update ops.Tensor references to //third_party/tensorflow/python/framework/tensor.py. PiperOrigin-RevId: 547326089 --- .../data/experimental/kernel_tests/BUILD | 3 +- .../make_batched_features_dataset_test.py | 6 +- .../parse_example_dataset_test.py | 7 ++- tensorflow/python/data/experimental/ops/BUILD | 4 +- .../data/experimental/ops/data_service_ops.py | 11 ++-- .../data/experimental/ops/lookup_ops.py | 10 ++-- tensorflow/python/data/kernel_tests/BUILD | 3 +- .../data/kernel_tests/checkpoint_test_base.py | 3 +- .../python/data/kernel_tests/iterator_test.py | 19 +++--- tensorflow/python/data/ops/BUILD | 3 +- tensorflow/python/data/ops/iterator_ops.py | 8 +-- tensorflow/python/data/ops/ragged_batch_op.py | 9 ++- .../data/ops/sample_from_datasets_op.py | 5 +- tensorflow/python/data/util/BUILD | 9 ++- tensorflow/python/data/util/sparse.py | 4 +- tensorflow/python/data/util/sparse_test.py | 49 +++++++-------- tensorflow/python/data/util/structure.py | 8 +-- tensorflow/python/data/util/structure_test.py | 59 +++++++++---------- 18 files changed, 114 insertions(+), 106 deletions(-) diff --git a/tensorflow/python/data/experimental/kernel_tests/BUILD b/tensorflow/python/data/experimental/kernel_tests/BUILD index d58641c73509fe..f659a77f5914fa 100644 --- a/tensorflow/python/data/experimental/kernel_tests/BUILD +++ b/tensorflow/python/data/experimental/kernel_tests/BUILD @@ -313,7 +313,7 @@ tf_py_strict_test( "//tensorflow/python/framework:combinations", "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:errors", - "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/ops:io_ops", "//tensorflow/python/ops:parsing_ops", "//tensorflow/python/platform:client_testlib", @@ -537,6 +537,7 @@ tf_py_strict_test( "//tensorflow/python/framework:errors", "//tensorflow/python/framework:ops", "//tensorflow/python/framework:sparse_tensor", + "//tensorflow/python/framework:tensor", "//tensorflow/python/ops:parsing_ops", "//tensorflow/python/ops/ragged:ragged_factory_ops", "//tensorflow/python/platform:client_testlib", diff --git a/tensorflow/python/data/experimental/kernel_tests/make_batched_features_dataset_test.py b/tensorflow/python/data/experimental/kernel_tests/make_batched_features_dataset_test.py index 6d7b26ce88c10a..bfba1b69547d5d 100644 --- a/tensorflow/python/data/experimental/kernel_tests/make_batched_features_dataset_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/make_batched_features_dataset_test.py @@ -25,7 +25,7 @@ from tensorflow.python.framework import combinations from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors -from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.ops import io_ops from tensorflow.python.ops import parsing_ops from tensorflow.python.platform import test @@ -214,7 +214,7 @@ def testDropFinalBatch(self, batch_size, num_epochs): batch_size=batch_size, drop_final_batch=True) for tensor in nest.flatten(outputs): - if isinstance(tensor, ops.Tensor): # Guard against SparseTensor. + if isinstance(tensor, tensor_lib.Tensor): # Guard against SparseTensor. self.assertEqual(tensor.shape[0], batch_size) @combinations.generate(test_base.default_test_combinations()) @@ -227,7 +227,7 @@ def testIndefiniteRepeatShapeInference(self): for shape, clazz in zip( nest.flatten(dataset_ops.get_legacy_output_shapes(dataset)), nest.flatten(dataset_ops.get_legacy_output_classes(dataset))): - if issubclass(clazz, ops.Tensor): + if issubclass(clazz, tensor_lib.Tensor): self.assertEqual(32, shape[0]) @combinations.generate(test_base.default_test_combinations()) diff --git a/tensorflow/python/data/experimental/kernel_tests/parse_example_dataset_test.py b/tensorflow/python/data/experimental/kernel_tests/parse_example_dataset_test.py index 373aebc2d3e231..28e9c1379b9c73 100644 --- a/tensorflow/python/data/experimental/kernel_tests/parse_example_dataset_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/parse_example_dataset_test.py @@ -33,6 +33,7 @@ from tensorflow.python.framework import errors_impl from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor from tensorflow.python.ops import parsing_ops from tensorflow.python.ops.ragged import ragged_factory_ops from tensorflow.python.platform import test @@ -96,8 +97,10 @@ def _test(self, # Check shapes; if serialized is a Tensor we need its size to # properly check. batch_size = ( - self.evaluate(input_tensor).size if isinstance(input_tensor, ops.Tensor) - else np.asarray(input_tensor).size) + self.evaluate(input_tensor).size + if isinstance(input_tensor, tensor.Tensor) + else np.asarray(input_tensor).size + ) for k, f in feature_val.items(): if isinstance(f, parsing_ops.FixedLenFeature) and f.shape is not None: self.assertEqual( diff --git a/tensorflow/python/data/experimental/ops/BUILD b/tensorflow/python/data/experimental/ops/BUILD index d2018a755e031f..9f96277f5dcd24 100644 --- a/tensorflow/python/data/experimental/ops/BUILD +++ b/tensorflow/python/data/experimental/ops/BUILD @@ -81,7 +81,7 @@ py_strict_library( "//tensorflow/python/eager:context", "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", - "//tensorflow/python/framework:tensor_spec", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_util", "//tensorflow/python/ops:experimental_dataset_ops_gen", "//tensorflow/python/ops:string_ops", @@ -243,7 +243,7 @@ py_strict_library( ":cardinality", "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", - "//tensorflow/python/framework:tensor_spec", + "//tensorflow/python/framework:tensor", "//tensorflow/python/ops:experimental_dataset_ops_gen", "//tensorflow/python/ops:lookup_ops", "//tensorflow/python/ops:math_ops", diff --git a/tensorflow/python/data/experimental/ops/data_service_ops.py b/tensorflow/python/data/experimental/ops/data_service_ops.py index d2732ed973944c..3a003314ec49d5 100644 --- a/tensorflow/python/data/experimental/ops/data_service_ops.py +++ b/tensorflow/python/data/experimental/ops/data_service_ops.py @@ -30,7 +30,7 @@ from tensorflow.python.eager import context from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_spec +from tensorflow.python.framework import tensor from tensorflow.python.framework import tensor_util from tensorflow.python.ops import gen_experimental_dataset_ops from tensorflow.python.ops import string_ops @@ -200,7 +200,7 @@ def _get_compression_proto(compression): def _to_tensor(dataset_id): """Converts `dataset_id` to Tensor.""" - if isinstance(dataset_id, ops.Tensor): + if isinstance(dataset_id, tensor.Tensor): return dataset_id if isinstance(dataset_id, str) or isinstance(dataset_id, bytes): return ops.convert_to_tensor( @@ -212,7 +212,7 @@ def _to_tensor(dataset_id): def _to_string(dataset_id): """Converts `dataset_id` to string.""" - if isinstance(dataset_id, ops.Tensor): + if isinstance(dataset_id, tensor.Tensor): return (dataset_id if dataset_id.dtype == dtypes.string else string_ops.as_string(dataset_id)) return (dataset_id.decode() @@ -334,7 +334,7 @@ def __init__(self, uncompress_func = structured_function.StructuredFunctionWrapper( lambda x: compression_ops.uncompress(x, output_spec=element_spec), transformation_name="DataServiceDataset.uncompress()", - input_structure=tensor_spec.TensorSpec(shape=(), dtype=dtypes.variant)) + input_structure=tensor.TensorSpec(shape=(), dtype=dtypes.variant)) cross_trainer_cache_options = ( cross_trainer_cache._to_proto().SerializeToString() if cross_trainer_cache else None) @@ -1004,7 +1004,8 @@ def _get_element_spec(): else: protocol, address = _parse_service(service) if job_name is not None: - if not isinstance(job_name, str) and not isinstance(job_name, ops.Tensor): + if not isinstance(job_name, str) and not isinstance( + job_name, tensor.Tensor): raise ValueError( "`job_name` must be a string or Tensor, but `job_name` was of type " f"{type(job_name)}. job_name={job_name}.") diff --git a/tensorflow/python/data/experimental/ops/lookup_ops.py b/tensorflow/python/data/experimental/ops/lookup_ops.py index 6fc0fef4761cf0..aef2902813eca1 100644 --- a/tensorflow/python/data/experimental/ops/lookup_ops.py +++ b/tensorflow/python/data/experimental/ops/lookup_ops.py @@ -17,7 +17,7 @@ from tensorflow.python.data.experimental.ops.cardinality import assert_cardinality from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_spec +from tensorflow.python.framework import tensor from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import math_ops @@ -35,10 +35,10 @@ def _check_table_initializer_element_spec(element_spec): f"{len(element_spec)} components instead of two " "(key, value) components. Full dataset element spec: " f"{element_spec}.") - if not isinstance(element_spec[0], tensor_spec.TensorSpec): + if not isinstance(element_spec[0], tensor.TensorSpec): raise ValueError(base_error + "However, the given dataset produces " f"non-Tensor keys of type {type(element_spec[0])}.") - if not isinstance(element_spec[1], tensor_spec.TensorSpec): + if not isinstance(element_spec[1], tensor.TensorSpec): raise ValueError(base_error + "However, the given dataset produces " f"non-Tensor values of type {type(element_spec[1])}.") if element_spec[0].shape.rank not in (None, 0): @@ -163,14 +163,14 @@ def table_from_dataset(dataset=None, if num_oov_buckets < 0: raise ValueError("`num_oov_buckets` must be greater than or equal to 0, " f"got {num_oov_buckets}.") - if (not isinstance(vocab_size, ops.Tensor) and vocab_size is not None and + if (not isinstance(vocab_size, tensor.Tensor) and vocab_size is not None and vocab_size < 1): raise ValueError(f"`vocab_size` must be greater than 0, got {vocab_size}.") if (not key_dtype.is_integer) and (dtypes.string != key_dtype.base_dtype): raise TypeError("`key_dtype` must be either an integer or string type, " f"but got {key_dtype}") if vocab_size is not None: - if isinstance(vocab_size, ops.Tensor): + if isinstance(vocab_size, tensor.Tensor): vocab_size = math_ops.cast(vocab_size, dtypes.int64) dataset = dataset.take(vocab_size) dataset = dataset.apply(assert_cardinality(vocab_size)) diff --git a/tensorflow/python/data/kernel_tests/BUILD b/tensorflow/python/data/kernel_tests/BUILD index cfdf01e0e8a9e2..27c89b47d91777 100644 --- a/tensorflow/python/data/kernel_tests/BUILD +++ b/tensorflow/python/data/kernel_tests/BUILD @@ -172,6 +172,7 @@ py_strict_library( "//tensorflow/python/framework:errors", "//tensorflow/python/framework:ops", "//tensorflow/python/framework:sparse_tensor", + "//tensorflow/python/framework:tensor", "//tensorflow/python/ops:lookup_ops", "//tensorflow/python/ops:variables", "//tensorflow/python/ops/ragged:ragged_tensor_value", @@ -621,7 +622,7 @@ cuda_py_strict_test( "//tensorflow/python/framework:function", "//tensorflow/python/framework:ops", "//tensorflow/python/framework:sparse_tensor", - "//tensorflow/python/framework:tensor_spec", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:test_lib", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:data_flow_ops", diff --git a/tensorflow/python/data/kernel_tests/checkpoint_test_base.py b/tensorflow/python/data/kernel_tests/checkpoint_test_base.py index 5f42c9fc0053ff..21b9266c90b806 100644 --- a/tensorflow/python/data/kernel_tests/checkpoint_test_base.py +++ b/tensorflow/python/data/kernel_tests/checkpoint_test_base.py @@ -29,6 +29,7 @@ from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import variables from tensorflow.python.ops.ragged import ragged_tensor_value @@ -43,7 +44,7 @@ def remove_variants(get_next_op): """Remove variants from a nest structure, so sess.run will execute.""" def _remove_variant(x): - if isinstance(x, ops.Tensor) and x.dtype == dtypes.variant: + if isinstance(x, tensor.Tensor) and x.dtype == dtypes.variant: return () else: return x diff --git a/tensorflow/python/data/kernel_tests/iterator_test.py b/tensorflow/python/data/kernel_tests/iterator_test.py index e5cd37bb6e1db0..8d1e384e033683 100644 --- a/tensorflow/python/data/kernel_tests/iterator_test.py +++ b/tensorflow/python/data/kernel_tests/iterator_test.py @@ -35,7 +35,7 @@ from tensorflow.python.framework import function from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor -from tensorflow.python.framework import tensor_spec +from tensorflow.python.framework import tensor from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import data_flow_ops @@ -808,9 +808,8 @@ def testRepeatedGetNextWarning(self): combinations.times( test_base.default_test_combinations(), combinations.combine( - expected_element_structure=tensor_spec.TensorSpec([], - dtypes.float32), - expected_output_classes=ops.Tensor, + expected_element_structure=tensor.TensorSpec([], dtypes.float32), + expected_output_classes=tensor.Tensor, expected_output_types=dtypes.float32, expected_output_shapes=[[]]))) def testTensorIteratorStructure(self, expected_element_structure, @@ -872,13 +871,13 @@ def tf_value_fn(): combinations.combine( expected_element_structure={ "a": - tensor_spec.TensorSpec([], dtypes.float32), - "b": (tensor_spec.TensorSpec([1], dtypes.string), - tensor_spec.TensorSpec([], dtypes.string)) + tensor.TensorSpec([], dtypes.float32), + "b": (tensor.TensorSpec([1], dtypes.string), + tensor.TensorSpec([], dtypes.string)) }, expected_output_classes={ - "a": ops.Tensor, - "b": (ops.Tensor, ops.Tensor) + "a": tensor.Tensor, + "b": (tensor.Tensor, tensor.Tensor) }, expected_output_types={ "a": dtypes.float32, @@ -973,7 +972,7 @@ def finalize_fn(n): @def_function.function def fn(): - output_signature = tensor_spec.TensorSpec((), dtypes.int64) + output_signature = tensor.TensorSpec((), dtypes.int64) dataset = from_generator_op._GeneratorDataset(1, init_fn, next_fn, finalize_fn, output_signature) diff --git a/tensorflow/python/data/ops/BUILD b/tensorflow/python/data/ops/BUILD index b02e4d4985d959..0ca307c5b58c10 100644 --- a/tensorflow/python/data/ops/BUILD +++ b/tensorflow/python/data/ops/BUILD @@ -119,6 +119,7 @@ py_strict_library( "//tensorflow/python/framework:random_seed", "//tensorflow/python/framework:smart_cond", "//tensorflow/python/framework:sparse_tensor", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/framework:tensor_spec", "//tensorflow/python/framework:tensor_util", @@ -178,6 +179,7 @@ py_strict_library( "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:errors", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_conversion", "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/framework:tensor_spec", @@ -192,7 +194,6 @@ py_strict_library( "//tensorflow/python/util:_pywrap_utils", "//tensorflow/python/util:compat", "//tensorflow/python/util:deprecation", - "//tensorflow/python/util:lazy_loader", "//tensorflow/python/util:nest", "//tensorflow/python/util:tf_export", "//third_party/py/numpy", diff --git a/tensorflow/python/data/ops/iterator_ops.py b/tensorflow/python/data/ops/iterator_ops.py index 5354345577e9c2..82defdbf3210ff 100644 --- a/tensorflow/python/data/ops/iterator_ops.py +++ b/tensorflow/python/data/ops/iterator_ops.py @@ -30,8 +30,8 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import type_spec from tensorflow.python.framework import type_utils from tensorflow.python.ops import gen_dataset_ops @@ -219,7 +219,7 @@ def from_structure(output_types, tensor_shape.as_shape, output_shapes) if output_classes is None: - output_classes = nest.map_structure(lambda _: ops.Tensor, output_types) + output_classes = nest.map_structure(lambda _: tensor.Tensor, output_types) nest.assert_same_structure(output_types, output_shapes) output_structure = structure.convert_legacy_structure( output_types, output_shapes, output_classes) @@ -293,7 +293,7 @@ def from_string_handle(string_handle, tensor_shape.as_shape, output_shapes) if output_classes is None: - output_classes = nest.map_structure(lambda _: ops.Tensor, output_types) + output_classes = nest.map_structure(lambda _: tensor.Tensor, output_types) nest.assert_same_structure(output_types, output_shapes) output_structure = structure.convert_legacy_structure( output_types, output_shapes, output_classes) @@ -930,7 +930,7 @@ def _serialize(self): @property def _component_specs(self): - return (tensor_spec.TensorSpec([], dtypes.resource),) + return (tensor.TensorSpec([], dtypes.resource),) def _to_components(self, value): return (value._iterator_resource,) # pylint: disable=protected-access diff --git a/tensorflow/python/data/ops/ragged_batch_op.py b/tensorflow/python/data/ops/ragged_batch_op.py index 02a147e796d375..bb886ca7bbd5ff 100644 --- a/tensorflow/python/data/ops/ragged_batch_op.py +++ b/tensorflow/python/data/ops/ragged_batch_op.py @@ -17,8 +17,7 @@ from tensorflow.python.data.ops import structured_function from tensorflow.python.data.util import nest from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_spec +from tensorflow.python.framework import tensor from tensorflow.python.ops.ragged import ragged_tensor @@ -56,7 +55,7 @@ def __init__(self, input_dataset, row_splits_dtype, name=None): # corresponding RaggedTensorSpec. def to_ragged_spec(spec): """Returns the new spec based on RaggedTensors.""" - if (not isinstance(spec, tensor_spec.TensorSpec) or + if (not isinstance(spec, tensor.TensorSpec) or spec.shape.rank is None or spec.shape.is_fully_defined()): return spec @@ -80,12 +79,12 @@ def to_ragged_spec(spec): # RaggedTensorSpec._from_tensor_list. def to_ragged_variant(value): """Re-encode Tensors as RaggedTensors.""" - if (not isinstance(value, ops.Tensor) or + if (not isinstance(value, tensor.Tensor) or value.shape.rank is None or value.shape.is_fully_defined()): return value else: - spec = to_ragged_spec(tensor_spec.TensorSpec.from_tensor(value)) + spec = to_ragged_spec(tensor.TensorSpec.from_tensor(value)) if spec._ragged_rank > 0: # pylint: disable=protected-access value = ragged_tensor.RaggedTensor.from_tensor( value, ragged_rank=spec._ragged_rank) # pylint: disable=protected-access diff --git a/tensorflow/python/data/ops/sample_from_datasets_op.py b/tensorflow/python/data/ops/sample_from_datasets_op.py index 38c53bfb072050..29fc0d627d1436 100644 --- a/tensorflow/python/data/ops/sample_from_datasets_op.py +++ b/tensorflow/python/data/ops/sample_from_datasets_op.py @@ -19,6 +19,7 @@ from tensorflow.python.data.ops import map_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_stateless_random_ops from tensorflow.python.ops import math_ops @@ -48,7 +49,7 @@ def _skip_datasets_with_zero_weight(datasets, weights): logits = [[1.0] * len(datasets)] else: - if isinstance(weights, ops.Tensor): + if isinstance(weights, tensor.Tensor): if not weights.shape.is_compatible_with([len(datasets)]): raise ValueError(f"Invalid `weights`. The shape of `weights` " f"should be compatible with `[len(datasets)]` " @@ -62,7 +63,7 @@ def _skip_datasets_with_zero_weight(datasets, weights): # Use the given `weights` as the probability of choosing the respective # input. - if not isinstance(weights, ops.Tensor): + if not isinstance(weights, tensor.Tensor): datasets, weights = _skip_datasets_with_zero_weight(datasets, weights) weights = ops.convert_to_tensor(weights, name="weights") if weights.dtype not in (dtypes.float32, dtypes.float64): diff --git a/tensorflow/python/data/util/BUILD b/tensorflow/python/data/util/BUILD index 126c2d12967616..965106bad275b1 100644 --- a/tensorflow/python/data/util/BUILD +++ b/tensorflow/python/data/util/BUILD @@ -43,8 +43,8 @@ py_strict_library( deps = [ ":nest", "//tensorflow/python/framework:dtypes", - "//tensorflow/python/framework:ops", "//tensorflow/python/framework:sparse_tensor", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/ops:sparse_ops", ], @@ -63,8 +63,8 @@ py_strict_test( "//tensorflow/python/framework:combinations", "//tensorflow/python/framework:constant_op", "//tensorflow/python/framework:dtypes", - "//tensorflow/python/framework:ops", "//tensorflow/python/framework:sparse_tensor", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/platform:client_testlib", "@absl_py//absl/testing:parameterized", @@ -80,8 +80,8 @@ py_strict_library( "//tensorflow/python/framework:composite_tensor", "//tensorflow/python/framework:ops", "//tensorflow/python/framework:sparse_tensor", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_shape", - "//tensorflow/python/framework:tensor_spec", "//tensorflow/python/framework:type_spec", "//tensorflow/python/framework:type_spec_registry", "//tensorflow/python/ops:resource_variable_ops", @@ -110,10 +110,9 @@ py_strict_test( "//tensorflow/python/framework:combinations", "//tensorflow/python/framework:constant_op", "//tensorflow/python/framework:dtypes", - "//tensorflow/python/framework:ops", "//tensorflow/python/framework:sparse_tensor", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_shape", - "//tensorflow/python/framework:tensor_spec", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:tensor_array_ops", "//tensorflow/python/ops:variables", diff --git a/tensorflow/python/data/util/sparse.py b/tensorflow/python/data/util/sparse.py index f56a905058a0ff..1f1fc794b1ebae 100644 --- a/tensorflow/python/data/util/sparse.py +++ b/tensorflow/python/data/util/sparse.py @@ -15,8 +15,8 @@ """Python dataset sparse tensor utility functions.""" from tensorflow.python.data.util import nest from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import sparse_ops @@ -107,7 +107,7 @@ def get_classes(tensors): """ return nest.pack_sequence_as(tensors, [ sparse_tensor.SparseTensor - if isinstance(tensor, sparse_tensor.SparseTensor) else ops.Tensor + if isinstance(tensor, sparse_tensor.SparseTensor) else tensor_lib.Tensor for tensor in nest.flatten(tensors) ]) diff --git a/tensorflow/python/data/util/sparse_test.py b/tensorflow/python/data/util/sparse_test.py index eb1c592b975b20..c82067650d221d 100644 --- a/tensorflow/python/data/util/sparse_test.py +++ b/tensorflow/python/data/util/sparse_test.py @@ -24,8 +24,8 @@ from tensorflow.python.framework import combinations from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.platform import test @@ -39,11 +39,11 @@ def _test_any_sparse_combinations(): cases = [("TestCase_0", lambda: (), False), - ("TestCase_1", lambda: (ops.Tensor), False), - ("TestCase_2", lambda: (((ops.Tensor))), False), - ("TestCase_3", lambda: (ops.Tensor, ops.Tensor), False), + ("TestCase_1", lambda: (tensor.Tensor), False), + ("TestCase_2", lambda: (((tensor.Tensor))), False), + ("TestCase_3", lambda: (tensor.Tensor, tensor.Tensor), False), ("TestCase_4", lambda: - (ops.Tensor, sparse_tensor.SparseTensor), True), + (tensor.Tensor, sparse_tensor.SparseTensor), True), ("TestCase_5", lambda: (sparse_tensor.SparseTensor, sparse_tensor.SparseTensor), True), ("TestCase_6", lambda: (((sparse_tensor.SparseTensor))), True)] @@ -62,7 +62,8 @@ def _test_as_dense_shapes_combinations(): cases = [ ("TestCase_0", lambda: (), lambda: (), lambda: ()), - ("TestCase_1", lambda: tensor_shape.TensorShape([]), lambda: ops.Tensor, + ("TestCase_1", lambda: tensor_shape.TensorShape([]), + lambda: tensor.Tensor, lambda: tensor_shape.TensorShape([])), ( "TestCase_2", @@ -71,7 +72,7 @@ def _test_as_dense_shapes_combinations(): lambda: tensor_shape.unknown_shape() # pylint: disable=unnecessary-lambda ), ("TestCase_3", lambda: (tensor_shape.TensorShape([])), lambda: - (ops.Tensor), lambda: (tensor_shape.TensorShape([]))), + (tensor.Tensor), lambda: (tensor_shape.TensorShape([]))), ( "TestCase_4", lambda: (tensor_shape.TensorShape([])), @@ -79,9 +80,9 @@ def _test_as_dense_shapes_combinations(): lambda: (tensor_shape.unknown_shape()) # pylint: disable=unnecessary-lambda ), ("TestCase_5", lambda: (tensor_shape.TensorShape([]), ()), lambda: - (ops.Tensor, ()), lambda: (tensor_shape.TensorShape([]), ())), + (tensor.Tensor, ()), lambda: (tensor_shape.TensorShape([]), ())), ("TestCase_6", lambda: ((), tensor_shape.TensorShape([])), lambda: - ((), ops.Tensor), lambda: ((), tensor_shape.TensorShape([]))), + ((), tensor.Tensor), lambda: ((), tensor_shape.TensorShape([]))), ("TestCase_7", lambda: (tensor_shape.TensorShape([]), ()), lambda: (sparse_tensor.SparseTensor, ()), lambda: (tensor_shape.unknown_shape(), ())), @@ -90,14 +91,14 @@ def _test_as_dense_shapes_combinations(): (), tensor_shape.unknown_shape())), ("TestCase_9", lambda: (tensor_shape.TensorShape([]), (), tensor_shape.TensorShape([])), lambda: - (ops.Tensor, (), ops.Tensor), lambda: + (tensor.Tensor, (), tensor.Tensor), lambda: (tensor_shape.TensorShape([]), (), tensor_shape.TensorShape([]))), ("TestCase_10", lambda: (tensor_shape.TensorShape([]), (), tensor_shape.TensorShape([])), lambda: (sparse_tensor.SparseTensor, (), sparse_tensor.SparseTensor), lambda: (tensor_shape.unknown_shape(), (), tensor_shape.unknown_shape())), ("TestCase_11", lambda: ((), tensor_shape.TensorShape([]), ()), lambda: - ((), ops.Tensor, ()), lambda: ((), tensor_shape.TensorShape([]), ())), + ((), tensor.Tensor, ()), lambda: ((), tensor_shape.TensorShape([]), ())), ("TestCase_12", lambda: ((), tensor_shape.TensorShape([]), ()), lambda: ((), sparse_tensor.SparseTensor, ()), lambda: ((), tensor_shape.unknown_shape(), ())) @@ -118,29 +119,30 @@ def reduce_fn(x, y): def _test_as_dense_types_combinations(): cases = [ ("TestCase_0", lambda: (), lambda: (), lambda: ()), - ("TestCase_1", lambda: dtypes.int32, lambda: ops.Tensor, + ("TestCase_1", lambda: dtypes.int32, lambda: tensor.Tensor, lambda: dtypes.int32), ("TestCase_2", lambda: dtypes.int32, lambda: sparse_tensor.SparseTensor, lambda: dtypes.variant), - ("TestCase_3", lambda: (dtypes.int32), lambda: (ops.Tensor), lambda: + ("TestCase_3", lambda: (dtypes.int32), lambda: (tensor.Tensor), lambda: (dtypes.int32)), ("TestCase_4", lambda: (dtypes.int32), lambda: (sparse_tensor.SparseTensor), lambda: (dtypes.variant)), ("TestCase_5", lambda: (dtypes.int32, ()), lambda: - (ops.Tensor, ()), lambda: (dtypes.int32, ())), + (tensor.Tensor, ()), lambda: (dtypes.int32, ())), ("TestCase_6", lambda: ((), dtypes.int32), lambda: - ((), ops.Tensor), lambda: ((), dtypes.int32)), + ((), tensor.Tensor), lambda: ((), dtypes.int32)), ("TestCase_7", lambda: (dtypes.int32, ()), lambda: (sparse_tensor.SparseTensor, ()), lambda: (dtypes.variant, ())), ("TestCase_8", lambda: ((), dtypes.int32), lambda: ((), sparse_tensor.SparseTensor), lambda: ((), dtypes.variant)), ("TestCase_9", lambda: (dtypes.int32, (), dtypes.int32), lambda: - (ops.Tensor, (), ops.Tensor), lambda: (dtypes.int32, (), dtypes.int32)), + (tensor.Tensor, (), tensor.Tensor), + lambda: (dtypes.int32, (), dtypes.int32)), ("TestCase_10", lambda: (dtypes.int32, (), dtypes.int32), lambda: (sparse_tensor.SparseTensor, (), sparse_tensor.SparseTensor), lambda: (dtypes.variant, (), dtypes.variant)), ("TestCase_11", lambda: ((), dtypes.int32, ()), lambda: - ((), ops.Tensor, ()), lambda: ((), dtypes.int32, ())), + ((), tensor.Tensor, ()), lambda: ((), dtypes.int32, ())), ("TestCase_12", lambda: ((), dtypes.int32, ()), lambda: ((), sparse_tensor.SparseTensor, ()), lambda: ((), dtypes.variant, ())), ] @@ -163,11 +165,12 @@ def _test_get_classes_combinations(): ("TestCase_1", lambda: sparse_tensor.SparseTensor( indices=[[0]], values=[1], dense_shape=[1]), lambda: sparse_tensor.SparseTensor), - ("TestCase_2", lambda: constant_op.constant([1]), lambda: ops.Tensor), + ("TestCase_2", lambda: constant_op.constant([1]), lambda: tensor.Tensor), ("TestCase_3", lambda: (sparse_tensor.SparseTensor(indices=[[0]], values=[1], dense_shape=[1])), lambda: (sparse_tensor.SparseTensor)), - ("TestCase_4", lambda: (constant_op.constant([1])), lambda: (ops.Tensor)), + ("TestCase_4", lambda: (constant_op.constant([1])), + lambda: (tensor.Tensor)), ("TestCase_5", lambda: (sparse_tensor.SparseTensor(indices=[[0]], values=[1], dense_shape=[1]), ()), lambda: (sparse_tensor.SparseTensor, ())), @@ -176,19 +179,19 @@ def _test_get_classes_combinations(): sparse_tensor.SparseTensor(indices=[[0]], values=[1], dense_shape=[1])), lambda: ((), sparse_tensor.SparseTensor)), ("TestCase_7", lambda: (constant_op.constant([1]), ()), lambda: - (ops.Tensor, ())), + (tensor.Tensor, ())), ("TestCase_8", lambda: ((), constant_op.constant([1])), lambda: - ((), ops.Tensor)), + ((), tensor.Tensor)), ("TestCase_9", lambda: (sparse_tensor.SparseTensor(indices=[[0]], values=[1], dense_shape=[1]), (), constant_op.constant([1])), lambda: (sparse_tensor.SparseTensor, - (), ops.Tensor)), + (), tensor.Tensor)), ("TestCase_10", lambda: ((), sparse_tensor.SparseTensor(indices=[[0]], values=[1], dense_shape=[1]), ()), lambda: ((), sparse_tensor.SparseTensor, ())), ("TestCase_11", lambda: ((), constant_op.constant([1]), ()), lambda: - ((), ops.Tensor, ())), + ((), tensor.Tensor, ())), ] def reduce_fn(x, y): diff --git a/tensorflow/python/data/util/structure.py b/tensorflow/python/data/util/structure.py index 14e9cb0e4ff9a4..56d89095da40bc 100644 --- a/tensorflow/python/data/util/structure.py +++ b/tensorflow/python/data/util/structure.py @@ -23,8 +23,8 @@ from tensorflow.python.framework import composite_tensor from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import type_spec from tensorflow.python.framework import type_spec_registry from tensorflow.python.ops import resource_variable_ops @@ -41,7 +41,7 @@ @tf_export(v1=["data.experimental.TensorStructure"]) @deprecation.deprecated(None, "Use `tf.TensorSpec` instead.") def _TensorStructure(dtype, shape): - return tensor_spec.TensorSpec(shape, dtype) + return tensor_lib.TensorSpec(shape, dtype) @tf_export(v1=["data.experimental.SparseTensorStructure"]) @@ -171,8 +171,8 @@ def convert_legacy_structure(output_types, output_shapes, output_classes): flat_ret.append(flat_class) elif issubclass(flat_class, sparse_tensor.SparseTensor): flat_ret.append(sparse_tensor.SparseTensorSpec(flat_shape, flat_type)) - elif issubclass(flat_class, ops.Tensor): - flat_ret.append(tensor_spec.TensorSpec(flat_shape, flat_type)) + elif issubclass(flat_class, tensor_lib.Tensor): + flat_ret.append(tensor_lib.TensorSpec(flat_shape, flat_type)) elif issubclass(flat_class, tensor_array_ops.TensorArray): # We sneaked the dynamic_size and infer_shape into the legacy shape. flat_ret.append( diff --git a/tensorflow/python/data/util/structure_test.py b/tensorflow/python/data/util/structure_test.py index b2d33c8247ef48..6832a3b3d7c51b 100644 --- a/tensorflow/python/data/util/structure_test.py +++ b/tensorflow/python/data/util/structure_test.py @@ -28,10 +28,9 @@ from tensorflow.python.framework import combinations from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import tensor_spec from tensorflow.python.ops import array_ops from tensorflow.python.ops import tensor_array_ops from tensorflow.python.ops import variables @@ -50,7 +49,7 @@ def _test_flat_structure_combinations(): cases = [ ("Tensor", lambda: constant_op.constant(37.0), - lambda: tensor_spec.TensorSpec, lambda: [dtypes.float32], lambda: [[]]), + lambda: tensor.TensorSpec, lambda: [dtypes.float32], lambda: [[]]), ("TensorArray", lambda: tensor_array_ops.TensorArray( dtype=dtypes.float32, element_shape=(3,), size=0), lambda: tensor_array_ops.TensorArraySpec, lambda: [dtypes.variant], @@ -336,8 +335,8 @@ def reduce_fn(x, y): def _test_convert_legacy_structure_combinations(): cases = [ - (dtypes.float32, tensor_shape.TensorShape([]), ops.Tensor, - tensor_spec.TensorSpec([], dtypes.float32)), + (dtypes.float32, tensor_shape.TensorShape([]), tensor.Tensor, + tensor.TensorSpec([], dtypes.float32)), (dtypes.int32, tensor_shape.TensorShape([2, 2]), sparse_tensor.SparseTensor, sparse_tensor.SparseTensorSpec([2, 2], dtypes.int32)), @@ -369,13 +368,13 @@ def _test_convert_legacy_structure_combinations(): "a": tensor_shape.TensorShape([]), "b": (tensor_shape.TensorShape([2, 2]), tensor_shape.TensorShape([])) }, { - "a": ops.Tensor, - "b": (sparse_tensor.SparseTensor, ops.Tensor) + "a": tensor.Tensor, + "b": (sparse_tensor.SparseTensor, tensor.Tensor) }, { "a": - tensor_spec.TensorSpec([], dtypes.float32), + tensor.TensorSpec([], dtypes.float32), "b": (sparse_tensor.SparseTensorSpec([2, 2], dtypes.int32), - tensor_spec.TensorSpec([], dtypes.string)) + tensor.TensorSpec([], dtypes.string)) }) ] @@ -392,10 +391,10 @@ def reduce_fn(x, y): def _test_batch_combinations(): cases = [ - (tensor_spec.TensorSpec([], dtypes.float32), 32, - tensor_spec.TensorSpec([32], dtypes.float32)), - (tensor_spec.TensorSpec([], dtypes.float32), None, - tensor_spec.TensorSpec([None], dtypes.float32)), + (tensor.TensorSpec([], dtypes.float32), 32, + tensor.TensorSpec([32], dtypes.float32)), + (tensor.TensorSpec([], dtypes.float32), None, + tensor.TensorSpec([None], dtypes.float32)), (sparse_tensor.SparseTensorSpec([None], dtypes.float32), 32, sparse_tensor.SparseTensorSpec([32, None], dtypes.float32)), (sparse_tensor.SparseTensorSpec([4], dtypes.float32), None, @@ -406,14 +405,14 @@ def _test_batch_combinations(): ragged_tensor.RaggedTensorSpec([None, 4, None], dtypes.float32, 2)), ({ "a": - tensor_spec.TensorSpec([], dtypes.float32), + tensor.TensorSpec([], dtypes.float32), "b": (sparse_tensor.SparseTensorSpec([2, 2], dtypes.int32), - tensor_spec.TensorSpec([], dtypes.string)) + tensor.TensorSpec([], dtypes.string)) }, 128, { "a": - tensor_spec.TensorSpec([128], dtypes.float32), + tensor.TensorSpec([128], dtypes.float32), "b": (sparse_tensor.SparseTensorSpec([128, 2, 2], dtypes.int32), - tensor_spec.TensorSpec([128], dtypes.string)) + tensor.TensorSpec([128], dtypes.string)) }), ] @@ -429,10 +428,10 @@ def reduce_fn(x, y): def _test_unbatch_combinations(): cases = [ - (tensor_spec.TensorSpec([32], dtypes.float32), - tensor_spec.TensorSpec([], dtypes.float32)), - (tensor_spec.TensorSpec([None], dtypes.float32), - tensor_spec.TensorSpec([], dtypes.float32)), + (tensor.TensorSpec([32], dtypes.float32), + tensor.TensorSpec([], dtypes.float32)), + (tensor.TensorSpec([None], dtypes.float32), + tensor.TensorSpec([], dtypes.float32)), (sparse_tensor.SparseTensorSpec([32, None], dtypes.float32), sparse_tensor.SparseTensorSpec([None], dtypes.float32)), (sparse_tensor.SparseTensorSpec([None, 4], dtypes.float32), @@ -443,14 +442,14 @@ def _test_unbatch_combinations(): ragged_tensor.RaggedTensorSpec([None, None], dtypes.float32, 1)), ({ "a": - tensor_spec.TensorSpec([128], dtypes.float32), + tensor.TensorSpec([128], dtypes.float32), "b": (sparse_tensor.SparseTensorSpec([128, 2, 2], dtypes.int32), - tensor_spec.TensorSpec([None], dtypes.string)) + tensor.TensorSpec([None], dtypes.string)) }, { "a": - tensor_spec.TensorSpec([], dtypes.float32), + tensor.TensorSpec([], dtypes.float32), "b": (sparse_tensor.SparseTensorSpec([2, 2], dtypes.int32), - tensor_spec.TensorSpec([], dtypes.string)) + tensor.TensorSpec([], dtypes.string)) }), ] @@ -838,9 +837,9 @@ def testConvertLegacyStructureFail(self): @combinations.generate(test_base.default_test_combinations()) def testNestedNestedStructure(self): - s = (tensor_spec.TensorSpec([], dtypes.int64), - (tensor_spec.TensorSpec([], dtypes.float32), - tensor_spec.TensorSpec([], dtypes.string))) + s = (tensor.TensorSpec([], dtypes.int64), + (tensor.TensorSpec([], dtypes.float32), + tensor.TensorSpec([], dtypes.string))) int64_t = constant_op.constant(37, dtype=dtypes.int64) float32_t = constant_op.constant(42.0) @@ -917,7 +916,7 @@ def testToBatchedTensorList(self, value_fn, element_0_fn): def testDatasetSpecConstructor(self): rt_spec = ragged_tensor.RaggedTensorSpec([10, None], dtypes.int32) st_spec = sparse_tensor.SparseTensorSpec([10, 20], dtypes.float32) - t_spec = tensor_spec.TensorSpec([10, 8], dtypes.string) + t_spec = tensor.TensorSpec([10, 8], dtypes.string) element_spec = {"rt": rt_spec, "st": st_spec, "t": t_spec} ds_struct = dataset_ops.DatasetSpec(element_spec, [5]) self.assertEqual(ds_struct._element_spec, element_spec) @@ -929,7 +928,7 @@ def testCustomMapping(self): elem = CustomMap(foo=constant_op.constant(37.)) spec = structure.type_spec_from_value(elem) self.assertIsInstance(spec, CustomMap) - self.assertEqual(spec["foo"], tensor_spec.TensorSpec([], dtypes.float32)) + self.assertEqual(spec["foo"], tensor.TensorSpec([], dtypes.float32)) @combinations.generate(test_base.default_test_combinations()) def testObjectProxy(self): From dac9af0b3e861336b7db97171a19e83f6365ed51 Mon Sep 17 00:00:00 2001 From: Jieying Luo Date: Tue, 11 Jul 2023 16:20:29 -0700 Subject: [PATCH 159/376] [TF:PJRT] Use ShapeUtil::Compatible when checking the compatibility between buffer on_device_shape and expected execution shape. Buffer on_device_shape may be dynamic. ShapeUtil::Equal will fail if it is dynamic. An alternative fix is to compare logical_on_device_shape. But getting logical_on_device_shape is blocking and may have performance impact. PiperOrigin-RevId: 547327689 --- .../compiler/xla/pjrt/pjrt_stream_executor_client.cc | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc b/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc index 11947167552f16..928704167ce290 100644 --- a/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc +++ b/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc @@ -1751,29 +1751,29 @@ struct TupleHandle { }; Status CheckCompatibleShapes(bool strict_shape_checking, - const Shape& buffer_shape, + const Shape& buffer_on_device_shape, const Shape& execution_shape, const TransferManager& transfer_manager, int parameter_index) { // TODO(misard) Support casting of tuple parameters. - if (strict_shape_checking || buffer_shape.IsTuple()) { - if (!ShapeUtil::Equal(buffer_shape, execution_shape)) { + if (strict_shape_checking || buffer_on_device_shape.IsTuple()) { + if (!ShapeUtil::Compatible(buffer_on_device_shape, execution_shape)) { return InvalidArgument( "Executable expected shape %s for argument %d but got " "incompatible " "shape %s", ShapeUtil::HumanStringWithLayout(execution_shape), parameter_index, - ShapeUtil::HumanStringWithLayout(buffer_shape)); + ShapeUtil::HumanStringWithLayout(buffer_on_device_shape)); } } else { - if (transfer_manager.GetByteSizeRequirement(buffer_shape) != + if (transfer_manager.GetByteSizeRequirement(buffer_on_device_shape) != transfer_manager.GetByteSizeRequirement(execution_shape)) { return InvalidArgument( "Executable expected shape %s for argument %d but got " "incompatible " "shape %s", ShapeUtil::HumanStringWithLayout(execution_shape), parameter_index, - ShapeUtil::HumanStringWithLayout(buffer_shape)); + ShapeUtil::HumanStringWithLayout(buffer_on_device_shape)); } } return OkStatus(); From e713567d9bd9e4132bb794506ba55185a095e339 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 11 Jul 2023 16:46:42 -0700 Subject: [PATCH 160/376] [XLA] Extended associative reordering to work with arbitrary contracting dimensions PiperOrigin-RevId: 547334655 --- tensorflow/compiler/xla/service/BUILD | 1 + .../xla/service/algebraic_simplifier.cc | 206 +++++++++++++----- .../xla/service/algebraic_simplifier_test.cc | 20 +- 3 files changed, 159 insertions(+), 68 deletions(-) diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 41ea31d95a685b..100727bdd53f42 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -2309,6 +2309,7 @@ cc_library( srcs = ["algebraic_simplifier.cc"], hdrs = ["algebraic_simplifier.h"], deps = [ + ":hlo_cost_analysis", ":hlo_creation_utils", ":hlo_pass", ":pattern_matcher", diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 16c374647b7aef..17980a6c84413e 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -51,6 +51,7 @@ limitations under the License. #include "tensorflow/compiler/xla/overflow_util.h" #include "tensorflow/compiler/xla/permutation_util.h" #include "tensorflow/compiler/xla/primitive_util.h" +#include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" #include "tensorflow/compiler/xla/service/hlo_creation_utils.h" #include "tensorflow/compiler/xla/service/pattern_matcher.h" #include "tensorflow/compiler/xla/shape.h" @@ -395,19 +396,70 @@ bool ValidateTilingOfBitcast( return true; } -double GetDotFlops(const HloInstruction* dot) { - // A dot of arrays of size ab and bc requires ac(2b-1) flops - // In general, we compute the flops per element in the output shape - double contraction_prod = 1; - auto lhs_contracting_dims = - dot->dot_dimension_numbers().lhs_contracting_dimensions(); - for (auto dim : lhs_contracting_dims) { - contraction_prod *= dot->operand(0)->shape().dimensions(dim); +// Constructs the maps that take dims of A and dims of B to dims of AB, mapping +// to -1 for dimensions not present in AB. For an example, consider we are +// computing a dot whose operands have shapes [m,n,p] and [n,q]. Assuming we +// contract over n, this produces an array with shape [m,p,q]. This function +// will return vectors map_a_ab = {0, -1, 1} and map_b_ab = {-1, 2} +std::pair, std::vector> ConstructToDotMaps( + DotDimensionNumbers dnums, const Shape& a_shape, const Shape& b_shape) { + std::vector map_a_ab, map_b_ab; + int ab_index = 0; + // Extract a and b contraction dimensions from dnums + auto a_contracting_dims = dnums.lhs_contracting_dimensions(); + auto b_contracting_dims = dnums.rhs_contracting_dimensions(); + // Iterating through the dimensions of a + for (int a_index = 0; a_index < a_shape.rank(); a_index++) { + if (absl::c_linear_search(a_contracting_dims, a_index)) { + map_a_ab.push_back(-1); + } else { + map_a_ab.push_back(ab_index); + ab_index++; + } + } + // Iterating through the dimensions of b + for (int b_index = 0; b_index < b_shape.rank(); b_index++) { + if (absl::c_linear_search(b_contracting_dims, b_index)) { + map_b_ab.push_back(-1); + } else { + map_b_ab.push_back(ab_index); + ab_index++; + } } - // Flops include multiplications and adds - double flops_per_output_elem = 2 * contraction_prod - 1; - // We then multiply this number by the number of elements in the output shape - return flops_per_output_elem * ShapeUtil::ElementsIn(dot->shape()); + return {map_a_ab, map_b_ab}; +} + +// Constructs the maps that take dims of AB to dims of A and dims of B mapping +// to -1 for dimensions not present in A/B. For an example, consider we are +// computing a dot whose operands have shapes [m,n,p] and [n,q]. Assuming we +// contract over n, this produces an array with shape [m,p,q]. This function +// will return vectors map_ab_a = {0,2,-1} and map_ab_b = {-1,-1,1} +std::pair, std::vector> ConstructFromDotMaps( + const HloInstruction* dot, const Shape& a_shape, const Shape& b_shape) { + // Reserve space for new maps + std::vector map_ab_a, map_ab_b; + map_ab_a.resize(dot->shape().rank(), -1); + map_ab_b.resize(dot->shape().rank(), -1); + // Construct the maps going in the opposite direction + std::vector map_a_ab, map_b_ab; + std::tie(map_a_ab, map_b_ab) = + ConstructToDotMaps(dot->dot_dimension_numbers(), a_shape, b_shape); + // Construct these maps by inverting those above + int a_index = 0; + for (auto ab_index : map_a_ab) { + if (ab_index != -1) { + map_ab_a[ab_index] = a_index; + } + a_index++; + } + int b_index = 0; + for (auto ab_index : map_b_ab) { + if (ab_index != -1) { + map_ab_b[ab_index] = b_index; + } + b_index++; + } + return {map_ab_a, map_ab_b}; } } // namespace @@ -2832,6 +2884,90 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { return ReplaceWithNewInstruction(dot, std::move(new_instruction)); } + // Reorder nested dots with associativity using flops as a heuristic + if (options_.use_associative_reordering()) { + HloInstruction *a, *b, *c; + HloInstruction *old_inner, *old_outer, *new_inner, *new_outer; + DotDimensionNumbers ab_dnums, ac_dnums, bc_dnums; + // Here we extract the contracting dimensions shared between A and B, A and + // C, and B and C, and use these to build up the dimension numbers for the + // reordered dot A(BC). + if (Match(dot, m::Dot(m::Dot(m::Op(&a), m::Op(&b)), m::Op(&c))) && + dot->dot_dimension_numbers().lhs_batch_dimensions_size() == 0) { + // We already have the ab_dnums for free + ab_dnums = dot->operand(0)->dot_dimension_numbers(); + // Get maps for converting AB dimensions to A and B + std::vector map_ab_a, map_ab_b; + std::tie(map_ab_a, map_ab_b) = + ConstructFromDotMaps(dot->operand(0), a->shape(), b->shape()); + // Recover ac_dnums and bc_dnums from ab_c_dnums + DotDimensionNumbers ab_c_dnums = dot->dot_dimension_numbers(); + for (int i = 0; i < ab_c_dnums.lhs_contracting_dimensions_size(); i++) { + auto ab_index = ab_c_dnums.lhs_contracting_dimensions(i); + auto c_index = ab_c_dnums.rhs_contracting_dimensions(i); + if (map_ab_b[ab_index] == -1) { + ac_dnums.add_lhs_contracting_dimensions(map_ab_a[ab_index]); + ac_dnums.add_rhs_contracting_dimensions(c_index); + } else { + bc_dnums.add_lhs_contracting_dimensions(map_ab_b[ab_index]); + bc_dnums.add_rhs_contracting_dimensions(c_index); + } + } + + // Get maps for converting B and C dimensions to BC + std::vector map_b_bc, map_c_bc; + std::tie(map_b_bc, map_c_bc) = + ConstructToDotMaps(bc_dnums, b->shape(), c->shape()); + // Now build a_bc_dnums from ab_dnums and bc_dnums + DotDimensionNumbers a_bc_dnums; + for (int i = 0; i < ab_dnums.lhs_contracting_dimensions_size(); i++) { + auto a_index = ab_dnums.lhs_contracting_dimensions(i); + auto b_index = ab_dnums.rhs_contracting_dimensions(i); + a_bc_dnums.add_lhs_contracting_dimensions(a_index); + a_bc_dnums.add_rhs_contracting_dimensions(map_b_bc[b_index]); + } + for (int i = 0; i < ac_dnums.lhs_contracting_dimensions_size(); i++) { + auto a_index = ac_dnums.lhs_contracting_dimensions(i); + auto c_index = ac_dnums.rhs_contracting_dimensions(i); + a_bc_dnums.add_lhs_contracting_dimensions(a_index); + a_bc_dnums.add_rhs_contracting_dimensions(map_c_bc[c_index]); + } + + // Make Hlo for reordering dot + old_inner = lhs; + old_outer = dot; + TF_ASSIGN_OR_RETURN(new_inner, + MakeDotHlo(b, c, bc_dnums, dot->precision_config(), + dot->shape().element_type())); + TF_ASSIGN_OR_RETURN(new_outer, MakeDotHlo(a, new_inner, a_bc_dnums, + dot->precision_config(), + dot->shape().element_type())); + + // Use HloCostAnalysis to compute flops for both the original and + // reordered instructions, and reorder if doing so decreases flops by a + // factor of the reordering threshold. + const int64_t old_flops = + HloCostAnalysis::GetDotFlops(old_inner->operand(0)->shape(), + old_inner->shape(), + old_inner->dot_dimension_numbers()) + + HloCostAnalysis::GetDotFlops(old_outer->operand(0)->shape(), + old_outer->shape(), + old_outer->dot_dimension_numbers()); + const int64_t new_flops = + HloCostAnalysis::GetDotFlops(new_inner->operand(0)->shape(), + new_inner->shape(), + new_inner->dot_dimension_numbers()) + + HloCostAnalysis::GetDotFlops(new_outer->operand(0)->shape(), + new_outer->shape(), + new_outer->dot_dimension_numbers()); + if (old_flops / new_flops > options_.associative_reordering_threshold()) { + VLOG(10) << "Reordering with associativity"; + return ReplaceInstruction(dot, new_outer); + } + } + // TODO(b/289120301) Implement other direction after first looks good + } + // If the lhs or rhs have only batch and contracting dimensions, a dot can be // rewritten as reduce(mul(broadcast(transpose(x)),broadcast(transpose(y)))) if (!is_packed_nibble && options_.enable_dot_strength_reduction() && @@ -2944,52 +3080,6 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { if (removed_transposes) { return OkStatus(); } - - // Reorder nested dots with associativity using flops as a heuristic - if (options_.use_associative_reordering()) { - // TODO(b/289120301): Update with symmetric contraction form - HloInstruction *a, *b, *c; - HloInstruction *dot_a_b, *dot_ab_c, *dot_b_c, *dot_a_bc; - int64_t left_first_flops, right_first_flops; - if (Match(dot, m::Dot(m::Dot(m::Op(&a), m::Op(&b)), m::Op(&c)))) { - dot_a_b = lhs; - dot_ab_c = dot; - TF_ASSIGN_OR_RETURN( - dot_b_c, - MakeDotHlo(b, c, dot->dot_dimension_numbers(), - dot->precision_config(), dot->shape().element_type())); - TF_ASSIGN_OR_RETURN( - dot_a_bc, - MakeDotHlo(a, dot_b_c, dot_a_b->dot_dimension_numbers(), - dot->precision_config(), dot->shape().element_type())); - left_first_flops = GetDotFlops(dot_a_b) + GetDotFlops(dot_ab_c); - right_first_flops = GetDotFlops(dot_b_c) + GetDotFlops(dot_a_bc); - if (left_first_flops > - options_.associative_reordering_threshold() * right_first_flops) { - return ReplaceInstruction(dot, dot_a_bc); - } - } else if (Match(dot, m::Dot(m::Op(&a), m::Dot(m::Op(&b), m::Op(&c))))) { - dot_b_c = rhs; - dot_a_bc = dot; - TF_ASSIGN_OR_RETURN( - dot_a_b, - MakeDotHlo(a, b, dot->dot_dimension_numbers(), - dot->precision_config(), dot->shape().element_type())); - TF_ASSIGN_OR_RETURN( - dot_ab_c, - MakeDotHlo(dot_a_b, c, dot_b_c->dot_dimension_numbers(), - dot->precision_config(), dot->shape().element_type())); - left_first_flops = GetDotFlops(dot_a_b) + GetDotFlops(dot_ab_c); - right_first_flops = GetDotFlops(dot_b_c) + GetDotFlops(dot_a_bc); - if (right_first_flops > - options_.associative_reordering_threshold() * left_first_flops) { - return ReplaceInstruction(dot, dot_ab_c); - } - } else { - return OkStatus(); - } - } - return OkStatus(); } diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index 162aac070b98b7..b94c0eb3941502 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -5695,20 +5695,20 @@ TEST_F(AlgebraicSimplifierTest, TransposeOfDot) { PrecisionConfig::HIGHEST); } -TEST_F(AlgebraicSimplifierTest, DotAttentionReorder) { +TEST_F(AlgebraicSimplifierTest, DotAssociativeReorder) { const char* hlo_string = R"( HloModule module ENTRY test { - a = f32[1024,2] parameter(0) - b = f32[2,1024] parameter(1) - c = f32[1024,2] parameter(2) - inner_dot = f32[1024,1024] dot(a,b), - lhs_contracting_dims={1}, - rhs_contracting_dims={0} - ROOT outer_dot = f32[1024,2] dot(inner_dot, c), - lhs_contracting_dims={1}, - rhs_contracting_dims={0} + a = f32[2,3,4,5] parameter(0) + b = f32[6,7,5] parameter(1) + c = f32[4,7] parameter(2) + inner_dot = f32[2,3,4,6,7] dot(a,b), + lhs_contracting_dims={3}, + rhs_contracting_dims={2} + ROOT outer_dot = f32[2,3,6] dot(inner_dot,c), + lhs_contracting_dims={2,4}, + rhs_contracting_dims={0,1} } )"; TF_ASSERT_OK_AND_ASSIGN(auto module, From e87d1c975a2f1ea4de60a865b6561ae12c1c8c33 Mon Sep 17 00:00:00 2001 From: Richard Levasseur Date: Tue, 11 Jul 2023 16:55:45 -0700 Subject: [PATCH 161/376] Internal Code Change PiperOrigin-RevId: 547336766 --- tensorflow/py.default.bzl | 12 ++++++++++++ tensorflow/tensorflow.bzl | 36 ++++++++++++++++++++++++++++-------- 2 files changed, 40 insertions(+), 8 deletions(-) create mode 100644 tensorflow/py.default.bzl diff --git a/tensorflow/py.default.bzl b/tensorflow/py.default.bzl new file mode 100644 index 00000000000000..bad528e901bbd1 --- /dev/null +++ b/tensorflow/py.default.bzl @@ -0,0 +1,12 @@ +"""Shims for loading the plain Python rules. + +These are used to make internal/external code transformations managable. Once +Tensorflow is loading the Python rules directly from rules_python, these shims +can be removed. +""" + +py_test = native.py_test + +py_binary = native.py_binary + +py_library = native.py_library diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl index efda8f31ff6542..7f32fa50914c7c 100644 --- a/tensorflow/tensorflow.bzl +++ b/tensorflow/tensorflow.bzl @@ -61,6 +61,12 @@ load( "//third_party/llvm_openmp:openmp.bzl", "windows_llvm_openmp_linkopts", ) +load( + "//tensorflow:py.default.bzl", + _plain_py_binary = "py_binary", + _plain_py_library = "py_library", + _plain_py_test = "py_test", +) load("@bazel_skylib//lib:new_sets.bzl", "sets") load("@bazel_skylib//rules:common_settings.bzl", "BuildSettingInfo") @@ -1331,7 +1337,7 @@ def tf_gen_op_wrapper_py( testonly = False, copts = [], extra_py_deps = None, - py_lib_rule = native.py_library): + py_lib_rule = _plain_py_library): """Generates a Python library target wrapping the ops registered in "deps". Args: @@ -2267,7 +2273,8 @@ def tf_custom_op_py_library( deps = [], **kwargs): _ignore = [kernels] - native.py_library( + _make_tags_mutable(kwargs) + _plain_py_library( name = name, data = dso, srcs = srcs, @@ -2454,7 +2461,7 @@ def pywrap_tensorflow_macro_opensource( # link the pyd (which is just a dll) because of missing dependencies. _create_symlink("ml_dtypes.so", "//tensorflow/tsl/python/lib/core:ml_dtypes.so") - native.py_library( + _plain_py_library( name = name, srcs = [":" + name + ".py"], srcs_version = "PY3", @@ -2489,10 +2496,11 @@ pywrap_tensorflow_macro = pywrap_tensorflow_macro_opensource # Note that this only works on Windows. See the definition of # //third_party/tensorflow/tools/pip_package:win_pip_package_marker for specific reasons. # 2. When --define=no_tensorflow_py_deps=false (by default), it's a normal py_test. -def py_test(deps = [], data = [], kernels = [], exec_properties = None, test_rule = native.py_test, **kwargs): +def py_test(deps = [], data = [], kernels = [], exec_properties = None, test_rule = _plain_py_test, **kwargs): if not exec_properties: exec_properties = tf_exec_properties(kwargs) + _make_tags_mutable(kwargs) test_rule( deps = select({ "//conditions:default": deps, @@ -2516,13 +2524,14 @@ register_extension_info( # See https://github.com/tensorflow/tensorflow/issues/22390 def py_binary(name, deps = [], **kwargs): # Add an extra target for dependencies to avoid nested select statement. - native.py_library( + _plain_py_library( name = name + "_deps", deps = deps, ) # Python version placeholder - native.py_binary( + _make_tags_mutable(kwargs) + _plain_py_binary( name = name, deps = select({ "//conditions:default": [":" + name + "_deps"], @@ -2533,7 +2542,18 @@ def py_binary(name, deps = [], **kwargs): def pytype_library(name, pytype_deps = [], pytype_srcs = [], **kwargs): # Types not enforced in OSS. - native.py_library(name = name, **kwargs) + _make_tags_mutable(kwargs) + _plain_py_library(name = name, **kwargs) + +# Tensorflow uses rules_python 0.0.1, and in that version of rules_python, +# the rules require the tags value to be a mutable list because they +# modify it in-place. Later versions of rules_python don't have this +# requirement. +def _make_tags_mutable(kwargs): + if "tags" in kwargs and kwargs["tags"] != None: + # The value might be a frozen list, which looks just like + # a regular list. So always make a copy. + kwargs["tags"] = list(kwargs["tags"]) def tf_py_test( name, @@ -3142,7 +3162,7 @@ def pybind_extension_opensource( testonly = testonly, ) - native.py_library( + _plain_py_library( name = name, data = select({ clean_dep("//tensorflow:windows"): [pyd_file], From 41f4706e7e3de4ff33735f00e792083436729123 Mon Sep 17 00:00:00 2001 From: Reed Wanderman-Milne Date: Tue, 11 Jul 2023 17:09:02 -0700 Subject: [PATCH 162/376] Rollback 8b116e21d125efe69764316a86af4f109bb3d5b6. It breaks FP8 gemms on Hopper PiperOrigin-RevId: 547339624 --- .../compiler/xla/service/gpu/compile_module_to_llvm_ir.cc | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tensorflow/compiler/xla/service/gpu/compile_module_to_llvm_ir.cc b/tensorflow/compiler/xla/service/gpu/compile_module_to_llvm_ir.cc index c91d9cce339c53..1f9d1b00094a69 100644 --- a/tensorflow/compiler/xla/service/gpu/compile_module_to_llvm_ir.cc +++ b/tensorflow/compiler/xla/service/gpu/compile_module_to_llvm_ir.cc @@ -428,7 +428,10 @@ Status CompileModuleToLlvmIrImpl( RecordHloToLlvmDuration(end_usecs - start_usecs); } - if (IsXlaRuntimeExecutableEnabled(hlo_module->config())) { + // TODO(ezhulenev): Remove the FP8 check once https://reviews.llvm.org/D140088 + // is submitted. Currently we can't emit LLVM IR with fp8 types. + if (IsXlaRuntimeExecutableEnabled(hlo_module->config()) && + !HasFp8(*hlo_module)) { std::vector buffer_sizes; llvm::transform( results->allocations, std::back_inserter(buffer_sizes), From 1295151bb6dd0fcdfe960309052f53fc7234e1cd Mon Sep 17 00:00:00 2001 From: Rahul Joshi Date: Tue, 11 Jul 2023 17:14:58 -0700 Subject: [PATCH 163/376] [XLA] Fix `IsPerIdOffset` to check IsEffectiveScalar - Fix the function to check if the multiply is effectively scalar as opposed to true scalar - This fixes a regression in RS pattern matching caused by `ReshapeMover` pass, which pushes down the reshape in the offset computation generated by SPMD partitioner, causing the RS pattern matching to fail PiperOrigin-RevId: 547340786 --- .../compiler/xla/service/gpu/tests/BUILD | 4 -- .../tests/gpu_reduce_scatter_creator_test.cc | 46 +++++++++++++++++-- .../xla/service/reduce_scatter_utils.cc | 4 +- 3 files changed, 44 insertions(+), 10 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/tests/BUILD b/tensorflow/compiler/xla/service/gpu/tests/BUILD index a810a5fcdd2670..22b5a70b580938 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/BUILD +++ b/tensorflow/compiler/xla/service/gpu/tests/BUILD @@ -99,13 +99,9 @@ xla_cc_test( "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/hlo/ir:hlo", "//tensorflow/compiler/xla/hlo/utils:hlo_matchers", - "//tensorflow/compiler/xla/service:hlo_parser", - "//tensorflow/compiler/xla/service:hlo_pass_pipeline", - "//tensorflow/compiler/xla/service:hlo_verifier", "//tensorflow/compiler/xla/service/gpu:gpu_reduce_scatter_creator", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/tsl/lib/core:status_test_util", "//tensorflow/tsl/platform:test", ], ) diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_reduce_scatter_creator_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_reduce_scatter_creator_test.cc index a80f897735fd89..c2a4bf7213a0d9 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_reduce_scatter_creator_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_reduce_scatter_creator_test.cc @@ -15,18 +15,16 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/gpu_reduce_scatter_creator.h" +#include +#include + #include "tensorflow/compiler/xla/hlo/ir/hlo_casting_utils.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_instructions.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_module.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_opcode.h" #include "tensorflow/compiler/xla/hlo/utils/hlo_matchers.h" -#include "tensorflow/compiler/xla/service/hlo_parser.h" -#include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h" -#include "tensorflow/compiler/xla/service/hlo_verifier.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/tsl/lib/core/status_test_util.h" namespace xla { namespace gpu { @@ -98,6 +96,44 @@ ENTRY %AllReduce { EXPECT_EQ(AllReduceCount(module), 0); } +TEST_F(GpuReduceScatterCreatorTest, AllReplicasWithOffsetReshape) { + absl::string_view hlo_string = R"( +HloModule AllReduce + +%sum { + %a = f32[] parameter(0) + %b = f32[] parameter(1) + ROOT %add = f32[] add(%a, %b) +} + +ENTRY %AllReduce { + %param = f32[32,8,128]{2,1,0} parameter(0) + %all-reduce = f32[32,8,128]{2,1,0} all-reduce(%param), + replica_groups={}, to_apply=%sum + %table = s32[8]{0} constant({0,1,2,3,4,5,6,7}) + %rid = u32[] replica-id() + %id = s32[1] dynamic-slice(%table, %rid), dynamic_slice_sizes={1} + %slice_size = s32[1] constant({4}) + %offset = s32[1] multiply(%id, %slice_size) + %reshape = s32[] reshape(%offset) + %zero = s32[] constant(0) + ROOT %dynamic-slice = f32[4,8,128] dynamic-slice(%all-reduce, %reshape, %zero, %zero), + dynamic_slice_sizes={4,8,128} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string, + /*num_replicas=*/8, + /*num_partitions=*/1, + /*expect_change=*/true)); + ASSERT_THAT(module->entry_computation()->root_instruction(), + op::ReduceScatter(op::Parameter(0))); + const auto *rs = Cast( + module->entry_computation()->root_instruction()); + EXPECT_EQ(rs->scatter_dimension(), 0) << rs->ToString(); + EXPECT_EQ(AllReduceCount(module), 0); +} + TEST_F(GpuReduceScatterCreatorTest, AllReplicasWithReshape) { absl::string_view hlo_string = R"( HloModule AllReduce diff --git a/tensorflow/compiler/xla/service/reduce_scatter_utils.cc b/tensorflow/compiler/xla/service/reduce_scatter_utils.cc index ae689d2a96cff6..bc179cbfa83550 100644 --- a/tensorflow/compiler/xla/service/reduce_scatter_utils.cc +++ b/tensorflow/compiler/xla/service/reduce_scatter_utils.cc @@ -17,6 +17,8 @@ limitations under the License. #include #include +#include +#include #include "tensorflow/compiler/xla/hlo/ir/hlo_instructions.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_module.h" @@ -146,7 +148,7 @@ bool IsPerIdOffset(const HloInstruction* offset, int64_t shard_size, if (offset->opcode() == HloOpcode::kMultiply) { // Check if it's constant * IsPerIdOffset(..., shard_size / constant, ...) - if (offset->shape().rank() != 0) { + if (!ShapeUtil::IsEffectiveScalar(offset->shape())) { VLOG(2) << "Offset is not a scalar " << offset->ToString(); return false; } From 6c7ebc9aad9a238a1b2d91aef51723e26ffdf076 Mon Sep 17 00:00:00 2001 From: Luke Boyer Date: Tue, 11 Jul 2023 17:22:38 -0700 Subject: [PATCH 164/376] Benchmarks from MLIR for tfl tensorlists PiperOrigin-RevId: 547342340 --- tensorflow/compiler/mlir/lite/tests/legalize-tensorlist.mlir | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorflow/compiler/mlir/lite/tests/legalize-tensorlist.mlir b/tensorflow/compiler/mlir/lite/tests/legalize-tensorlist.mlir index b2e2aac35983df..f97bab00c89503 100644 --- a/tensorflow/compiler/mlir/lite/tests/legalize-tensorlist.mlir +++ b/tensorflow/compiler/mlir/lite/tests/legalize-tensorlist.mlir @@ -79,6 +79,7 @@ func.func @listFromTensor(%tensor: tensor<3xi32>, %shape : tensor) -> ten func.return %0 : tensor>> // CHECK: %0 = "tfl.custom"(%arg0, %arg1) {custom_code = "TensorListFromTensor", custom_option = #tfl} : (tensor<3xi32>, tensor) -> tensor>> } + // ----- // CHECK-LABEL: typeNotSupportedNotLegalized From ce78c0d4716318db33a0ce0c4894583b09b4b092 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Tue, 11 Jul 2023 19:12:21 -0700 Subject: [PATCH 165/376] [xla:runtime] Use volatile store to encode args/rets Encoded args/rets create a lot of store instructions that LLVM tries to optimize very hard, but we don't really expect any optimizations to improve performance. By marking store instructions volatile we suppress most of the expensive LLVM optimizations. PiperOrigin-RevId: 547359822 --- .../mlir/runtime/transforms/custom_call_encoding.cc | 3 ++- .../compiler/xla/mlir/runtime/transforms/rt_to_llvm.cc | 10 ++++++---- .../xla/mlir/runtime/transforms/tests/rt_to_llvm.mlir | 8 ++++---- 3 files changed, 12 insertions(+), 9 deletions(-) diff --git a/tensorflow/compiler/xla/mlir/runtime/transforms/custom_call_encoding.cc b/tensorflow/compiler/xla/mlir/runtime/transforms/custom_call_encoding.cc index a83f477729ee54..b23c7f0d1bd86f 100644 --- a/tensorflow/compiler/xla/mlir/runtime/transforms/custom_call_encoding.cc +++ b/tensorflow/compiler/xla/mlir/runtime/transforms/custom_call_encoding.cc @@ -422,7 +422,8 @@ static LLVM::AllocaOp PackValue(ImplicitLocOpBuilder &b, Allocas &a, LLVM::AllocaOp alloca = a.GetOrCreate(b, value.getType()); // Start the lifetime of encoded value. b.create(b.getI64IntegerAttr(-1), alloca); - b.create(value, alloca); + // Use volatile store to suppress expensive LLVM optimizations. + b.create(value, alloca, /*alignment=*/0, /*isVolatile=*/true); return alloca; } diff --git a/tensorflow/compiler/xla/mlir/runtime/transforms/rt_to_llvm.cc b/tensorflow/compiler/xla/mlir/runtime/transforms/rt_to_llvm.cc index 399ad8de6fbbae..00df6405898bd1 100644 --- a/tensorflow/compiler/xla/mlir/runtime/transforms/rt_to_llvm.cc +++ b/tensorflow/compiler/xla/mlir/runtime/transforms/rt_to_llvm.cc @@ -337,8 +337,9 @@ static FailureOr EncodeArguments( // Start the lifetime of the encoded arguments pointers. b.create(b.getI64IntegerAttr(-1), alloca); - // Store constructed arguments pointers array into the alloca. - b.create(arr, alloca.getRes()); + // Store constructed arguments pointers array into the alloca. Use volatile + // store to suppress expensive LLVM optimizations. + b.create(arr, alloca, /*alignment=*/0, /*isVolatile=*/true); // Alloca that encodes the custom call arguments. arguments.encoded = alloca; @@ -431,8 +432,9 @@ static FailureOr EncodeResults( // Start the lifetime of the encoded results pointers allocation. b.create(b.getI64IntegerAttr(-1), alloca); - // Store constructed results pointers array on the stack - b.create(arr, alloca); + // Store constructed results pointers array on the stack. Use volatile + // store to suppress expensive LLVM optimizations. + b.create(arr, alloca, /*alignment=*/0, /*isVolatile=*/true); // Alloca that encodes the custom call returns. results.encoded = alloca; diff --git a/tensorflow/compiler/xla/mlir/runtime/transforms/tests/rt_to_llvm.mlir b/tensorflow/compiler/xla/mlir/runtime/transforms/tests/rt_to_llvm.mlir index b26b0b6f6062c4..38af292598b59d 100644 --- a/tensorflow/compiler/xla/mlir/runtime/transforms/tests/rt_to_llvm.mlir +++ b/tensorflow/compiler/xla/mlir/runtime/transforms/tests/rt_to_llvm.mlir @@ -335,7 +335,7 @@ func.func @custom_call(%arg0: !rt.execution_context, %arg1 : f32) { // CHECK-DAG: %[[ARGS:.*]] = llvm.alloca {{.*}} x !llvm.array<3 x ptr> // CHECK-DAG: %[[N_ARGS:.*]] = llvm.mlir.addressof @__rt_num_args - // CHECK-DAG: llvm.store %[[ARG]], %[[MEM]] + // CHECK-DAG: llvm.store volatile %[[ARG]], %[[MEM]] // CHECK: %[[ARGS_TYPES:.*]] = llvm.mlir.addressof @__rt_args_type_table // CHECK: llvm.insertvalue %[[ARGS_TYPES]], {{.*}}[1] : !llvm.array<3 x ptr> @@ -460,7 +460,7 @@ func.func @opaque_arg(%ctx: !rt.execution_context, %arg: !rt.opaque) { func.func @opaque_custom_call_arg(%ctx: !rt.execution_context, %arg: !rt.opaque) { // CHECK: %[[ALLOCA:.*]] = llvm.alloca {{.*}} x !llvm.ptr - // CHECK: llvm.store %[[ARG1]], %[[ALLOCA]] : !llvm.ptr + // CHECK: llvm.store volatile %[[ARG1]], %[[ALLOCA]] : !llvm.ptr // CHECK: call @target %status = rt.call %ctx["target"] (%arg) : (!rt.opaque) -> () return @@ -627,7 +627,7 @@ func.func @custom_call(%arg0: !rt.execution_context, %arg1: f32) { // CHECK-NOT: llvm.alloca // llvm.intr.lifetime.start -1, %[[ARG_ALLOCA]] : !llvm.ptr - // CHECK: llvm.store %[[ARG]], %[[ARG_ALLOCA]] : f32, !llvm.ptr + // CHECK: llvm.store volatile %[[ARG]], %[[ARG_ALLOCA]] : f32, !llvm.ptr // llvm.intr.lifetime.start -1, %[[ARGS]] : !llvm.ptr // CHECK: llvm.store {{.*}}, %[[ARGS]] // CHECK: call @target @@ -636,7 +636,7 @@ func.func @custom_call(%arg0: !rt.execution_context, %arg1: f32) { // llvm.intr.lifetime.end -1, %[[ARG_ALLOCA]] : !llvm.ptr // llvm.intr.lifetime.start -1, %[[ARG_ALLOCA]] : !llvm.ptr - // CHECK: llvm.store %[[ARG]], %[[ARG_ALLOCA]] : f32, !llvm.ptr + // CHECK: llvm.store volatile %[[ARG]], %[[ARG_ALLOCA]] : f32, !llvm.ptr // llvm.intr.lifetime.start -1, %[[ARGS]] : !llvm.ptr // CHECK: llvm.store {{.*}}, %[[ARGS]] // CHECK: call @target From 221af6a546ba5c8e470d2b80877a47b34c9d6b94 Mon Sep 17 00:00:00 2001 From: Jieying Luo Date: Tue, 11 Jul 2023 20:49:17 -0700 Subject: [PATCH 166/376] [PJRT C API] Add host_layout to ToHostBufferArg. The input of C++ API ToLiteral can have a specific host layout (tile dimensions not supported right now). Add it to corresponding PJRT C API. PiperOrigin-RevId: 547374250 --- tensorflow/compiler/xla/pjrt/c/BUILD | 3 ++ tensorflow/compiler/xla/pjrt/c/pjrt_c_api.h | 6 +++- .../compiler/xla/pjrt/c/pjrt_c_api_test.cc | 29 +++++++++++++++++++ .../xla/pjrt/c/pjrt_c_api_wrapper_impl.cc | 22 ++++++++++++-- .../compiler/xla/pjrt/pjrt_c_api_client.cc | 12 ++++++++ 5 files changed, 68 insertions(+), 4 deletions(-) diff --git a/tensorflow/compiler/xla/pjrt/c/BUILD b/tensorflow/compiler/xla/pjrt/c/BUILD index 169ad42fc3eb4a..d1531a0bf1c523 100644 --- a/tensorflow/compiler/xla/pjrt/c/BUILD +++ b/tensorflow/compiler/xla/pjrt/c/BUILD @@ -160,6 +160,8 @@ cc_library( deps = [ ":pjrt_c_api_hdrs", ":pjrt_c_api_helpers", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla:xla_proto_cc", "//tensorflow/compiler/xla/client:xla_builder", @@ -168,6 +170,7 @@ cc_library( "//tensorflow/compiler/xla/pjrt:pjrt_future", "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/service:hlo_proto_cc", + "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/tsl/platform:errors", "//tensorflow/tsl/platform:status", "//tensorflow/tsl/platform:statusor", diff --git a/tensorflow/compiler/xla/pjrt/c/pjrt_c_api.h b/tensorflow/compiler/xla/pjrt/c/pjrt_c_api.h index b94359e9e4c48a..9f7926e0d403f5 100644 --- a/tensorflow/compiler/xla/pjrt/c/pjrt_c_api.h +++ b/tensorflow/compiler/xla/pjrt/c/pjrt_c_api.h @@ -53,7 +53,7 @@ extern "C" { // Changes include: // * Adding a new field to the PJRT_Api or argument structs // * Renaming a method or argument (doesn't affect ABI) -#define PJRT_API_MINOR 7 +#define PJRT_API_MINOR 8 // The plugin should set the major_version and minor_version of // PJRT_Api.pjrt_api_version to be the `PJRT_API_MAJOR` and `PJRT_API_MINOR` in @@ -1352,6 +1352,10 @@ struct PJRT_Buffer_ToHostBuffer_Args { void* priv; PJRT_Buffer* src; + // The caller can specify an optional host layout. If nullptr, the layout of + // the src buffer will be used. The caller is responsible to keep the data + // (tiled or strides) in the host_layout alive during the call. + PJRT_Buffer_MemoryLayout* host_layout; // `dst` can be nullptr to query required size which will be set into // `dst_size`. void* dst; // in/out diff --git a/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_test.cc b/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_test.cc index b4d51f2d18924a..4bde253fee9a68 100644 --- a/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_test.cc +++ b/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_test.cc @@ -32,12 +32,15 @@ limitations under the License. #include "absl/types/span.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/pjrt/c/pjrt_c_api.h" #include "tensorflow/compiler/xla/pjrt/c/pjrt_c_api_helpers.h" #include "tensorflow/compiler/xla/pjrt/pjrt_client.h" #include "tensorflow/compiler/xla/pjrt/pjrt_future.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/shape.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/xla.pb.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/tsl/platform/errors.h" @@ -283,6 +286,7 @@ class PjrtCApiTest : public ::testing::Test { .struct_size = PJRT_Buffer_ToHostBuffer_Args_STRUCT_SIZE, .priv = nullptr, .src = src_buffer, + .host_layout = nullptr, .dst = nullptr, .dst_size = 0, .event = nullptr, @@ -864,6 +868,31 @@ TEST_F(PjrtCApiBufferTest, ReadyEvent) { EXPECT_EQ(error, nullptr); } +TEST_F(PjrtCApiBufferTest, ToHostBufferNoHostLayout) { + PJRT_Buffer_ToHostBuffer_Args args; + args.struct_size = PJRT_Buffer_ToHostBuffer_Args_STRUCT_SIZE; + args.priv = nullptr; + args.src = buffer_.get(); + Shape host_shape = ShapeUtil::MakeShape(F32, {4}); + auto literal = std::make_shared(host_shape); + args.host_layout = nullptr; + args.dst = literal->untyped_data(); + args.dst_size = ShapeUtil::ByteSizeOfElements(host_shape); + args.event = nullptr; + + PJRT_Error* error = api_->PJRT_Buffer_ToHostBuffer(&args); + PjRtFuture transfer_to_host = + ::pjrt::ConvertCEventToCppFuture(args.event, api_); + TF_CHECK_OK(transfer_to_host.Await()); + + EXPECT_EQ(error, nullptr); + ASSERT_EQ(literal->data().size(), 4); + std::vector float_data(4); + std::iota(float_data.begin(), float_data.end(), 41.0f); + EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1(float_data), + *literal)); +} + // --------------------------------- Helpers ----------------------------------- class PjrtCommonCApiHelpersTest : public PjrtCApiTest {}; diff --git a/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc b/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc index 32bac4e95fc26f..4bf5c5cc796c9d 100644 --- a/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc +++ b/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc @@ -471,7 +471,7 @@ PJRT_Error* PJRT_Client_BufferFromHostBuffer( PJRT_Buffer_MemoryLayout_Type_Strides: { PJRT_RETURN_IF_ERROR(absl::InvalidArgumentError(absl::StrCat( "PJRT_Buffer_MemoryLayout_Type_Strides in device_layout is not " - "supported in PJRT_Client_BufferFromHostBuffer for platform '%s'", + "supported in PJRT_Client_BufferFromHostBuffer for platform ", args->client->client->platform_name()))); break; } @@ -1347,8 +1347,24 @@ PJRT_Error* PJRT_Buffer_ToHostBuffer(PJRT_Buffer_ToHostBuffer_Args* args) { } else { device_shape = args->src->buffer->on_device_shape(); } - const xla::Shape& host_shape = - xla::ShapeUtil::DeviceShapeToHostShape(device_shape); + xla::Shape host_shape = xla::ShapeUtil::DeviceShapeToHostShape(device_shape); + if (args->host_layout != nullptr) { + if (args->host_layout->type == + PJRT_Buffer_MemoryLayout_Type::PJRT_Buffer_MemoryLayout_Type_Strides) { + PJRT_RETURN_IF_ERROR(absl::InvalidArgumentError( + absl::StrCat("PJRT_Buffer_ToHostBuffer does not support host_layout " + "with strides for platform ", + args->src->buffer->client()->platform_name()))); + } + if (args->host_layout->tiled.num_tiles > 0) { + PJRT_RETURN_IF_ERROR(absl::InvalidArgumentError( + absl::StrCat("PJRT_Buffer_ToHostBuffer does not support host_layout " + "with tiled dimension for platform ", + args->src->buffer->client()->platform_name()))); + } + PJRT_ASSIGN_OR_RETURN(*host_shape.mutable_layout(), + ConvertToLayout(args->host_layout->tiled)); + } size_t host_buffer_size = xla::ShapeUtil::ByteSizeOfElements(host_shape); diff --git a/tensorflow/compiler/xla/pjrt/pjrt_c_api_client.cc b/tensorflow/compiler/xla/pjrt/pjrt_c_api_client.cc index 4bbf4663891f62..44e0accce8f7b1 100644 --- a/tensorflow/compiler/xla/pjrt/pjrt_c_api_client.cc +++ b/tensorflow/compiler/xla/pjrt/pjrt_c_api_client.cc @@ -1441,6 +1441,18 @@ PjRtFuture PjRtCApiBuffer::ToLiteral(MutableLiteralBase* literal) { args.dst_size = ShapeUtil::ByteSizeOfElements(shape); args.dst = literal->untyped_data(); + xla::StatusOr c_layout_data; + if (literal->shape().has_layout()) { + c_layout_data = + pjrt::ConvertToBufferMemoryLayoutData(&literal->shape().layout()); + if (!c_layout_data.ok()) { + return PjRtFuture(c_layout_data.status()); + } + args.host_layout = &(c_layout_data->c_layout); + } else { + args.host_layout = nullptr; + } + const PJRT_Api* api = pjrt_c_api(); std::unique_ptr error{ From ee4da7b1fb5a1a62f52103307f258f912eaa450a Mon Sep 17 00:00:00 2001 From: Kuangyuan Chen Date: Tue, 11 Jul 2023 20:59:00 -0700 Subject: [PATCH 167/376] Add an e2e test for a SAX model with streaming using TFRT PiperOrigin-RevId: 547375586 --- .../mlir/tfrt/translate/import_model.cc | 14 +++---- tensorflow/core/tfrt/graph_executor/BUILD | 2 + .../graph_executor/graph_execution_options.h | 6 +++ .../tfrt/graph_executor/graph_executor.cc | 39 +++++++++++++++---- .../core/tfrt/graph_executor/graph_executor.h | 19 +++++++-- .../graph_executor/graph_executor_test.cc | 7 +++- 6 files changed, 68 insertions(+), 19 deletions(-) diff --git a/tensorflow/compiler/mlir/tfrt/translate/import_model.cc b/tensorflow/compiler/mlir/tfrt/translate/import_model.cc index 6045d40d56692a..080439159e628e 100644 --- a/tensorflow/compiler/mlir/tfrt/translate/import_model.cc +++ b/tensorflow/compiler/mlir/tfrt/translate/import_model.cc @@ -166,7 +166,13 @@ Status ConvertTfMlirToRuntimeExecutable( } } - if (options.device_target == TfrtDeviceInfraTarget::kTpurt) { + if (options.backend_compiler != nullptr) { + if (VLOG_IS_ON(1)) { + tensorflow::DumpMlirOpToFile("tf_dialect_before_backend_compile", module); + } + TF_RETURN_IF_ERROR( + options.backend_compiler->CompileTensorflow(model_context, module)); + } else if (options.device_target == TfrtDeviceInfraTarget::kTpurt) { VLOG(1) << "Running MLIR TPU bridge for tpurt"; if (VLOG_IS_ON(1)) { tensorflow::DumpMlirOpToFile("tpu_bct_conversion_before", module); @@ -212,12 +218,6 @@ Status ConvertTfMlirToRuntimeExecutable( TF_RETURN_IF_ERROR(fallback_state->AddFunctionDef(func_def)); } } - } else if (options.backend_compiler != nullptr) { - if (VLOG_IS_ON(1)) { - tensorflow::DumpMlirOpToFile("tf_dialect_before_backend_compile", module); - } - TF_RETURN_IF_ERROR( - options.backend_compiler->CompileTensorflow(model_context, module)); } if (VLOG_IS_ON(1)) { diff --git a/tensorflow/core/tfrt/graph_executor/BUILD b/tensorflow/core/tfrt/graph_executor/BUILD index 1e4e07907211df..5a3a69305074df 100644 --- a/tensorflow/core/tfrt/graph_executor/BUILD +++ b/tensorflow/core/tfrt/graph_executor/BUILD @@ -29,6 +29,7 @@ cc_library( ":config", "//tensorflow/compiler/mlir/tfrt:tfrt_compile_options", "//tensorflow/core:core_cpu", + "//tensorflow/core/framework:tensor", "//tensorflow/core/protobuf:for_core_protos_cc", "//tensorflow/core/tfrt/runtime:work_queue_interface", "//tensorflow/core/tfrt/utils:bridge_graph_analysis", @@ -88,6 +89,7 @@ cc_library( "//tensorflow/core/tfrt/mlrt/interpreter:execute", "//tensorflow/core/tfrt/mlrt/kernel:context", "//tensorflow/core/tfrt/runtime", + "//tensorflow/core/tfrt/runtime:stream", "//tensorflow/core/tfrt/runtime:work_queue_interface", "//tensorflow/core/tfrt/stubs:tfrt_native_lowering_stub", "//tensorflow/core/tfrt/utils", diff --git a/tensorflow/core/tfrt/graph_executor/graph_execution_options.h b/tensorflow/core/tfrt/graph_executor/graph_execution_options.h index bb18cc39d38ad0..09f48b592f2f56 100644 --- a/tensorflow/core/tfrt/graph_executor/graph_execution_options.h +++ b/tensorflow/core/tfrt/graph_executor/graph_execution_options.h @@ -15,11 +15,14 @@ limitations under the License. #ifndef TENSORFLOW_CORE_TFRT_GRAPH_EXECUTOR_GRAPH_EXECUTION_OPTIONS_H_ #define TENSORFLOW_CORE_TFRT_GRAPH_EXECUTOR_GRAPH_EXECUTION_OPTIONS_H_ +#include #include #include +#include #include "absl/types/optional.h" #include "tensorflow/compiler/mlir/tfrt/translate/tfrt_compile_options.h" +#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/protobuf/config.pb.h" #include "tensorflow/core/public/session_options.h" #include "tensorflow/core/tfrt/graph_executor/config.h" @@ -101,6 +104,9 @@ struct GraphExecutionRunOptions { // If true, just-in-time host compilation is disabled, and then if the // specified graph is not compiled, the execution will return an error. bool disable_compilation = false; + + std::function)> + streamed_output_callback; }; // Creates the default `SessionOptions` from a `GraphExecutionOptions`. diff --git a/tensorflow/core/tfrt/graph_executor/graph_executor.cc b/tensorflow/core/tfrt/graph_executor/graph_executor.cc index d409b9af6a1af9..0e7059be656fad 100644 --- a/tensorflow/core/tfrt/graph_executor/graph_executor.cc +++ b/tensorflow/core/tfrt/graph_executor/graph_executor.cc @@ -70,6 +70,7 @@ limitations under the License. #include "tensorflow/core/tfrt/mlrt/interpreter/execute.h" #include "tensorflow/core/tfrt/mlrt/kernel/context.h" #include "tensorflow/core/tfrt/runtime/runtime.h" +#include "tensorflow/core/tfrt/runtime/stream.h" #include "tensorflow/core/tfrt/runtime/work_queue_interface.h" #include "tensorflow/core/tfrt/stubs/tfrt_native_lowering_stub.h" #include "tensorflow/core/tfrt/utils/fallback_tensor.h" @@ -267,7 +268,8 @@ tensorflow::Status GraphExecutionRunOnFunction( tfd::FallbackResourceArray* resource_array, const Runtime& runtime, const FallbackState& fallback_state, tfrt::RequestDeadlineTracker* req_deadline_tracker, - CostRecorder* cost_recorder) { + CostRecorder* cost_recorder, + std::optional stream_callback_id) { TF_ASSIGN_OR_RETURN( auto request_info, CreateRequestInfo(options, run_options, run_options.work_queue, @@ -275,10 +277,10 @@ tensorflow::Status GraphExecutionRunOnFunction( runner_table, resource_array, fallback_state, cost_recorder)); + int64_t request_id = request_info->tfrt_request_context->id(); tensorflow::profiler::TraceMeProducer traceme( // To TraceMeConsumers in RunHandlerThreadPool::WorkerLoop. - [request_id = request_info->tfrt_request_context->id(), signature_name, - &options, symbol_uids] { + [request_id, signature_name, &options, symbol_uids] { return tensorflow::profiler::TraceMeEncode( "TfrtModelRun", {{"_r", 1}, @@ -289,8 +291,7 @@ tensorflow::Status GraphExecutionRunOnFunction( {"tf_symbol_uid", symbol_uids.tf_symbol_uid}, {"tfrt_symbol_uid", symbol_uids.tfrt_symbol_uid}}); }, - tensorflow::profiler::ContextType::kTfrtExecutor, - request_info->tfrt_request_context->id()); + tensorflow::profiler::ContextType::kTfrtExecutor, request_id); // Only configure timer when the deadline is set. if (run_options.deadline.has_value()) { @@ -306,6 +307,23 @@ tensorflow::Status GraphExecutionRunOnFunction( deadline, request_info->tfrt_request_context); } + ScopedStreamCallback scoped_stream_callback; + + if (stream_callback_id.has_value()) { + if (!run_options.streamed_output_callback) { + return absl::InvalidArgumentError( + "streamed_output_callback is not provided for a streaming model."); + } + + auto streamed_output_callback = run_options.streamed_output_callback; + + TF_ASSIGN_OR_RETURN( + scoped_stream_callback, + GetGlobalStreamCallbackRegistry().Register( + options.model_metadata.name(), *stream_callback_id, + StepId(request_id), std::move(streamed_output_callback))); + } + if (loaded_executable) { auto function = loaded_executable->GetFunction(signature_name); if (!function) { @@ -558,7 +576,8 @@ tensorflow::Status GraphExecutor::Run( &executable_context->resource_context, &loaded_client_graph.runner_table(), &loaded_client_graph.resource_array(), runtime(), fallback_state_, - &req_deadline_tracker_, cost_recorder.get())); + &req_deadline_tracker_, cost_recorder.get(), + loaded_client_graph.stream_callback_id())); if (cost_recorder != nullptr) { TF_RETURN_IF_ERROR( @@ -597,6 +616,11 @@ GraphExecutor::ImportAndCompileClientGraph( registry, mlir::MLIRContext::Threading::DISABLED); ASSIGN_OR_RETURN_IN_IMPORT( auto module, ImportClientGraphToMlirModule(client_graph, context.get())); + + TF_ASSIGN_OR_RETURN( + auto stream_callback_id, + CreateStreamCallbackId(options().model_metadata.name(), module.get())); + // TODO(b/278143179): Upload module w/o control flow. SymbolUids symbol_uids; symbol_uids.tf_symbol_uid = MaybeUploadMlirToXsymbol(module.get()); @@ -660,7 +684,8 @@ GraphExecutor::ImportAndCompileClientGraph( return std::make_unique( client_graph.name, std::move(symbol_uids), this, std::move(context), std::move(module_with_op_keys), std::move(module), - std::move(executable_context), options_.enable_online_cost_analysis); + std::move(executable_context), options_.enable_online_cost_analysis, + std::move(stream_callback_id)); } StatusOr> diff --git a/tensorflow/core/tfrt/graph_executor/graph_executor.h b/tensorflow/core/tfrt/graph_executor/graph_executor.h index a50f222d351c77..bde356e9107fd7 100644 --- a/tensorflow/core/tfrt/graph_executor/graph_executor.h +++ b/tensorflow/core/tfrt/graph_executor/graph_executor.h @@ -39,6 +39,7 @@ limitations under the License. #include "tensorflow/core/tfrt/mlrt/bytecode/bytecode.h" #include "tensorflow/core/tfrt/mlrt/interpreter/context.h" #include "tensorflow/core/tfrt/runtime/runtime.h" +#include "tensorflow/core/tfrt/runtime/stream.h" #include "tensorflow/core/tfrt/runtime/work_queue_interface.h" #include "tensorflow/core/tfrt/utils/tfrt_graph_execution_state.h" #include "tensorflow/tsl/platform/thread_annotations.h" @@ -93,6 +94,9 @@ StatusOr> CreateRequestInfo( // Note: `resource_context` is per-graph-executor and // `client_graph_resource_context` is per-loaded-client-graph. See the comment // above `GraphExecutor::resource_context_` about the todo to merge these two. +// +// TODO(chky): Refactor this function to take `LoadedClientGraph` instead of +// having a long list of parameters. tensorflow::Status GraphExecutionRunOnFunction( const GraphExecutionOptions& options, const GraphExecutionRunOptions& run_options, @@ -106,7 +110,8 @@ tensorflow::Status GraphExecutionRunOnFunction( tfd::FallbackResourceArray* resource_array, const Runtime& runtime, const FallbackState& fallback_state, tfrt::RequestDeadlineTracker* req_deadline_tracker, - CostRecorder* cost_recorder = nullptr); + CostRecorder* cost_recorder = nullptr, + std::optional stream_callback_id = std::nullopt); // Runs a MLRT function for executing tensorflow graphs. tensorflow::Status RunMlrtFunction( @@ -133,12 +138,14 @@ class GraphExecutor { mlir::OwningOpRef tf_mlir_with_op_keys, mlir::OwningOpRef tfrt_mlir, std::shared_ptr executable_context, - bool enable_online_cost_analysis) + bool enable_online_cost_analysis, + std::optional stream_callback_id) : name_(std::move(name)), symbol_uids_(std::move(symbol_uids)), graph_executor_(graph_executor), mlir_context_(std::move(mlir_context)), - executable_context_(std::move(executable_context)) { + executable_context_(std::move(executable_context)), + stream_callback_id_(std::move(stream_callback_id)) { if (enable_online_cost_analysis) { tf_mlir_with_op_keys_ = std::move(tf_mlir_with_op_keys); tfrt_mlir_ = std::move(tfrt_mlir); @@ -167,6 +174,10 @@ class GraphExecutor { tfd::FallbackResourceArray& resource_array() { return resource_array_; } SyncResourceState& sync_resource_state() { return sync_resource_state_; } + const std::optional& stream_callback_id() const { + return stream_callback_id_; + } + private: std::string name_; SymbolUids symbol_uids_; @@ -187,6 +198,8 @@ class GraphExecutor { TF_GUARDED_BY(executable_context_mu_); mutable absl::once_flag create_cost_recorder_once_; SyncResourceState sync_resource_state_; + + std::optional stream_callback_id_; }; // A subgraph constructed by specifying input/output tensors. diff --git a/tensorflow/core/tfrt/graph_executor/graph_executor_test.cc b/tensorflow/core/tfrt/graph_executor/graph_executor_test.cc index 064c32fe47c78a..34505ce6bcac25 100644 --- a/tensorflow/core/tfrt/graph_executor/graph_executor_test.cc +++ b/tensorflow/core/tfrt/graph_executor/graph_executor_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/core/tfrt/graph_executor/graph_executor.h" #include +#include #include #include #include @@ -256,14 +257,16 @@ TEST_F(GraphExecutorTest, DoOnlineCostAnalysisExactlyOnce) { /*mlir_context=*/nullptr, /*tf_mlir_with_op_keys=*/{}, /*tfrt_mlir=*/{}, /*executable_context=*/nullptr, - /*enable_online_cost_analysis=*/true); + /*enable_online_cost_analysis=*/true, + /*stream_callback_id=*/std::nullopt); GraphExecutor::LoadedClientGraph loaded_client_graph_1( "name1", /*symbol_uids=*/{}, /*graph_executor=*/nullptr, /*mlir_context=*/nullptr, /*tf_mlir_with_op_keys=*/{}, /*tfrt_mlir=*/{}, /*executable_context=*/nullptr, - /*enable_online_cost_analysis=*/true); + /*enable_online_cost_analysis=*/true, + /*stream_callback_id=*/std::nullopt); // For each `LoadedClientGraph`, `MaybeCreateCostRecorder()` only returns a // cost recorder for once. From 2821ab8d269886b1f400c4fb9495edf801eb241d Mon Sep 17 00:00:00 2001 From: Jieying Luo Date: Tue, 11 Jul 2023 21:23:13 -0700 Subject: [PATCH 168/376] [PJRT C API] Support passing allow_devices as an option in PJRT GPU plugin. PiperOrigin-RevId: 547380253 --- .../xla/pjrt/c/pjrt_c_api_gpu_internal.cc | 31 ++++++++++++------- 1 file changed, 20 insertions(+), 11 deletions(-) diff --git a/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_gpu_internal.cc b/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_gpu_internal.cc index 75f211487c2d71..42f5f8700892b0 100644 --- a/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_gpu_internal.cc +++ b/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_gpu_internal.cc @@ -15,8 +15,10 @@ limitations under the License. #include #include +#include #include #include +#include #include "tensorflow/compiler/xla/pjrt/c/pjrt_c_api_helpers.h" #include "tensorflow/compiler/xla/pjrt/c/pjrt_c_api_wrapper_impl.h" @@ -36,11 +38,19 @@ PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args) { args->num_options); const auto kExpectedOptionNameAndTypes = absl::flat_hash_map( - {{"node_id", PJRT_NamedValue_Type::PJRT_NamedValue_kInt64}, + {{"visible_devices", + PJRT_NamedValue_Type::PJRT_NamedValue_kInt64List}, + {"node_id", PJRT_NamedValue_Type::PJRT_NamedValue_kInt64}, {"num_nodes", PJRT_NamedValue_Type::PJRT_NamedValue_kInt64}}); PJRT_RETURN_IF_ERROR( ValidateCreateOptions(create_options, kExpectedOptionNameAndTypes)); + std::optional> visible_devices; + if (auto it = create_options.find("visible_devices"); + it != create_options.end()) { + const auto& vec = std::get>(it->second); + visible_devices->insert(vec.begin(), vec.end()); + } int node_id = 0; if (auto it = create_options.find("node_id"); it != create_options.end()) { node_id = std::get(it->second); @@ -53,16 +63,15 @@ PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args) { // TODO(b/261916900) initializing allocator_config is important as should be // passed through the args later. xla::GpuAllocatorConfig allocator_config; - PJRT_ASSIGN_OR_RETURN( - std::unique_ptr client, - xla::GetStreamExecutorGpuClient( - /*asynchronous=*/true, allocator_config, node_id, num_nodes, - /*allowed_devices=*/std::nullopt, - /*platform_name=*/std::nullopt, true, - pjrt::ToCppKeyValueGetCallback(args->kv_get_callback, - args->kv_get_user_arg), - pjrt::ToCppKeyValuePutCallback(args->kv_put_callback, - args->kv_put_user_arg))); + PJRT_ASSIGN_OR_RETURN(std::unique_ptr client, + xla::GetStreamExecutorGpuClient( + /*asynchronous=*/true, allocator_config, node_id, + num_nodes, visible_devices, + /*platform_name=*/std::nullopt, true, + pjrt::ToCppKeyValueGetCallback( + args->kv_get_callback, args->kv_get_user_arg), + pjrt::ToCppKeyValuePutCallback( + args->kv_put_callback, args->kv_put_user_arg))); args->client = pjrt::CreateWrapperClient(std::move(client)); return nullptr; } From e07f2547f4cc4d8587e62c14e5072e43b98a9c92 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 11 Jul 2023 21:43:25 -0700 Subject: [PATCH 169/376] Internal Code Change PiperOrigin-RevId: 547383616 --- tensorflow/compiler/mlir/tfrt/BUILD | 5 ----- 1 file changed, 5 deletions(-) diff --git a/tensorflow/compiler/mlir/tfrt/BUILD b/tensorflow/compiler/mlir/tfrt/BUILD index e76b65b6172349..dabc0ece9e6303 100644 --- a/tensorflow/compiler/mlir/tfrt/BUILD +++ b/tensorflow/compiler/mlir/tfrt/BUILD @@ -365,13 +365,10 @@ cc_library( srcs = ["transforms/gpu_passes.cc"], hdrs = ["transforms/gpu_passes.h"], deps = [ - "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tfrt/ir:tfrt_gpu_opdefs", "@llvm-project//llvm:Support", - "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", - "@llvm-project//mlir:Support", "@llvm-project//mlir:Transforms", ], ) @@ -471,7 +468,6 @@ cc_library( deps = [ ":tf_to_tfrt", ":tfrt_compile_options", - "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:dump_mlir_util", "//tensorflow/compiler/mlir/tensorflow:error_util", "//tensorflow/compiler/mlir/tensorflow:import_model", @@ -683,7 +679,6 @@ cc_library( "//tensorflow/compiler/mlir/tfrt/ir:tfrt_fallback_async_opdefs", "//tensorflow/compiler/mlir/tfrt/ir:tfrt_fallback_opdefs", "//tensorflow/compiler/mlir/tfrt/ir:tfrt_fallback_sync_opdefs", - "//tensorflow/compiler/mlir/tfrt/ir:tfrt_gpu_opdefs", "//tensorflow/compiler/mlir/tfrt/ir/mlrt:mlrt_ops", "//tensorflow/compiler/mlir/tfrt/ir/mlrt:tf_mlrt_ops", "//tensorflow/compiler/mlir/tfrt/jit/transforms:tf_jitrt_passes", From 61df417a7413d0224720d91e919af874f0555bff Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 11 Jul 2023 21:44:08 -0700 Subject: [PATCH 170/376] Internal Code Change PiperOrigin-RevId: 547383747 --- tensorflow/core/runtime_fallback/kernel/BUILD | 2 -- 1 file changed, 2 deletions(-) diff --git a/tensorflow/core/runtime_fallback/kernel/BUILD b/tensorflow/core/runtime_fallback/kernel/BUILD index 2aea6feb9c657d..23ca51aa4db1cf 100644 --- a/tensorflow/core/runtime_fallback/kernel/BUILD +++ b/tensorflow/core/runtime_fallback/kernel/BUILD @@ -492,10 +492,8 @@ cc_library( ], deps = [ ":kernel_fallback_compat_request_state", - ":kernel_fallback_tensor", ":kernel_fallback_utils", ":tensor_util", - "//tensorflow/core/runtime_fallback/runtime:kernel_utils", "//tensorflow/core/tfrt/utils:fallback_tensor", "//tensorflow/core/tfrt/utils:gpu_variables_table", "//tensorflow/core/tfrt/utils:tensor_util", From bc3e83d7a2edcf8ef7a7d0d84b14a9b0f4aa6daf Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Tue, 11 Jul 2023 23:51:47 -0700 Subject: [PATCH 171/376] Remove unused include (NFC) This is a leftover from when we removed the handling for the special Softmax custom call. PiperOrigin-RevId: 547402693 --- tensorflow/compiler/xla/service/gpu/BUILD | 1 - tensorflow/compiler/xla/service/gpu/gpu_hlo_cost_analysis.cc | 1 - 2 files changed, 2 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index e473aa260f6aed..c386958cbee47e 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -2954,7 +2954,6 @@ cc_library( deps = [ ":backend_configs_cc", ":cublas_cudnn", - ":ir_emission_utils", "//tensorflow/compiler/xla/hlo/ir:hlo", "//tensorflow/compiler/xla/service:elemental_ir_emitter", "//tensorflow/compiler/xla/service:hlo_cost_analysis", diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/gpu/gpu_hlo_cost_analysis.cc index 211dcc285df083..d9f4927628b43e 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_cost_analysis.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_cost_analysis.cc @@ -23,7 +23,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/elemental_ir_emitter.h" #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" #include "tensorflow/compiler/xla/service/gpu/cublas_cudnn.h" -#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" namespace xla { namespace gpu { From 21eb51467976f311af039352628da6f1bfe26617 Mon Sep 17 00:00:00 2001 From: George Necula Date: Wed, 12 Jul 2023 00:22:35 -0700 Subject: [PATCH 172/376] Refactor xla_call_module_loader to share refine_polymorphic_shapes Previously, we had duplicated functionality for the refinement of polymorphic shapes in refine_polymorphic_shapes (used from JAX) and xla_call_module_loader (used by tf.XlaCallModule). We now consolidate and share this functionality in refine_polymorphic_shapes. We move incorporate ValidateStaticShapes into RefinePolymorphicShapes. This is in preparation for augmenting the refine_polymorphic_shapes with shape assertion handling. PiperOrigin-RevId: 547409633 --- tensorflow/compiler/tf2xla/kernels/BUILD | 1 + .../tf2xla/kernels/xla_call_module_loader.cc | 65 ++----------- .../tf2xla/kernels/xla_call_module_op.cc | 1 - tensorflow/compiler/xla/python/BUILD | 4 +- .../xla/python/refine_polymorphic_shapes.cc | 96 ++++++++++++++----- .../xla/python/refine_polymorphic_shapes.h | 15 ++- 6 files changed, 94 insertions(+), 88 deletions(-) diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index 6d940eccedc72f..1c60ba5874746b 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -394,6 +394,7 @@ cc_library( "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/mlir_hlo", "//tensorflow/compiler/xla/mlir_hlo:mhlo_passes", + "//tensorflow/compiler/xla/python:refine_polymorphic_shapes", "//tensorflow/compiler/xla/translate/hlo_to_mhlo:hlo_utils", "//tensorflow/compiler/xla/translate/mhlo_to_hlo:mlir_hlo_to_hlo", "//tensorflow/tsl/platform:errors", diff --git a/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.cc b/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.cc index e265184ad2d42d..0b84d5434776bd 100644 --- a/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.cc +++ b/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.cc @@ -54,6 +54,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/passes.h" +#include "tensorflow/compiler/xla/python/refine_polymorphic_shapes.h" #include "tensorflow/compiler/xla/translate/hlo_to_mhlo/hlo_utils.h" #include "tensorflow/compiler/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h" #include "tensorflow/tsl/platform/errors.h" @@ -379,7 +380,7 @@ tsl::Status XlaCallModuleLoader::RefineDynamicShapes( auto arg = main_body.getArgument(i); arg.setType(static_array_input_types[i]); // If the argument is used by `func.return`, then we also need to - // update function result types. It's not great that we need this hack, + // update the function result types. It's not great that we need this hack, // but in the future when we have stablehlo.func, stablehlo.return, etc, // this will not be needed. // TODO(burmako): Once https://github.com/openxla/stablehlo/issues/425 is @@ -396,33 +397,10 @@ tsl::Status XlaCallModuleLoader::RefineDynamicShapes( DumpMlirOpToFile("xla_call_module.after_refined_input_types", *module_); } - // Verify the module before running passes on it. - // If the module doesn't pass verification, all sorts of weirdness might - // happen if we run the pass manager. - { - mlir::StatusScopedDiagnosticHandler diag_handler(module_->getContext()); - - if (failed(verify(*module_))) { - return absl::InvalidArgumentError( - absl::StrCat("Module verification failed: ", - diag_handler.ConsumeStatus().ToString())); - } - - mlir::PassManager pm(module_->getContext()); - applyTensorflowAndCLOptions(pm); - pm.addPass(mlir::createCSEPass()); - pm.addPass(mlir::stablehlo::createStablehloRefineShapesPass()); - pm.addNestedPass( - mlir::stablehlo::createStablehloCanonicalizeDynamismPass()); - if (mlir::failed(pm.run(*module_))) { - return absl::InvalidArgumentError( - absl::StrCat("Module shape refinement failed: ", - diag_handler.ConsumeStatus().ToString())); - } + TF_RETURN_IF_ERROR(xla::RefinePolymorphicShapes(*module_)); - if (VLOG_IS_ON(3)) { - DumpMlirOpToFile("xla_call_module.after_shape_refinement", *module_); - } + if (VLOG_IS_ON(3)) { + DumpMlirOpToFile("xla_call_module.after_shape_refinement", *module_); } return tsl::OkStatus(); } @@ -467,7 +445,7 @@ tsl::Status XlaCallModuleLoader::LoadAndPreprocessModule( return absl::InvalidArgumentError("Cannot deserialize computation"); } - VLOG(3) << "Parsed serialized module (version " << version + VLOG(3) << "Parsed serialized module (version = " << version << ", platforms = [" << absl::StrJoin(platforms, ", ") << "], loading_platform = " << loading_platform << ", dim_args_spec = [" << absl::StrJoin(dim_args_spec_, ", ") @@ -559,37 +537,6 @@ tsl::Status XlaCallModuleLoader::ValidateDialect() { return tsl::OkStatus(); } -tsl::Status XlaCallModuleLoader::ValidateStaticShapes() { - mlir::StatusScopedDiagnosticHandler diag_handler(module_->getContext()); - bool moduleHasDynamicShapes = false; - - module_->walk([&](mlir::Operation *op) { - // It's sufficient to only check results because operands either come from - // results or from block arguments which are checked below. - auto hasDynamicShape = [](mlir::Value value) { - auto shaped_type = value.getType().dyn_cast(); - return shaped_type ? !shaped_type.hasStaticShape() : false; - }; - bool opHasDynamicShapes = false; - opHasDynamicShapes |= llvm::any_of(op->getResults(), hasDynamicShape); - for (mlir::Region ®ion : op->getRegions()) { - opHasDynamicShapes |= - llvm::any_of(region.getArguments(), hasDynamicShape); - } - if (opHasDynamicShapes) { - moduleHasDynamicShapes = true; - op->emitOpError() << "has dynamic shapes"; - } - }); - - if (moduleHasDynamicShapes) { - return absl::InvalidArgumentError( - absl::StrCat("Module has dynamic shapes: ", - diag_handler.ConsumeStatus().ToString())); - } - return tsl::OkStatus(); -} - absl::Status XlaCallModuleLoader::LowerModuleToMhlo() { mlir::StatusScopedDiagnosticHandler diag_handler(module_->getContext()); diff --git a/tensorflow/compiler/tf2xla/kernels/xla_call_module_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_call_module_op.cc index 5445c93c7be650..cd1a09de334223 100644 --- a/tensorflow/compiler/tf2xla/kernels/xla_call_module_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/xla_call_module_op.cc @@ -217,7 +217,6 @@ class XlaCallModuleOp : public XlaOpKernel { input_shapes.push_back(*std::move(shape)); } OP_REQUIRES_OK(ctx, loader_->RefineDynamicShapes(input_shapes)); - OP_REQUIRES_OK(ctx, loader_->ValidateStaticShapes()); OP_REQUIRES_OK(ctx, loader_->LowerModuleToMhlo()); if (!function_list_.empty()) { OP_REQUIRES_OK(ctx, LowerTfFunctionCalls(ctx)); diff --git a/tensorflow/compiler/xla/python/BUILD b/tensorflow/compiler/xla/python/BUILD index 9967de1243a745..cd68ae2423b4f5 100644 --- a/tensorflow/compiler/xla/python/BUILD +++ b/tensorflow/compiler/xla/python/BUILD @@ -746,9 +746,9 @@ cc_library( srcs = ["refine_polymorphic_shapes.cc"], hdrs = ["refine_polymorphic_shapes.h"], deps = [ - "//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla/mlir/utils:error_util", - "@com_google_absl//absl/log", + "//tensorflow/tsl/platform:errors", + "//tensorflow/tsl/platform:logging", "@com_google_absl//absl/status", "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", diff --git a/tensorflow/compiler/xla/python/refine_polymorphic_shapes.cc b/tensorflow/compiler/xla/python/refine_polymorphic_shapes.cc index 66b0a459438cdf..91b16e3b6b0e4a 100644 --- a/tensorflow/compiler/xla/python/refine_polymorphic_shapes.cc +++ b/tensorflow/compiler/xla/python/refine_polymorphic_shapes.cc @@ -15,12 +15,12 @@ limitations under the License. #include "tensorflow/compiler/xla/python/refine_polymorphic_shapes.h" -#include "absl/log/log.h" #include "absl/status/status.h" #include "llvm/Support/raw_ostream.h" #include "mlir/Bytecode/BytecodeWriter.h" // from @llvm-project #include "mlir/Dialect/Func/Extensions/AllExtensions.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Verifier.h" // from @llvm-project #include "mlir/Parser/Parser.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project @@ -29,49 +29,101 @@ limitations under the License. #include "stablehlo/dialect/StablehloOps.h" // from @stablehlo #include "stablehlo/transforms/Passes.h" // from @stablehlo #include "tensorflow/compiler/xla/mlir/utils/error_util.h" -#include "tensorflow/compiler/xla/status.h" +#include "tensorflow/tsl/platform/errors.h" namespace xla { -xla::Status RefinePolymorphicShapes(llvm::StringRef module_str, - llvm::raw_ostream &os) { - mlir::MLIRContext context; - if (VLOG_IS_ON(3)) context.disableMultithreading(); - context.loadDialect(); - context.loadDialect(); - context.loadDialect(); +absl::Status RefinePolymorphicShapes(mlir::ModuleOp module) { + mlir::MLIRContext *context = module->getContext(); + if (VLOG_IS_ON(3)) context->disableMultithreading(); - mlir::DialectRegistry registry; - mlir::func::registerAllExtensions(registry); - context.appendDialectRegistry(registry); + // Verify the module before running passes on it. + // If the module doesn't pass verification, all sorts of weirdness might + // happen if we run the pass manager. + mlir::BaseScopedDiagnosticHandler diag_handler(context); - auto module = mlir::parseSourceString( - llvm::StringRef(module_str.data(), module_str.size()), &context); - if (!module || failed(module->verifyInvariants())) { - return absl::InvalidArgumentError("Cannot parse module."); + if (mlir::failed(mlir::verify(module))) { + return absl::InvalidArgumentError( + absl::StrCat("Module verification failed: ", + diag_handler.ConsumeStatus().ToString())); } - mlir::PassManager pm(&context); + mlir::PassManager pm(context); if (VLOG_IS_ON(3)) { auto print_before = [](mlir::Pass *, mlir::Operation *) { return true; }; auto print_after = [](mlir::Pass *, mlir::Operation *) { return true; }; pm.enableIRPrinting(print_before, print_after, /*printModuleScope=*/true, - /*printAfterOnlyOnChange=*/false); + /*printAfterOnlyOnChange=*/true); } + // TODO(necula): we should not need the inliner. pm.addPass(mlir::createInlinerPass()); pm.addPass(mlir::createCSEPass()); pm.addPass(mlir::stablehlo::createStablehloRefineShapesPass()); pm.addNestedPass( mlir::stablehlo::createStablehloCanonicalizeDynamismPass()); - if (!mlir::succeeded(pm.run(*module))) { - return absl::InternalError("Cannot refine shapes."); + if (!mlir::succeeded(pm.run(module))) { + return absl::InvalidArgumentError( + absl::StrCat("Module shape refinement failed: ", + diag_handler.ConsumeStatus().ToString())); } + return ValidateStaticShapes(module); +} - if (failed(mlir::writeBytecodeToFile(*module, os))) { +absl::Status RefinePolymorphicShapes(llvm::StringRef module_str, + llvm::raw_ostream &os) { + mlir::MLIRContext context; + if (VLOG_IS_ON(3)) context.disableMultithreading(); + context.loadDialect(); + context.loadDialect(); + context.loadDialect(); + + mlir::DialectRegistry registry; + mlir::func::registerAllExtensions(registry); + context.appendDialectRegistry(registry); + + mlir::OwningOpRef module = + mlir::parseSourceString( + llvm::StringRef(module_str.data(), module_str.size()), &context); + if (!module) { + return absl::InvalidArgumentError("Cannot parse module."); + } + TF_RETURN_IF_ERROR(RefinePolymorphicShapes(*module)); + if (mlir::failed(mlir::writeBytecodeToFile(*module, os))) { return absl::InternalError("Cannot serialize module."); } - return xla::OkStatus(); + return absl::OkStatus(); +} + +absl::Status ValidateStaticShapes(mlir::ModuleOp module) { + mlir::BaseScopedDiagnosticHandler diag_handler(module->getContext()); + bool moduleHasDynamicShapes = false; + + module->walk([&](mlir::Operation *op) { + // It's sufficient to only check results because operands either come from + // results or from block arguments which are checked below. + auto hasDynamicShape = [](mlir::Value value) { + auto shaped_type = value.getType().dyn_cast(); + return shaped_type ? !shaped_type.hasStaticShape() : false; + }; + bool opHasDynamicShapes = false; + opHasDynamicShapes |= llvm::any_of(op->getResults(), hasDynamicShape); + for (mlir::Region ®ion : op->getRegions()) { + opHasDynamicShapes |= + llvm::any_of(region.getArguments(), hasDynamicShape); + } + if (opHasDynamicShapes) { + moduleHasDynamicShapes = true; + op->emitOpError() << "has dynamic shapes"; + } + }); + + if (moduleHasDynamicShapes) { + return absl::InvalidArgumentError( + absl::StrCat("Module has dynamic shapes: ", + diag_handler.ConsumeStatus().ToString())); + } + return absl::OkStatus(); } } // namespace xla diff --git a/tensorflow/compiler/xla/python/refine_polymorphic_shapes.h b/tensorflow/compiler/xla/python/refine_polymorphic_shapes.h index ac020be1d75977..4f553d5f64c73d 100644 --- a/tensorflow/compiler/xla/python/refine_polymorphic_shapes.h +++ b/tensorflow/compiler/xla/python/refine_polymorphic_shapes.h @@ -16,16 +16,23 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_PYTHON_REFINE_POLYMORPHIC_SHAPES_H_ #define TENSORFLOW_COMPILER_XLA_PYTHON_REFINE_POLYMORPHIC_SHAPES_H_ +#include "absl/status/status.h" #include "llvm/Support/raw_ostream.h" -#include "tensorflow/compiler/xla/status.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project namespace xla { // Refines the dynamic shapes for a module whose "main" has static shapes // and all the intermediate dynamic shapes depend only on the input static -// shapes. Serializes the refined module to the given `os`. -xla::Status RefinePolymorphicShapes(llvm::StringRef module_str, - llvm::raw_ostream &os); +// shapes. +absl::Status RefinePolymorphicShapes(mlir::ModuleOp module); + +// Like the above but with serialized input and output modules. +absl::Status RefinePolymorphicShapes(llvm::StringRef module_str, + llvm::raw_ostream &os); + +// Validates that the module has only static shapes. +absl::Status ValidateStaticShapes(mlir::ModuleOp module); } // namespace xla From 941c1ee647e7b52d94cb22fa384afbc45f07a051 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 12 Jul 2023 00:47:29 -0700 Subject: [PATCH 173/376] Internal Code Change PiperOrigin-RevId: 547416046 --- tensorflow/core/grappler/utils/pattern_utils.cc | 3 ++- tensorflow/core/grappler/utils/scc_test.cc | 5 ++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/tensorflow/core/grappler/utils/pattern_utils.cc b/tensorflow/core/grappler/utils/pattern_utils.cc index 2d4c0a9b5a1c05..1bf827fcc6f98e 100644 --- a/tensorflow/core/grappler/utils/pattern_utils.cc +++ b/tensorflow/core/grappler/utils/pattern_utils.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/core/grappler/utils/pattern_utils.h" #include +#include #include "absl/container/flat_hash_set.h" @@ -171,7 +172,7 @@ bool SubGraphMatcher::GetMatchedNodes( MutableNodeView* node_view, std::map* matched_nodes_map, std::set* remove_node_indices) { bool found_match = false; - match_.reset(new NodeViewMatch()); + match_ = std::make_unique(); if (DoesOpTypePatternMatch(pattern, node_view, match_.get())) { if (IsSafeNodesToRemove(nodes_to_preserve)) { found_match = true; diff --git a/tensorflow/core/grappler/utils/scc_test.cc b/tensorflow/core/grappler/utils/scc_test.cc index b5fa76ef8bf4fc..b43fc1c40fdf1a 100644 --- a/tensorflow/core/grappler/utils/scc_test.cc +++ b/tensorflow/core/grappler/utils/scc_test.cc @@ -14,6 +14,9 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/grappler/utils/scc.h" + +#include + #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/grappler/clusters/virtual_cluster.h" #include "tensorflow/core/grappler/grappler_item.h" @@ -31,7 +34,7 @@ class SCCTest : public ::testing::Test { std::unordered_map devices; DeviceProperties unknown_device; devices["MY_DEVICE"] = unknown_device; - cluster_.reset(new VirtualCluster(devices)); + cluster_ = std::make_unique(devices); TF_CHECK_OK(cluster_->Provision()); } From 94eeda928e422ea5ee2b0e9c2ae74efdc63afda7 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 12 Jul 2023 00:52:23 -0700 Subject: [PATCH 174/376] Integrate LLVM at llvm/llvm-project@5671f023042b Updates LLVM usage to match [5671f023042b](https://github.com/llvm/llvm-project/commit/5671f023042b) PiperOrigin-RevId: 547417183 --- third_party/llvm/generated.patch | 28 +++++++++++++++++----------- third_party/llvm/workspace.bzl | 4 ++-- 2 files changed, 19 insertions(+), 13 deletions(-) diff --git a/third_party/llvm/generated.patch b/third_party/llvm/generated.patch index a7f99f08514996..5539280dba4e32 100644 --- a/third_party/llvm/generated.patch +++ b/third_party/llvm/generated.patch @@ -1,12 +1,18 @@ Auto generated patch. Do not edit or delete it, even if empty. -diff -ruN --strip-trailing-cr a/llvm/lib/Target/WebAssembly/MCTargetDesc/WebAssemblyMCCodeEmitter.cpp b/llvm/lib/Target/WebAssembly/MCTargetDesc/WebAssemblyMCCodeEmitter.cpp ---- a/llvm/lib/Target/WebAssembly/MCTargetDesc/WebAssemblyMCCodeEmitter.cpp -+++ b/llvm/lib/Target/WebAssembly/MCTargetDesc/WebAssemblyMCCodeEmitter.cpp -@@ -127,6 +127,7 @@ - Ctx.reportError( - SMLoc(), - Twine("Wasm globals should only be accessed symbolically!")); -+ break; - default: - encodeULEB128(uint64_t(MO.getImm()), OS); - } +diff -ruN --strip-trailing-cr a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp +--- a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp ++++ b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp +@@ -79,10 +79,10 @@ + void EmulateFloatPattern::rewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const { + Location loc = op->getLoc(); ++ TypeConverter *converter = getTypeConverter(); + SmallVector resultTypes; +- assert( +- succeeded(getTypeConverter()->convertTypes(op->getResultTypes(), resultTypes)) && +- "type conversions shouldn't fail in this pass"); ++ LogicalResult pass = converter->convertTypes(op->getResultTypes(), resultTypes); ++ (void) pass; + Operation *expandedOp = + rewriter.create(loc, op->getName().getIdentifier(), operands, resultTypes, + op->getAttrs(), op->getSuccessors(), /*regions=*/{}); diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl index f3bd820fe1217d..25289428d729d0 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "be29fe2f987b5bf58d7f6aa77c06e58d9402064a" - LLVM_SHA256 = "96b8dbd215400b2434823ae57a5dd53f84cb2162001a31f2ea65fdfe3c06e9ab" + LLVM_COMMIT = "5671f023042b558d38c3b777ee4ae0ad037b1867" + LLVM_SHA256 = "353607dd4ca5b20e6a2ec6650353dd5de006829e5be502716383624152bb1f0f" tf_http_archive( name = name, From 08ecf8a38f5ea8a0cd993c7e685d19ec15ba0f2a Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 12 Jul 2023 01:26:29 -0700 Subject: [PATCH 175/376] Update TFRT dependency to use revision http://github.com/tensorflow/runtime/commit/cf061a8afb57bad642bcc01442c414bf76fc3074. PiperOrigin-RevId: 547424096 --- third_party/tf_runtime/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/tf_runtime/workspace.bzl b/third_party/tf_runtime/workspace.bzl index 45799d33c17189..b9bb8c2fba04f5 100644 --- a/third_party/tf_runtime/workspace.bzl +++ b/third_party/tf_runtime/workspace.bzl @@ -6,8 +6,8 @@ def repo(): """Imports TFRT.""" # Attention: tools parse and update these lines. - TFRT_COMMIT = "a9afd3d20d538d81145a841ffdf20faf48dc69f8" - TFRT_SHA256 = "9ceb85b1bc9350c2c0a3f381fce8604173484f969f56d872239bac29a650f060" + TFRT_COMMIT = "cf061a8afb57bad642bcc01442c414bf76fc3074" + TFRT_SHA256 = "d4d05de303b5126d0e648a40e1c74091013d022179a5a0fa3fe2d13f7a73f2de" tf_http_archive( name = "tf_runtime", From aeb69db9d47f4e67eeba025a94181ebf59263b7b Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Wed, 12 Jul 2023 01:43:49 -0700 Subject: [PATCH 176/376] Add gpu_asm_compiler and gpu_asm_opts_util deps behind a guard. The includes and the usages are behind a guard, so the dependency can be behind a guard as well. PiperOrigin-RevId: 547427378 --- tensorflow/compiler/xla/service/gpu/BUILD | 8 +++++--- tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc | 2 +- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index c386958cbee47e..ca2ee0cb3c66b7 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -984,8 +984,8 @@ cc_library( srcs = ["ir_emission_utils.cc"], hdrs = ["ir_emission_utils.h"], compatible_with = get_compatible_with_portable(), + defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), deps = [ - ":gpu_asm_opts_util", ":target_util", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/hlo/ir:hlo", @@ -996,12 +996,14 @@ cc_library( "//tensorflow/compiler/xla/service/llvm_ir:llvm_type_conversion_util", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/compiler/xla/stream_executor", - "//tensorflow/compiler/xla/stream_executor/gpu:asm_compiler", "//tensorflow/compiler/xla/translate/mhlo_to_hlo:type_to_shape", "@com_google_absl//absl/container:flat_hash_set", "@llvm-project//llvm:Core", "@llvm-project//mlir:ArithDialect", - ], + ] + if_cuda_is_configured([ + ":gpu_asm_opts_util", + "//tensorflow/compiler/xla/stream_executor/gpu:asm_compiler", + ]), ) xla_cc_test( diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc index b2fb99ac46c0f4..1c63929cf17674 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc @@ -32,7 +32,6 @@ limitations under the License. #include "tensorflow/compiler/xla/hlo/ir/hlo_instruction.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_opcode.h" #include "tensorflow/compiler/xla/mlir_hlo/mhlo/IR/hlo_ops.h" -#include "tensorflow/compiler/xla/service/gpu/gpu_asm_opts_util.h" #include "tensorflow/compiler/xla/service/gpu/target_util.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_type_conversion_util.h" @@ -42,6 +41,7 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #ifdef GOOGLE_CUDA +#include "tensorflow/compiler/xla/service/gpu/gpu_asm_opts_util.h" #include "tensorflow/compiler/xla/stream_executor/gpu/asm_compiler.h" #endif // GOOGLE_CUDA From 2eec7c50ae23e8693b4e574c441e89a317de2530 Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Wed, 12 Jul 2023 01:57:32 -0700 Subject: [PATCH 177/376] Also simplify Bitcast(Broadcast) -> Broadcast if possible. If the bitcast is a bitcast transpose, we can do the same kind of simplification as for Transpose(Broadcast). Reuse the existing code, and add a check for whether the simplification would create a broadcast that does an implicit transpose. In such a case, don't apply the simplification. PiperOrigin-RevId: 547429982 --- .../xla/service/algebraic_simplifier.cc | 141 ++++++++++++------ .../xla/service/algebraic_simplifier.h | 5 + .../xla/service/algebraic_simplifier_test.cc | 36 +++++ 3 files changed, 140 insertions(+), 42 deletions(-) diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 17980a6c84413e..63889515ff4a37 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -1130,8 +1130,26 @@ Status AlgebraicSimplifierVisitor::HandleBitcast(HloInstruction* bitcast) { ReplaceWithNewInstruction(bitcast, std::move(new_bitcast))); bitcast = new_bitcast_ptr; } + // All bitcasts can be eliminated (assuming layout constraints are satisfied). - ReplaceInstructionIfCompatible(bitcast, bitcast->mutable_operand(0)); + HloInstruction* new_bitcast = bitcast->mutable_operand(0); + if (ReplaceInstructionIfCompatible(bitcast, new_bitcast)) { + bitcast = new_bitcast; + } + + // Check whether we can potentially simplify the bitcast into a broadcast + // operand. + if (bitcast->opcode() == HloOpcode::kBitcast && + bitcast->operand(0)->opcode() == HloOpcode::kBroadcast) { + // DeduceTransposeDimensionsForBitcast() checks whether the bitcast is a + // transpose and returns the dimensions attribute if it is. + auto dimensions = ShapeUtil::DeduceTransposeDimensionsForBitcast( + bitcast->operand(0)->shape(), bitcast->shape()); + if (dimensions.has_value()) { + return SimplifyTransposeOfBroadcast(bitcast, dimensions.value()); + } + } + return OkStatus(); } @@ -2232,6 +2250,84 @@ StatusOr AlgebraicSimplifierVisitor::RemoveDegenerateDimensionFromDot( return true; } +// transpose(broadcast(x)) -> broadcast(x), if the transpose leaves the relative +// order of the dimensions of `x` unchanged. +// +// To understand the permutations logic here, consider a simple case. +// +// bcast = f32[1,2,3,4] broadcast(f32[2,4] x), dimensions={1,3} +// trans = f32[2,3,1,4] transpose(f32[1,2,3,4] bcast), dimensions={1,2,0,3} +// +// We want to transform this into +// +// bcast' = f32[2,3,1,4] broadcast(f32[2,4] x), dimensions={0,3} +Status AlgebraicSimplifierVisitor::SimplifyTransposeOfBroadcast( + HloInstruction* transpose, absl::Span dimensions) { + HloInstruction* broadcast = transpose->mutable_operand(0); + if (broadcast->opcode() != HloOpcode::kBroadcast || + !absl::c_is_sorted(broadcast->dimensions())) { + return OkStatus(); + } + + // The algorithm to compute bcast'.dimensions() is: + // + // * Let p' be the inverse of trans.dimensions(); in the example, {2,0,1,3}. + // * bcast'.dimensions() is [p'[dim] for dim in bcast.dimensions()]. In the + // example, p'[1] = 0, meaning that broadcast dim 1 (size 2) ends up at + // index 0 after the transpose. + // + // We also need to check that bcast'.dimensions() is "sorted the same" as + // bcast.dimensions() -- otherwise, we're simply moving the transpose into the + // broadcast op. For now we cowardly refuse to consider broadcasts except + // where their dimensions() are sorted, so we need only check that + // bcast'.dimensions() is sorted. + // + // No one-user requirement on the transpose because having two different + // broadcasts of x should be cheap -- certainly cheaper than using the + // fully-materialized broadcasted+transposed value. + + auto inv_perm = InversePermutation(dimensions); + absl::InlinedVector new_bcast_dims; + for (int64_t dim : broadcast->dimensions()) { + new_bcast_dims.push_back(inv_perm[dim]); + } + if (!absl::c_is_sorted(new_bcast_dims)) { + return OkStatus(); + } + // We don't want to create broadcasts that create implicit transposes. Check + // whether the relative order of the layout of the broadcasted dimensions is + // the same as the broadcast operand layout. + if (options_.is_layout_sensitive()) { + std::vector perm1(new_bcast_dims.size()); + absl::c_iota(perm1, 0); + std::vector perm2 = perm1; + Layout operand_layout = broadcast->mutable_operand(0)->shape().layout(); + absl::c_sort(perm1, [&](int a, int b) { + return operand_layout.minor_to_major(a) < + operand_layout.minor_to_major(b); + }); + Layout transpose_layout = transpose->shape().layout(); + // Extract the part of the layout that corresponds to the broadcasted + // dimensions. + std::vector extracted_layout; + extracted_layout.reserve(new_bcast_dims.size()); + for (int64_t dim : transpose_layout.minor_to_major()) { + if (absl::c_binary_search(new_bcast_dims, dim)) { + extracted_layout.push_back(dim); + } + } + absl::c_sort(perm2, [&](int a, int b) { + return extracted_layout[a] < extracted_layout[b]; + }); + if (perm1 != perm2) { + return OkStatus(); + } + } + return ReplaceInstruction( + transpose, MakeBroadcastHlo(broadcast->mutable_operand(0), new_bcast_dims, + transpose->shape())); +} + StatusOr AlgebraicSimplifierVisitor::RemoveTransposesFromDotOperands( HloInstruction* dot) { const int64_t rank = dot->shape().rank(); @@ -6975,47 +7071,8 @@ Status AlgebraicSimplifierVisitor::HandleTranspose(HloInstruction* transpose) { } } - // transpose(broadcast(x)) -> broadcast(x), if the transpose leaves the - // relative order of the dimensions of `x` unchanged. - // - // To understand the permutations logic here, consider a simple case. - // - // bcast = f32[1,2,3,4] broadcast(f32[2,4] x), dimensions={1,3} - // trans = f32[2,3,1,4] transpose(f32[1,2,3,4] bcast), dimensions={1,2,0,3} - // - // We want to transform this into - // - // bcast' = f32[2,3,1,4] broadcast(f32[2,4] x), dimensions={0,3} - // - // The algorithm to compute bcast'.dimensions() is: - // - // * Let p' be the inverse of trans.dimensions(); in the example, {2,0,1,3}. - // * bcast'.dimensions() is [p'[dim] for dim in bcast.dimensions()]. In the - // example, p'[1] = 0, meaning that broadcast dim 1 (size 2) ends up at - // index 0 after the transpose. - // - // We also need to check that bcast'.dimensions() is "sorted the same" as - // bcast.dimensions() -- otherwise, we're simply moving the transpose into the - // broadcast op. For now we cowardly refuse to consider broadcasts except - // where their dimensions() are sorted, so we need only check that - // bcast'.dimensions() is sorted. - // - // No one-user requirement on the transpose because having two different - // broadcasts of x should be cheap -- certainly cheaper than using the - // fully-materialized broadcasted+transposed value. - if (operand->opcode() == HloOpcode::kBroadcast && - absl::c_is_sorted(operand->dimensions())) { - auto inv_perm = InversePermutation(transpose->dimensions()); - absl::InlinedVector new_bcast_dims; - for (int64_t dim : operand->dimensions()) { - new_bcast_dims.push_back(inv_perm[dim]); - } - if (absl::c_is_sorted(new_bcast_dims)) { - return ReplaceInstruction( - transpose, MakeBroadcastHlo(operand->mutable_operand(0), - new_bcast_dims, transpose->shape())); - } - } + TF_RETURN_IF_ERROR( + SimplifyTransposeOfBroadcast(transpose, transpose->dimensions())); return OkStatus(); } diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.h b/tensorflow/compiler/xla/service/algebraic_simplifier.h index 0d016cf2ffae3a..0d5a13db092a9c 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.h +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.h @@ -443,6 +443,11 @@ class AlgebraicSimplifierVisitor : public DfsHloRewriteVisitor { // Removes degenerate dimension from dot. StatusOr RemoveDegenerateDimensionFromDot(HloInstruction* dot); + // Moves the transpose to the broadcast if possible. Can also be called with a + // bitcast transpose. + Status SimplifyTransposeOfBroadcast(HloInstruction* transpose, + absl::Span dimensions); + // Converts to primitive type if the input hlo is not that type, otherwise // returns the original hlo. HloInstruction* AsType(HloInstruction* hlo, diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index b94c0eb3941502..18eea15b0b1066 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -9451,6 +9451,42 @@ TEST_F(AlgebraicSimplifierTest, TransposeOfBroadcast) { }))); } +TEST_F(AlgebraicSimplifierTest, TransposeBitcastOfBroadcast) { + const char* kModuleStr = R"( + HloModule m + test { + bcast = f32[10,2,3,4]{3,2,1,0} broadcast(f32[2,4]{1,0} parameter(0)), dimensions={1,3} + ROOT trans = f32[2,3,10,4]{3,1,0,2} bitcast(bcast) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + AlgebraicSimplifierOptions options; + options.set_is_layout_sensitive(true); + EXPECT_TRUE(RunHloPass(AlgebraicSimplifier(options), m.get()).value()); + SCOPED_TRACE(m->ToString()); + EXPECT_THAT( + m->entry_computation()->root_instruction(), + GmockMatch( + m::Broadcast(m::Parameter(0)) + .WithPredicate([](const HloInstruction* instr) { + return instr->dimensions() == std::vector({0, 3}); + }))); +} + +TEST_F(AlgebraicSimplifierTest, TransposeOfBroadcastWithLayoutCheckSkipped) { + const char* kModuleStr = R"( + HloModule m + test { + bcast = f32[10,2,3,4]{3,2,1,0} broadcast(f32[2,4]{1,0} parameter(0)), dimensions={1,3} + ROOT trans = f32[2,3,10,4]{0,1,2,3} transpose(bcast), dimensions={1,2,0,3} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + AlgebraicSimplifierOptions options; + options.set_is_layout_sensitive(true); + EXPECT_FALSE(RunHloPass(AlgebraicSimplifier(options), m.get()).value()); +} + TEST_F(AlgebraicSimplifierTest, TransposeOfBroadcastSkipped) { const char* kModuleStr = R"( HloModule m From 0ee90b91c96ddf3b5bade97e767138ec45fafe7a Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 12 Jul 2023 02:03:31 -0700 Subject: [PATCH 178/376] compat: Update forward compatibility horizon to 2023-07-12 PiperOrigin-RevId: 547431222 --- tensorflow/python/compat/compat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index 35e0da462e4704..c2a858c6505dee 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -29,7 +29,7 @@ # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2023, 7, 11) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2023, 7, 12) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None From f5ca7efbd6dd60c8fbcf44965310fb3757010f77 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 12 Jul 2023 02:03:31 -0700 Subject: [PATCH 179/376] Update GraphDef version to 1555. PiperOrigin-RevId: 547431223 --- tensorflow/core/public/version.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index 28e0085b3d3abf..4beda718dbecb6 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -108,7 +108,7 @@ limitations under the License. #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 -#define TF_GRAPH_DEF_VERSION 1554 // Updated: 2023/7/11 +#define TF_GRAPH_DEF_VERSION 1555 // Updated: 2023/7/12 // Checkpoint compatibility versions (the versions field in SavedSliceMeta). // From 121ae3648eacd5defb02b8a41b623f63ccffd252 Mon Sep 17 00:00:00 2001 From: Shiqing Yan Date: Wed, 12 Jul 2023 02:53:28 -0700 Subject: [PATCH 180/376] [DelegatePerformance] Move the MiniBenchmark C APIs to shims. PiperOrigin-RevId: 547440488 --- .../acceleration/mini_benchmark/c/BUILD | 68 ++++++++++++++++ .../acceleration/mini_benchmark/c/c_api.cc | 6 +- .../acceleration/mini_benchmark/c/c_api.h | 81 +++++++++++++++++++ .../mini_benchmark/c/c_api_test.cc | 4 +- .../mini_benchmark/c/c_api_types.h | 8 +- .../acceleration/mini_benchmark/BUILD | 3 + .../mini_benchmark/build_defs.bzl | 50 +++++++++++- .../acceleration/mini_benchmark/c/BUILD | 52 ++++-------- .../acceleration/mini_benchmark/c/c_api.h | 60 +------------- .../mini_benchmark/special_rules.bzl | 1 + tensorflow/opensource_only.files | 1 + 11 files changed, 225 insertions(+), 109 deletions(-) create mode 100644 tensorflow/lite/core/experimental/acceleration/mini_benchmark/c/BUILD rename tensorflow/lite/{ => core}/experimental/acceleration/mini_benchmark/c/c_api.cc (97%) create mode 100644 tensorflow/lite/core/experimental/acceleration/mini_benchmark/c/c_api.h rename tensorflow/lite/{ => core}/experimental/acceleration/mini_benchmark/c/c_api_test.cc (99%) rename tensorflow/lite/{ => core}/experimental/acceleration/mini_benchmark/c/c_api_types.h (91%) diff --git a/tensorflow/lite/core/experimental/acceleration/mini_benchmark/c/BUILD b/tensorflow/lite/core/experimental/acceleration/mini_benchmark/c/BUILD new file mode 100644 index 00000000000000..2f9af85ed4d3cc --- /dev/null +++ b/tensorflow/lite/core/experimental/acceleration/mini_benchmark/c/BUILD @@ -0,0 +1,68 @@ +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +load("//tensorflow/lite/experimental/acceleration/mini_benchmark:build_defs.bzl", "cc_library_with_forced_in_process_benchmark_variant") +load("//tensorflow/lite/experimental/acceleration/mini_benchmark:special_rules.bzl", "libjpeg_handle_deps", "minibenchmark_visibility_allowlist") + +default_visibility_group = [ + "//tensorflow/lite/experimental/acceleration/mini_benchmark:__subpackages__", +] + minibenchmark_visibility_allowlist() + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = default_visibility_group, + licenses = ["notice"], +) + +cc_library_with_forced_in_process_benchmark_variant( + name = "c_api", + srcs = ["c_api.cc"], + hdrs = ["c_api.h"], + in_process_deps = [ + "//tensorflow/lite/experimental/acceleration/mini_benchmark:blocking_validator_runner", + ], + deps = [ + ":c_api_types", + "//tensorflow/lite/acceleration/configuration:configuration_fbs", + "//tensorflow/lite/acceleration/configuration/c:delegate_plugin", + "//tensorflow/lite/experimental/acceleration/mini_benchmark:benchmark_result_evaluator", + "//tensorflow/lite/experimental/acceleration/mini_benchmark:status_codes", + "//tensorflow/lite/experimental/acceleration/mini_benchmark:validator_runner_entrypoint", + "//tensorflow/lite/experimental/acceleration/mini_benchmark:validator_runner_options", + "@flatbuffers", + ], +) + +cc_test( + name = "c_api_test", + srcs = ["c_api_test.cc"], + deps = [ + ":c_api", + "//tensorflow/lite/acceleration/configuration:configuration_fbs", + "//tensorflow/lite/experimental/acceleration/mini_benchmark:embedded_mobilenet_model", + "//tensorflow/lite/experimental/acceleration/mini_benchmark:embedded_mobilenet_validation_model", + "//tensorflow/lite/experimental/acceleration/mini_benchmark:embedded_simple_addition_model", + "//tensorflow/lite/experimental/acceleration/mini_benchmark:mini_benchmark_test_helper", + "//tensorflow/lite/experimental/acceleration/mini_benchmark:status_codes", + "@com_google_googletest//:gtest_main", + "@flatbuffers", + "@flatbuffers//:runtime_cc", + ] + libjpeg_handle_deps(), +) + +cc_library( + name = "c_api_types", + hdrs = ["c_api_types.h"], + visibility = ["//visibility:private"], +) diff --git a/tensorflow/lite/experimental/acceleration/mini_benchmark/c/c_api.cc b/tensorflow/lite/core/experimental/acceleration/mini_benchmark/c/c_api.cc similarity index 97% rename from tensorflow/lite/experimental/acceleration/mini_benchmark/c/c_api.cc rename to tensorflow/lite/core/experimental/acceleration/mini_benchmark/c/c_api.cc index 95c32f45bd6ff5..51c927836135dc 100644 --- a/tensorflow/lite/experimental/acceleration/mini_benchmark/c/c_api.cc +++ b/tensorflow/lite/core/experimental/acceleration/mini_benchmark/c/c_api.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/lite/experimental/acceleration/mini_benchmark/c/c_api.h" +#include "tensorflow/lite/core/experimental/acceleration/mini_benchmark/c/c_api.h" #include #include @@ -27,7 +27,7 @@ limitations under the License. #include "tensorflow/lite/acceleration/configuration/configuration_generated.h" #include "tensorflow/lite/experimental/acceleration/mini_benchmark/benchmark_result_evaluator.h" #include "tensorflow/lite/experimental/acceleration/mini_benchmark/blocking_validator_runner.h" -#include "tensorflow/lite/experimental/acceleration/mini_benchmark/c/c_api_types.h" +#include "tensorflow/lite/core/experimental/acceleration/mini_benchmark/c/c_api_types.h" #include "tensorflow/lite/experimental/acceleration/mini_benchmark/status_codes.h" #include "tensorflow/lite/experimental/acceleration/mini_benchmark/validator_runner_options.h" diff --git a/tensorflow/lite/core/experimental/acceleration/mini_benchmark/c/c_api.h b/tensorflow/lite/core/experimental/acceleration/mini_benchmark/c/c_api.h new file mode 100644 index 00000000000000..ed5d17d62beba6 --- /dev/null +++ b/tensorflow/lite/core/experimental/acceleration/mini_benchmark/c/c_api.h @@ -0,0 +1,81 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_CORE_EXPERIMENTAL_ACCELERATION_MINI_BENCHMARK_C_C_API_H_ +#define TENSORFLOW_LITE_CORE_EXPERIMENTAL_ACCELERATION_MINI_BENCHMARK_C_C_API_H_ + +#include +#include +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +// APIs of TfLiteMiniBenchmarkResult. +typedef struct TfLiteMiniBenchmarkResult TfLiteMiniBenchmarkResult; +int TfLiteMiniBenchmarkResultInitStatus(TfLiteMiniBenchmarkResult* result); +uint8_t* TfLiteMiniBenchmarkResultFlatBufferData( + TfLiteMiniBenchmarkResult* result); +size_t TfLiteMiniBenchmarkResultFlatBufferDataSize( + TfLiteMiniBenchmarkResult* result); +// Free memory allocated with `result`. +void TfLiteMiniBenchmarkResultFree(TfLiteMiniBenchmarkResult* result); + +// APIs of TfLiteMiniBenchmarkCustomValidationInfo. +typedef struct TfLiteMiniBenchmarkCustomValidationInfo + TfLiteMiniBenchmarkCustomValidationInfo; +void TfLiteMiniBenchmarkCustomValidationInfoSetBuffer( + TfLiteMiniBenchmarkCustomValidationInfo* custom_validation, int batch_size, + uint8_t* buffer, size_t* buffer_dim, int buffer_dim_size); +void TfLiteMiniBenchmarkCustomValidationInfoSetAccuracyValidator( + TfLiteMiniBenchmarkCustomValidationInfo* custom_validation, + void* accuracy_validator_user_data, + bool (*accuracy_validator_func)(void* user_data, + uint8_t* benchmark_result_data, + int benchmark_result_data_size)); + +// APIs of TfLiteMiniBenchmarkSettings. +typedef struct TfLiteMiniBenchmarkSettings TfLiteMiniBenchmarkSettings; +TfLiteMiniBenchmarkSettings* TfLiteMiniBenchmarkSettingsCreate(); +TfLiteMiniBenchmarkCustomValidationInfo* +TfLiteMiniBenchmarkSettingsCustomValidationInfo( + TfLiteMiniBenchmarkSettings* settings); +void TfLiteMiniBenchmarkSettingsSetFlatBufferData( + TfLiteMiniBenchmarkSettings* settings, uint8_t* flatbuffer_data, + size_t flatbuffer_data_size); +void TfLiteMiniBenchmarkSettingsSetErrorReporter( + TfLiteMiniBenchmarkSettings* settings, void* error_reporter_user_data, + int (*error_reporter_func)(void* user_data, const char* format, + va_list args)); +void TfLiteMiniBenchmarkSettingsFree(TfLiteMiniBenchmarkSettings* settings); + +// Others. +// Trigger validation for `settings` and return the validation result. +// This returns a pointer, that you must free using +// TfLiteMiniBenchmarkResultFree(). +TfLiteMiniBenchmarkResult* TfLiteBlockingValidatorRunnerTriggerValidation( + TfLiteMiniBenchmarkSettings* settings); + +// This function is a private function that shouldn't be considered as part of +// the APIs. +// TODO: b/290615172 - Remove the function from this header. +void TfLiteMiniBenchmarkSettingsSetGpuPluginHandle( + TfLiteMiniBenchmarkSettings* settings, void* gpu_plugin_handle); + +#ifdef __cplusplus +} // extern "C". +#endif +#endif // TENSORFLOW_LITE_CORE_EXPERIMENTAL_ACCELERATION_MINI_BENCHMARK_C_C_API_H_ diff --git a/tensorflow/lite/experimental/acceleration/mini_benchmark/c/c_api_test.cc b/tensorflow/lite/core/experimental/acceleration/mini_benchmark/c/c_api_test.cc similarity index 99% rename from tensorflow/lite/experimental/acceleration/mini_benchmark/c/c_api_test.cc rename to tensorflow/lite/core/experimental/acceleration/mini_benchmark/c/c_api_test.cc index 7e251af1d168b2..14ed33a048cdb4 100644 --- a/tensorflow/lite/experimental/acceleration/mini_benchmark/c/c_api_test.cc +++ b/tensorflow/lite/core/experimental/acceleration/mini_benchmark/c/c_api_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/lite/experimental/acceleration/mini_benchmark/c/c_api.h" +#include "tensorflow/lite/core/experimental/acceleration/mini_benchmark/c/c_api.h" #include diff --git a/tensorflow/lite/experimental/acceleration/mini_benchmark/c/c_api_types.h b/tensorflow/lite/core/experimental/acceleration/mini_benchmark/c/c_api_types.h similarity index 91% rename from tensorflow/lite/experimental/acceleration/mini_benchmark/c/c_api_types.h rename to tensorflow/lite/core/experimental/acceleration/mini_benchmark/c/c_api_types.h index bdcef41b9752cf..adace96a21ff94 100644 --- a/tensorflow/lite/experimental/acceleration/mini_benchmark/c/c_api_types.h +++ b/tensorflow/lite/core/experimental/acceleration/mini_benchmark/c/c_api_types.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_ACCELERATION_MINI_BENCHMARK_C_C_API_TYPES_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_ACCELERATION_MINI_BENCHMARK_C_C_API_TYPES_H_ +#ifndef TENSORFLOW_LITE_CORE_EXPERIMENTAL_ACCELERATION_MINI_BENCHMARK_C_C_API_TYPES_H_ +#define TENSORFLOW_LITE_CORE_EXPERIMENTAL_ACCELERATION_MINI_BENCHMARK_C_C_API_TYPES_H_ #include #include @@ -87,4 +87,4 @@ struct TfLiteMiniBenchmarkSettings { } // extern "C". #endif -#endif // TENSORFLOW_LITE_EXPERIMENTAL_ACCELERATION_MINI_BENCHMARK_C_C_API_TYPES_H_ +#endif // TENSORFLOW_LITE_CORE_EXPERIMENTAL_ACCELERATION_MINI_BENCHMARK_C_C_API_TYPES_H_ diff --git a/tensorflow/lite/experimental/acceleration/mini_benchmark/BUILD b/tensorflow/lite/experimental/acceleration/mini_benchmark/BUILD index 3dcaefcb2f4846..1a6282228fd76e 100644 --- a/tensorflow/lite/experimental/acceleration/mini_benchmark/BUILD +++ b/tensorflow/lite/experimental/acceleration/mini_benchmark/BUILD @@ -150,6 +150,9 @@ cc_library( ":libjpeg_handle_hdr", "//tensorflow/lite/core/c:c_api_types", ] + libjpeg_deps(), + # Some targets only have an implicit dependency on LibjpegHandle. + # This avoids warnings about backwards references when linking. + alwayslink = True, ) cc_library( diff --git a/tensorflow/lite/experimental/acceleration/mini_benchmark/build_defs.bzl b/tensorflow/lite/experimental/acceleration/mini_benchmark/build_defs.bzl index 446bbef45e1b2d..1c2a1a3f8cb561 100644 --- a/tensorflow/lite/experimental/acceleration/mini_benchmark/build_defs.bzl +++ b/tensorflow/lite/experimental/acceleration/mini_benchmark/build_defs.bzl @@ -21,6 +21,16 @@ load( load("//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", "add_suffix") load("//tensorflow/lite/experimental/acceleration/mini_benchmark:special_rules.bzl", "libjpeg_handle_deps") +def _concat(lists): + """Concatenate a list of lists, without requiring the inner lists to be iterable. + + This allows the inner lists to be obtained by calls to select(). + """ + result = [] + for selected_list in lists: + result = result + selected_list + return result + def embedded_binary(name, binary, array_variable_name, testonly = False, exec_properties = None): """Create a cc_library that embeds a binary as constant data. @@ -180,28 +190,60 @@ def validation_test(name, validation_model, tags = [], copts = [], deps = []): def cc_library_with_forced_in_process_benchmark_variant( name, deps = [], + forced_in_process_deps = [], in_process_deps = [], + non_in_process_deps_selects = [], **kwargs): """Defines a cc_library that optionally forces benchmark runs in process. This generates two cc_library target. The first one runs the benchmark in a separate process on Android, while it runs the benchmark in process on all - other platforms. The second one, which has "_in_process" appended to the - name, forces benchmark runs in process. + other platforms. It doesn't have TFLITE_ACCELERATION_BENCHMARK_IN_PROCESS + defined. + The second one, which has "_in_process" appended to the name, forces + benchmark runs in process on all platforms. It has + TFLITE_ACCELERATION_BENCHMARK_IN_PROCESS defined. + + The default option for MiniBenchmark is to run the benchmark in a separate + process on Android, as this is safer than running the benchmark in the app + process. However, forcing the benchmark to run in-process on Android allows + the benchmark to reuse the same TF Lite runtime that is initialized in the + application process. These two variants may use different dependencies. + For example, the in-process variant uses the statically linked libjpeg + handle, while the other variant uses the dynamically linked libjpeg handle + on Android to minimize binary size. + + This build rule ensures that the dependencies listed in + "forced_in_process_deps" are added only when + TFLITE_ACCELERATION_BENCHMARK_IN_PROCESS is defined, that the dependencies + listed in "non_in_process_deps_selects" are added only when + TFLITE_ACCELERATION_BENCHMARK_IN_PROCESS is NOT defined, and that + TFLITE_ACCELERATION_BENCHMARK_IN_PROCESS is defined automatically when + using the "_in_process" target. + Args: name: determines the name used for the generated cc_library targets. + forced_in_process_deps: dependencies that will be enabled only when the + benchmark is forced to run in-process on all platforms. This should be + used for dependencies arising from code inside + '#ifdef TFLITE_ACCELERATION_BENCHMARK_IN_PROCESS'. deps: dependencies that will be unconditionally included in the deps of the generated cc_library targets. in_process_deps: dependencies on rules that are themselves defined using 'cc_library_with_forced_in_process_benchmark_variant'. Must be iterable, so cannot be computed by calling 'select'. + non_in_process_deps_selects: A list of dictionaries that will be + converted to dependencies with select on rules. The dependencies will + be enabled only when the benchmark runs in a separate process on + Android. This should be used for dependencies arising from code inside + '#ifndef TFLITE_ACCELERATION_BENCHMARK_IN_PROCESS'. **kwargs: Additional cc_library parameters. """ native.cc_library( name = name, - deps = deps + in_process_deps + [ + deps = deps + in_process_deps + _concat([select(map) for map in non_in_process_deps_selects]) + [ clean_dep("//tensorflow/lite/experimental/acceleration/mini_benchmark:tflite_acceleration_in_process_default"), ], **kwargs @@ -210,7 +252,7 @@ def cc_library_with_forced_in_process_benchmark_variant( in_process_deps_renamed = [add_suffix(in_process_dep, "_in_process") for in_process_dep in in_process_deps] native.cc_library( name = name + "_in_process", - deps = deps + in_process_deps_renamed + [ + deps = deps + in_process_deps_renamed + forced_in_process_deps + [ clean_dep("//tensorflow/lite/experimental/acceleration/mini_benchmark:tflite_acceleration_in_process_enable"), ], **kwargs diff --git a/tensorflow/lite/experimental/acceleration/mini_benchmark/c/BUILD b/tensorflow/lite/experimental/acceleration/mini_benchmark/c/BUILD index b6e36c51f11a96..e55041460e02b5 100644 --- a/tensorflow/lite/experimental/acceleration/mini_benchmark/c/BUILD +++ b/tensorflow/lite/experimental/acceleration/mini_benchmark/c/BUILD @@ -1,4 +1,4 @@ -# Copyright 2022 The TensorFlow Authors. All Rights Reserved. +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,8 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -load("//tensorflow/lite/experimental/acceleration/mini_benchmark:build_defs.bzl", "cc_library_with_forced_in_process_benchmark_variant") -load("//tensorflow/lite/experimental/acceleration/mini_benchmark:special_rules.bzl", "libjpeg_handle_deps", "minibenchmark_visibility_allowlist") +load("//tensorflow/lite:build_def.bzl", "tflite_cc_library_with_c_headers_test") +load("//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", "cc_library_with_tflite_with_c_headers_test") +load("//tensorflow/lite/experimental/acceleration/mini_benchmark:special_rules.bzl", "minibenchmark_visibility_allowlist") default_visibility_group = [ "//tensorflow/lite/experimental/acceleration/mini_benchmark:__subpackages__", @@ -25,44 +26,19 @@ package( licenses = ["notice"], ) -cc_library_with_forced_in_process_benchmark_variant( +# This target runs MiniBenchmark in a separate processon Android, while it runs MiniBenchmark +# in-process on all other platforms. +cc_library_with_tflite_with_c_headers_test( name = "c_api", - srcs = ["c_api.cc"], hdrs = ["c_api.h"], - in_process_deps = [ - "//tensorflow/lite/experimental/acceleration/mini_benchmark:blocking_validator_runner", - ], - deps = [ - ":c_api_types", - "//tensorflow/lite/acceleration/configuration:configuration_fbs", - "//tensorflow/lite/acceleration/configuration/c:delegate_plugin", - "//tensorflow/lite/experimental/acceleration/mini_benchmark:benchmark_result_evaluator", - "//tensorflow/lite/experimental/acceleration/mini_benchmark:status_codes", - "//tensorflow/lite/experimental/acceleration/mini_benchmark:validator_runner_entrypoint", - "//tensorflow/lite/experimental/acceleration/mini_benchmark:validator_runner_options", - "@flatbuffers", - ], + deps = ["//tensorflow/lite/core/experimental/acceleration/mini_benchmark/c:c_api"], ) -cc_test( - name = "c_api_test", - srcs = ["c_api_test.cc"], +# This target forces MiniBenchmark to run in-process on all platforms including Android. +tflite_cc_library_with_c_headers_test( + name = "c_api_in_process", + hdrs = ["c_api.h"], deps = [ - ":c_api", - "//tensorflow/lite/acceleration/configuration:configuration_fbs", - "//tensorflow/lite/experimental/acceleration/mini_benchmark:embedded_mobilenet_model", - "//tensorflow/lite/experimental/acceleration/mini_benchmark:embedded_mobilenet_validation_model", - "//tensorflow/lite/experimental/acceleration/mini_benchmark:embedded_simple_addition_model", - "//tensorflow/lite/experimental/acceleration/mini_benchmark:mini_benchmark_test_helper", - "//tensorflow/lite/experimental/acceleration/mini_benchmark:status_codes", - "@com_google_googletest//:gtest_main", - "@flatbuffers", - "@flatbuffers//:runtime_cc", - ] + libjpeg_handle_deps(), -) - -cc_library( - name = "c_api_types", - hdrs = ["c_api_types.h"], - visibility = ["//visibility:private"], + "//tensorflow/lite/core/experimental/acceleration/mini_benchmark/c:c_api_in_process", + ], ) diff --git a/tensorflow/lite/experimental/acceleration/mini_benchmark/c/c_api.h b/tensorflow/lite/experimental/acceleration/mini_benchmark/c/c_api.h index 2d68200d457461..e62b599d7e5294 100644 --- a/tensorflow/lite/experimental/acceleration/mini_benchmark/c/c_api.h +++ b/tensorflow/lite/experimental/acceleration/mini_benchmark/c/c_api.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,62 +15,6 @@ limitations under the License. #ifndef TENSORFLOW_LITE_EXPERIMENTAL_ACCELERATION_MINI_BENCHMARK_C_C_API_H_ #define TENSORFLOW_LITE_EXPERIMENTAL_ACCELERATION_MINI_BENCHMARK_C_C_API_H_ -#include -#include -#include +#include "tensorflow/lite/core/experimental/acceleration/mini_benchmark/c/c_api.h" // IWYU pragma: export -#ifdef __cplusplus -extern "C" { -#endif - -// APIs of TfLiteMiniBenchmarkResult. -typedef struct TfLiteMiniBenchmarkResult TfLiteMiniBenchmarkResult; -int TfLiteMiniBenchmarkResultInitStatus(TfLiteMiniBenchmarkResult* result); -uint8_t* TfLiteMiniBenchmarkResultFlatBufferData( - TfLiteMiniBenchmarkResult* result); -size_t TfLiteMiniBenchmarkResultFlatBufferDataSize( - TfLiteMiniBenchmarkResult* result); -// Free memory allocated with `result`. -void TfLiteMiniBenchmarkResultFree(TfLiteMiniBenchmarkResult* result); - -// APIs of TfLiteMiniBenchmarkCustomValidationInfo. -typedef struct TfLiteMiniBenchmarkCustomValidationInfo - TfLiteMiniBenchmarkCustomValidationInfo; -void TfLiteMiniBenchmarkCustomValidationInfoSetBuffer( - TfLiteMiniBenchmarkCustomValidationInfo* custom_validation, int batch_size, - uint8_t* buffer, size_t* buffer_dim, int buffer_dim_size); -void TfLiteMiniBenchmarkCustomValidationInfoSetAccuracyValidator( - TfLiteMiniBenchmarkCustomValidationInfo* custom_validation, - void* accuracy_validator_user_data, - bool (*accuracy_validator_func)(void* user_data, - uint8_t* benchmark_result_data, - int benchmark_result_data_size)); - -// APIs of TfLiteMiniBenchmarkSettings. -typedef struct TfLiteMiniBenchmarkSettings TfLiteMiniBenchmarkSettings; -TfLiteMiniBenchmarkSettings* TfLiteMiniBenchmarkSettingsCreate(); -TfLiteMiniBenchmarkCustomValidationInfo* -TfLiteMiniBenchmarkSettingsCustomValidationInfo( - TfLiteMiniBenchmarkSettings* settings); -void TfLiteMiniBenchmarkSettingsSetFlatBufferData( - TfLiteMiniBenchmarkSettings* settings, uint8_t* flatbuffer_data, - size_t flatbuffer_data_size); -void TfLiteMiniBenchmarkSettingsSetErrorReporter( - TfLiteMiniBenchmarkSettings* settings, void* error_reporter_user_data, - int (*error_reporter_func)(void* user_data, const char* format, - va_list args)); -void TfLiteMiniBenchmarkSettingsSetGpuPluginHandle( - TfLiteMiniBenchmarkSettings* settings, void* gpu_plugin_handle); -void TfLiteMiniBenchmarkSettingsFree(TfLiteMiniBenchmarkSettings* settings); - -// Others. -// Trigger validation for `settings` and return the validation result. -// This returns a pointer, that you must free using -// TfLiteMiniBenchmarkResultFree(). -TfLiteMiniBenchmarkResult* TfLiteBlockingValidatorRunnerTriggerValidation( - TfLiteMiniBenchmarkSettings* settings); - -#ifdef __cplusplus -} // extern "C". -#endif #endif // TENSORFLOW_LITE_EXPERIMENTAL_ACCELERATION_MINI_BENCHMARK_C_C_API_H_ diff --git a/tensorflow/lite/experimental/acceleration/mini_benchmark/special_rules.bzl b/tensorflow/lite/experimental/acceleration/mini_benchmark/special_rules.bzl index 522f3d95bec451..aa16873972d0a1 100644 --- a/tensorflow/lite/experimental/acceleration/mini_benchmark/special_rules.bzl +++ b/tensorflow/lite/experimental/acceleration/mini_benchmark/special_rules.bzl @@ -31,6 +31,7 @@ def libjpeg_handle_deps(): def minibenchmark_visibility_allowlist(): """Returns a list of packages that can depend on mini_benchmark.""" return [ + "//tensorflow/lite/core/experimental/acceleration/mini_benchmark/c:__subpackages__", "//tensorflow/lite/tools/benchmark/experimental/delegate_performance:__subpackages__", ] diff --git a/tensorflow/opensource_only.files b/tensorflow/opensource_only.files index aed2a433199a7e..9ccc484ca3a247 100644 --- a/tensorflow/opensource_only.files +++ b/tensorflow/opensource_only.files @@ -99,6 +99,7 @@ tensorflow/lite/delegates/hexagon/hexagon_nn/BUILD: tensorflow/lite/delegates/utils/experimental/sample_stable_delegate/BUILD: tensorflow/lite/delegates/utils/experimental/sample_stable_delegate/sample_stable_delegate_external.cc: tensorflow/lite/experimental/acceleration/configuration/configuration_generated.h: +tensorflow/lite/experimental/acceleration/mini_benchmark/c/c_api.h: tensorflow/lite/experimental/acceleration/mini_benchmark/libjpeg.h: tensorflow/lite/experimental/acceleration/mini_benchmark/special_rules.bzl: tensorflow/lite/interpreter.h: From 1e8ed6db907a2deaf4099462bd2b8b9ebc9b6c1c Mon Sep 17 00:00:00 2001 From: Doyeon Kim Date: Wed, 12 Jul 2023 03:54:12 -0700 Subject: [PATCH 181/376] Replace quantfork.stats to quantfork.q/dq instead of stablehlo.q/dq PiperOrigin-RevId: 547450971 --- .../compiler/mlir/quantization/stablehlo/passes/passes.td | 2 +- .../quantization/stablehlo/passes/prepare_srq_quantize.cc | 4 ++-- .../stablehlo/tests/prepare_srq_quantize.mlir | 8 ++++---- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.td b/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.td index cece788381b0b6..896aa1ba833395 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.td +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.td @@ -24,6 +24,6 @@ def QuantizeWeightPass : Pass<"stablehlo-quantize-weight", "mlir::func::FuncOp"> def PrepareSrqQuantizePass : Pass<"stablehlo-prepare-srq-quantize", "mlir::func::FuncOp"> { let summary = "Prepare StableHLO dialect for static range quantization."; let constructor = "CreatePrepareSrqQuantizePass()"; - let dependentDialects = ["stablehlo::StablehloDialect", "quant::QuantizationDialect"]; + let dependentDialects = ["stablehlo::StablehloDialect", "quant::QuantizationDialect", "quantfork::QuantizationForkDialect"]; } diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/prepare_srq_quantize.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/prepare_srq_quantize.cc index 78e056efef8d2b..12ccddcce58ba7 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/prepare_srq_quantize.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/prepare_srq_quantize.cc @@ -71,8 +71,8 @@ class PrepareSrqQuantizePass }; using ReplaceStatsWithQDQs = - quant::ConvertStatsToQDQs; + quant::ConvertStatsToQDQs; void PrepareSrqQuantizePass::runOnOperation() { func::FuncOp func = getOperation(); diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/prepare_srq_quantize.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/prepare_srq_quantize.mlir index 67e57d23c200b4..4c7d909094fe27 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/prepare_srq_quantize.mlir +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/prepare_srq_quantize.mlir @@ -9,13 +9,13 @@ func.func @main(%arg0: tensor) -> tensor { } // CHECK: %[[cst:.*]] = stablehlo.constant -// CHECK: %[[q1:.*]] = stablehlo.uniform_quantize %arg0 +// CHECK: %[[q1:.*]] = "quantfork.qcast"(%arg0) // CHECK-SAME: quant.uniform -// CHECK: %[[dq1:.*]] = stablehlo.uniform_dequantize %[[q1]] +// CHECK: %[[dq1:.*]] = "quantfork.dcast"(%[[q1]]) // CHECK-SAME: quant.uniform // CHECK: %[[dot:.*]] = stablehlo.dot %[[dq1]], %[[cst]] -// CHECK: %[[q2:.*]] = stablehlo.uniform_quantize %[[dot]] +// CHECK: %[[q2:.*]] = "quantfork.qcast"(%[[dot]]) // CHECK-SAME: quant.uniform> -// CHECK: %[[dq2:.*]] = stablehlo.uniform_dequantize %[[q2]] +// CHECK: %[[dq2:.*]] = "quantfork.dcast"(%[[q2]]) // CHECK-SAME: quant.uniform> // CHECK: return %[[dq2]] From 43e7f2664d852172002afec82b191412ebc932c5 Mon Sep 17 00:00:00 2001 From: Shiqing Yan Date: Wed, 12 Jul 2023 03:54:24 -0700 Subject: [PATCH 182/376] [DelegatePerformance] Updated the accuracy benchmark flow. PiperOrigin-RevId: 547451008 --- .../core/shims/cc_library_with_tflite.bzl | 30 ++++++- .../delegate_performance/android/BUILD | 37 +++++---- .../BenchmarkAccuracy.java | 30 +++++++ .../BenchmarkAccuracyActivity.java | 10 +-- .../BenchmarkAccuracyImpl.java | 83 ++++++++++--------- .../android/src/main/native/BUILD | 24 ++---- .../delegate_performance_benchmark_jni.cc | 16 ++-- 7 files changed, 149 insertions(+), 81 deletions(-) create mode 100644 tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/src/main/java/org/tensorflow/lite/benchmark/delegateperformance/BenchmarkAccuracy.java diff --git a/tensorflow/lite/core/shims/cc_library_with_tflite.bzl b/tensorflow/lite/core/shims/cc_library_with_tflite.bzl index 5133db9b4134da..7ed021d9d5c9d6 100644 --- a/tensorflow/lite/core/shims/cc_library_with_tflite.bzl +++ b/tensorflow/lite/core/shims/cc_library_with_tflite.bzl @@ -7,7 +7,7 @@ load( "tflite_custom_c_library", "tflite_jni_binary", ) -load("@build_bazel_rules_android//android:rules.bzl", "android_library") +load("@build_bazel_rules_android//android:rules.bzl", "android_binary", "android_library") load("@bazel_skylib//rules:build_test.bzl", "build_test") def _concat(lists): @@ -93,6 +93,34 @@ def android_library_with_tflite( **kwargs ) +def android_binary_with_tflite( + name, + deps = [], + tflite_deps = [], + **kwargs): + """Defines an android_binary that uses the TFLite shims. + + This is a hook to allow applying different build flags (etc.) + for targets that use the TFLite shims. + + Note that this build rule doesn't itself add any dependencies on + TF Lite; this macro should normally be used in conjunction with a + direct or indirect 'tflite_deps' dependency on one of the "shim" + library targets from //tensorflow/lite/core/shims:*. + + Args: + name: as for android_binary. + deps: as for android_binary. + tflite_deps: dependencies on rules that are themselves defined using + 'cc_library_with_tflite' / 'android_library_with_tflite'. + **kwargs: Additional android_binary parameters. + """ + android_binary( + name = name, + deps = deps + tflite_deps, + **kwargs + ) + def cc_library_with_tflite( name, srcs = [], diff --git a/tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/BUILD b/tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/BUILD index 0ef82065e3cf93..163907fe32df63 100644 --- a/tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/BUILD +++ b/tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/BUILD @@ -2,7 +2,8 @@ # Delegate Performance Benchmark (DPB) Android app. # This provides model-level latency & accuracy testings for delegates, on Android. -load("@build_bazel_rules_android//android:rules.bzl", "android_binary", "android_library") +load("@build_bazel_rules_android//android:rules.bzl", "android_library") +load("//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", "android_binary_with_tflite", "android_library_with_tflite") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -20,6 +21,7 @@ android_library( name = "benchmark_accuracy_impl", srcs = ["src/main/java/org/tensorflow/lite/benchmark/delegateperformance/BenchmarkAccuracyImpl.java"], deps = [ + ":benchmark_accuracy", ":benchmark_report", ":csv_writer", ":delegate_performance_benchmark_utils", @@ -32,6 +34,11 @@ android_library( ], ) +android_library( + name = "benchmark_accuracy", + srcs = ["src/main/java/org/tensorflow/lite/benchmark/delegateperformance/BenchmarkAccuracy.java"], +) + android_library( name = "benchmark_latency_activity", srcs = ["src/main/java/org/tensorflow/lite/benchmark/delegateperformance/BenchmarkLatencyActivity.java"], @@ -56,7 +63,7 @@ android_library( ], ) -android_library( +android_library_with_tflite( name = "benchmark_report", srcs = [ "src/main/java/org/tensorflow/lite/benchmark/delegateperformance/BenchmarkReport.java", @@ -74,11 +81,11 @@ android_library( srcs = ["src/main/java/org/tensorflow/lite/benchmark/delegateperformance/BenchmarkResultType.java"], ) -android_library( +android_library_with_tflite( name = "csv_writer", srcs = ["src/main/java/org/tensorflow/lite/benchmark/delegateperformance/CsvWriter.java"], + tflite_deps = [":benchmark_report"], deps = [ - ":benchmark_report", ":delegate_metrics_entry", ":metrics_entry", ":model_benchmark_report_interface", @@ -95,16 +102,18 @@ android_library( ], ) -android_library( +android_library_with_tflite( name = "delegate_performance_benchmark_lib", srcs = [ "src/main/java/org/tensorflow/lite/benchmark/delegateperformance/BenchmarkAccuracyActivity.java", "src/main/java/org/tensorflow/lite/benchmark/delegateperformance/BenchmarkLatencyActivity.java", ], + tflite_deps = [ + "//tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/src/main/native:benchmark_native", + ], deps = [ ":benchmark_accuracy_impl", ":benchmark_latency_impl", - "//tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/src/main/native:benchmark_native", ], ) @@ -123,11 +132,11 @@ android_library( ], ) -android_library( +android_library_with_tflite( name = "html_writer", srcs = ["src/main/java/org/tensorflow/lite/benchmark/delegateperformance/HtmlWriter.java"], + tflite_deps = [":benchmark_report"], deps = [ - ":benchmark_report", ":benchmark_result_type", ":delegate_metrics_entry", ":metrics_entry", @@ -136,10 +145,10 @@ android_library( ], ) -android_library( +android_library_with_tflite( name = "json_writer", srcs = ["src/main/java/org/tensorflow/lite/benchmark/delegateperformance/JsonWriter.java"], - deps = [":benchmark_report"], + tflite_deps = [":benchmark_report"], ) android_library( @@ -148,7 +157,7 @@ android_library( deps = [":benchmark_result_type"], ) -android_library( +android_library_with_tflite( name = "model_benchmark_report", srcs = [ "src/main/java/org/tensorflow/lite/benchmark/delegateperformance/AccuracyBenchmarkReport.java", @@ -204,7 +213,7 @@ android_library( ) # The main test app. -android_binary( +android_binary_with_tflite( name = "delegate_performance_benchmark", assets = [ "//tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/models:accuracy_models", @@ -221,8 +230,8 @@ android_binary( # can't be built. We need to prevent the build system from trying to # use the target in that case. tags = ["manual"], - visibility = ["//visibility:public"], - deps = [ + tflite_deps = [ ":delegate_performance_benchmark_lib", ], + visibility = ["//visibility:public"], ) diff --git a/tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/src/main/java/org/tensorflow/lite/benchmark/delegateperformance/BenchmarkAccuracy.java b/tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/src/main/java/org/tensorflow/lite/benchmark/delegateperformance/BenchmarkAccuracy.java new file mode 100644 index 00000000000000..90ef295cb09bc8 --- /dev/null +++ b/tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/src/main/java/org/tensorflow/lite/benchmark/delegateperformance/BenchmarkAccuracy.java @@ -0,0 +1,30 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +package org.tensorflow.lite.benchmark.delegateperformance; + +import android.content.Context; + +/** Interface for Delegate Performance Accuracy Benchmark. */ +public interface BenchmarkAccuracy { + /** + * Initializes and runs the accuracy benchmark. + * + * @param context the context to use for finding the test models and exporting reports + * @param tfliteSettingsJsonFiles the list of paths to delegate JSON configurations + * @return {@code true} if the benchmark was successfully initialized and executed. Otherwise, + * returns {@code false}. + */ + boolean benchmark(Context context, String[] tfliteSettingsJsonFiles); +} diff --git a/tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/src/main/java/org/tensorflow/lite/benchmark/delegateperformance/BenchmarkAccuracyActivity.java b/tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/src/main/java/org/tensorflow/lite/benchmark/delegateperformance/BenchmarkAccuracyActivity.java index 9645a94cd1cfc4..93f8c92797bde3 100644 --- a/tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/src/main/java/org/tensorflow/lite/benchmark/delegateperformance/BenchmarkAccuracyActivity.java +++ b/tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/src/main/java/org/tensorflow/lite/benchmark/delegateperformance/BenchmarkAccuracyActivity.java @@ -41,14 +41,10 @@ public void onCreate(Bundle savedInstanceState) { Intent intent = getIntent(); Bundle bundle = intent.getExtras(); String[] tfliteSettingsJsonFiles = bundle.getStringArray(TFLITE_SETTINGS_FILES_INTENT_KEY_0); - BenchmarkAccuracyImpl impl = - new BenchmarkAccuracyImpl(getApplicationContext(), tfliteSettingsJsonFiles); - - if (impl.initialize()) { - impl.benchmark(); - } else { - Log.e(TAG, "Failed to initialize the accuracy benchmarking."); + if (!new BenchmarkAccuracyImpl().benchmark(getApplicationContext(), tfliteSettingsJsonFiles)) { + Log.i(TAG, "Accuracy benchmark failed."); } + finish(); } } diff --git a/tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/src/main/java/org/tensorflow/lite/benchmark/delegateperformance/BenchmarkAccuracyImpl.java b/tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/src/main/java/org/tensorflow/lite/benchmark/delegateperformance/BenchmarkAccuracyImpl.java index 997b959e6c756d..7ce74f13c4c39a 100644 --- a/tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/src/main/java/org/tensorflow/lite/benchmark/delegateperformance/BenchmarkAccuracyImpl.java +++ b/tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/src/main/java/org/tensorflow/lite/benchmark/delegateperformance/BenchmarkAccuracyImpl.java @@ -54,49 +54,25 @@ * configuration and relative performance differences as percentages in HTML. * */ -public class BenchmarkAccuracyImpl { +public class BenchmarkAccuracyImpl implements BenchmarkAccuracy { private static final String TAG = "TfLiteAccuracyImpl"; private static final String ACCURACY_FOLDER_NAME = "accuracy"; - private final Context context; - private final String[] tfliteSettingsJsonFiles; - private final BenchmarkReport report; + private Context context; + private String[] tfliteSettingsJsonFiles; + private BenchmarkReport report; - public BenchmarkAccuracyImpl(Context context, String[] tfliteSettingsJsonFiles) { - this.context = context; - this.tfliteSettingsJsonFiles = tfliteSettingsJsonFiles; - this.report = BenchmarkReport.create(); - } - - /** - * Initializes the test environment. Checks the validity of input arguments and creates the result - * folder. - * - *

Returns {@code true} if the initialization was successful. Otherwise, returns {@code false}. - */ - public boolean initialize() { - if (tfliteSettingsJsonFiles == null || tfliteSettingsJsonFiles.length == 0) { - Log.e(TAG, "No TFLiteSettings file provided."); - return false; - } - - try { - // Creates root result folder. - String resultFolderPath = - DelegatePerformanceBenchmark.createResultFolder( - context.getFilesDir(), ACCURACY_FOLDER_NAME); - report.addWriter(JsonWriter.create(resultFolderPath)); - report.addWriter(CsvWriter.create(resultFolderPath)); - report.addWriter(HtmlWriter.create(resultFolderPath)); - } catch (IOException e) { - Log.e(TAG, "Failed to create result folder", e); + @Override + public boolean benchmark(Context context, String[] tfliteSettingsJsonFiles) { + if (!initialize(context, tfliteSettingsJsonFiles)) { + Log.e(TAG, "Failed to initialize accuracy benchmark."); return false; } - return true; + return benchmarkDelegatesAndExportReport(); } - public void benchmark() { + private boolean benchmarkDelegatesAndExportReport() { Log.i( TAG, "Running accuracy benchmark with TFLiteSettings JSON files: " @@ -105,14 +81,14 @@ public void benchmark() { DelegatePerformanceBenchmark.loadTfLiteSettingsList(tfliteSettingsJsonFiles); if (tfliteSettingsList.size() < 2) { Log.e(TAG, "Failed to load the TFLiteSettings JSON file."); - return; + return false; } String[] assets; try { assets = context.getAssets().list(ACCURACY_FOLDER_NAME); } catch (IOException e) { Log.e(TAG, "Failed to list files from assets folder.", e); - return; + return false; } for (String asset : assets) { if (!asset.endsWith(".tflite")) { @@ -127,7 +103,7 @@ public void benchmark() { context.getFilesDir(), ACCURACY_FOLDER_NAME + "/" + modelName); } catch (IOException e) { Log.e(TAG, "Failed to create result folder for " + modelName + ". Exiting application.", e); - return; + return false; } try (AssetFileDescriptor modelFileDescriptor = context.getAssets().openFd(ACCURACY_FOLDER_NAME + "/" + asset)) { @@ -148,7 +124,7 @@ public void benchmark() { AccuracyBenchmarkReport.create(modelName, rawDelegateMetricsEntries)); } catch (IOException e) { Log.e(TAG, "Failed to open assets file " + asset, e); - return; + return false; } } // Computes the aggregated results and export the report to local files. @@ -158,5 +134,36 @@ public void benchmark() { TAG, String.format( "Accuracy benchmark result for %s: %s.", testTarget.filePath(), report.result())); + return true; + } + + /** + * Initializes the test environment. Checks the validity of input arguments and creates the result + * folder. + * + * @return {@code true} if the initialization was successful. Otherwise, returns {@code false}. + */ + private boolean initialize(Context context, String[] tfliteSettingsJsonFiles) { + if (tfliteSettingsJsonFiles == null || tfliteSettingsJsonFiles.length == 0) { + Log.e(TAG, "No TFLiteSettings file provided."); + return false; + } + this.context = context; + this.tfliteSettingsJsonFiles = tfliteSettingsJsonFiles; + report = BenchmarkReport.create(); + + try { + // Creates root result folder. + String resultFolderPath = + DelegatePerformanceBenchmark.createResultFolder( + context.getFilesDir(), ACCURACY_FOLDER_NAME); + report.addWriter(JsonWriter.create(resultFolderPath)); + report.addWriter(CsvWriter.create(resultFolderPath)); + report.addWriter(HtmlWriter.create(resultFolderPath)); + } catch (IOException e) { + Log.e(TAG, "Failed to create result folder", e); + return false; + } + return true; } } diff --git a/tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/src/main/native/BUILD b/tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/src/main/native/BUILD index ce2e1e300a64ee..5209526ab6712c 100644 --- a/tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/src/main/native/BUILD +++ b/tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/src/main/native/BUILD @@ -1,9 +1,8 @@ # Description: # Holds the native layer of the app. -load("//tensorflow/lite:build_def.bzl", "tflite_jni_binary") -load("//tensorflow:tensorflow.bzl", "clean_dep") load("//tensorflow/lite/tools/benchmark/experimental/delegate_performance/android:build_defs.bzl", "accuracy_benchmark_extra_deps", "latency_benchmark_extra_deps") +load("//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", "cc_library_with_tflite", "jni_binary_with_tflite") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -11,11 +10,11 @@ package( licenses = ["notice"], ) -tflite_jni_binary( +jni_binary_with_tflite( name = "libdelegate_performance_benchmark.so", srcs = ["delegate_performance_benchmark_jni.cc"], + tflite_deps = [":accuracy_benchmark"], deps = [ - ":accuracy_benchmark", ":latency_benchmark", "//tensorflow/lite/acceleration/configuration:configuration_fbs", "//tensorflow/lite/delegates/utils/experimental/stable_delegate:tflite_settings_json_parser", @@ -46,7 +45,7 @@ cc_library( ] + latency_benchmark_extra_deps(), ) -cc_library( +cc_library_with_tflite( name = "accuracy_benchmark", srcs = ["accuracy_benchmark.cc"], hdrs = ["accuracy_benchmark.h"], @@ -54,6 +53,7 @@ cc_library( ":status_codes", "//tensorflow/lite:minimal_logging", "//tensorflow/lite/acceleration/configuration:configuration_fbs", + "//tensorflow/lite/acceleration/configuration:gpu_plugin", "//tensorflow/lite/acceleration/configuration:stable_delegate_plugin", "//tensorflow/lite/acceleration/configuration:xnnpack_plugin", "//tensorflow/lite/core/acceleration/configuration:nnapi_plugin", @@ -65,18 +65,10 @@ cc_library( "//tensorflow/lite/tools:command_line_flags", "//tensorflow/lite/tools:tool_params", "@flatbuffers", - ] + select({ - # On Android, as the validation runs in a separate process as a - # different binary, any TFLite delegates to be validated need to - # include corresponding delegate plugins. - clean_dep("//tensorflow:android"): [ - "//tensorflow/lite/acceleration/configuration:gpu_plugin", - ], - "//conditions:default": [], - }) + accuracy_benchmark_extra_deps(), + ] + accuracy_benchmark_extra_deps(), ) -cc_library( +cc_library_with_tflite( name = "benchmark_native", - srcs = ["libdelegate_performance_benchmark.so"], + tflite_jni_binaries = [":libdelegate_performance_benchmark.so"], ) diff --git a/tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/src/main/native/delegate_performance_benchmark_jni.cc b/tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/src/main/native/delegate_performance_benchmark_jni.cc index e538cc1688478d..87fe702e3dac19 100644 --- a/tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/src/main/native/delegate_performance_benchmark_jni.cc +++ b/tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/src/main/native/delegate_performance_benchmark_jni.cc @@ -24,7 +24,10 @@ limitations under the License. #include "tensorflow/lite/delegates/utils/experimental/stable_delegate/tflite_settings_json_parser.h" #include "tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/proto/delegate_performance.pb.h" #include "tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/src/main/native/accuracy_benchmark.h" + +#ifndef TFLITE_WITH_STABLE_ABI #include "tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/src/main/native/latency_benchmark.h" +#endif // !TFLITE_WITH_STABLE_ABI namespace { @@ -67,6 +70,10 @@ Java_org_tensorflow_lite_benchmark_delegateperformance_DelegatePerformanceBenchm JNIEnv* env, jclass clazz, jobjectArray args_obj, jbyteArray tflite_settings_byte_array, jstring tflite_settings_path_obj, jint model_fd, jlong model_offset, jlong model_size) { + tflite::proto::benchmark::LatencyResults results; + +// The latency benchmark doesn't support TF Lite with the stable ABI path. +#ifndef TFLITE_WITH_STABLE_ABI std::vector args = toStringVector(env, args_obj); const char* tflite_settings_path_chars = env->GetStringUTFChars(tflite_settings_path_obj, nullptr); @@ -76,16 +83,15 @@ Java_org_tensorflow_lite_benchmark_delegateperformance_DelegatePerformanceBenchm flatbuffers::GetRoot( reinterpret_cast(tflite_settings_bytes)); - tflite::proto::benchmark::LatencyResults results = - tflite::benchmark::latency::Benchmark( - *tflite_settings, tflite_settings_path_chars, - static_cast(model_fd), static_cast(model_offset), - static_cast(model_size), args); + results = tflite::benchmark::latency::Benchmark( + *tflite_settings, tflite_settings_path_chars, static_cast(model_fd), + static_cast(model_offset), static_cast(model_size), args); env->ReleaseByteArrayElements(tflite_settings_byte_array, tflite_settings_bytes, JNI_ABORT); env->ReleaseStringUTFChars(tflite_settings_path_obj, tflite_settings_path_chars); +#endif // !TFLITE_WITH_STABLE_ABI return CppProtoToBytes(env, results); } From 22abfba9ecb1ff8ef144c4c7dff096e110289a60 Mon Sep 17 00:00:00 2001 From: George Karpenkov Date: Wed, 12 Jul 2023 04:12:38 -0700 Subject: [PATCH 183/376] [XLA:GPU] [NFC] Generalize autotuner_compile_util to support multiple outputs + Some minor refactorings PiperOrigin-RevId: 547454713 --- .../xla/service/gpu/autotuner_compile_util.cc | 100 ++++++++++-------- .../xla/service/gpu/autotuner_compile_util.h | 57 +++++----- .../xla/service/gpu/triton_autotuner.cc | 24 +++-- 3 files changed, 99 insertions(+), 82 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/autotuner_compile_util.cc b/tensorflow/compiler/xla/service/gpu/autotuner_compile_util.cc index d799987126e681..d78454ca7961ba 100644 --- a/tensorflow/compiler/xla/service/gpu/autotuner_compile_util.cc +++ b/tensorflow/compiler/xla/service/gpu/autotuner_compile_util.cc @@ -103,15 +103,34 @@ std::vector ExecutionInputsFromBuffers( } // namespace +AutotunerCompileUtil::AutotunerCompileUtil(Compiler* compiler, + se::StreamExecutor& stream_executor, + se::Stream& stream, + se::DeviceMemoryAllocator& allocator, + const DebugOptions& opts) + : compiler_(compiler), + stream_executor_(stream_executor), + stream_(stream), + allocator_(allocator), + opts_(opts) { + // Avoid dumping compilation steps. + opts_.set_xla_dump_to(""); + opts_.set_xla_gpu_dump_autotune_results_to(""); + opts_.set_xla_gpu_load_autotune_results_from(""); + opts_.set_xla_gpu_dump_llvmir(false); + // Avoid using another thread pool. + opts_.set_xla_gpu_force_compilation_parallelism(1); + // Avoid using GPU graphs as we don't want to measure graph construction time. + opts_.set_xla_gpu_cuda_graph_level(0); +} + StatusOr> AutotunerCompileUtil::GenerateAndProfileExecutable( - const HloComputation& hlo_computation, const AutotuneResult& config, - const AutotuneCacheKey& cache_key, se::Stream* stream, - absl::Span input_buffers, - se::DeviceMemoryBase output_buffer, ExtractModuleFn extractor) { - TF_ASSIGN_OR_RETURN( - Executable * executable, - Compile(hlo_computation, config, cache_key, std::move(extractor))); + const AutotuneResult& config, const AutotuneCacheKey& cache_key, + se::Stream* stream, absl::Span input_buffers, + ShapedBuffer output_buffer, GenerateModuleFn extractor) { + TF_ASSIGN_OR_RETURN(Executable * executable, + Compile(config, cache_key, std::move(extractor))); if (!executable) { return {std::nullopt}; @@ -134,17 +153,29 @@ AutotunerCompileUtil::GenerateAndProfileExecutable( TF_ASSIGN_OR_RETURN(absl::Duration timer_duration, timer.GetElapsedDuration()); ScopedShapedBuffer result = execution_output.ConsumeResult(); - TF_RET_CHECK(output_buffer.size() == result.root_buffer().size()); + // TODO(cheshire): Copying should not be required. Instead, we can add a new // aliased parameter. - stream->ThenMemcpy(&output_buffer, result.root_buffer(), - result.root_buffer().size()); + Shape shape = result.on_device_shape(); + TF_RET_CHECK(shape == output_buffer.on_device_shape()); + if (shape.IsTuple()) { + for (int64_t i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { + TF_RET_CHECK(!shape.tuple_shapes(i).IsTuple()); + stream->ThenMemcpy(output_buffer.buffers().mutable_element(ShapeIndex{i}), + result.buffer(ShapeIndex{i}), + ShapeUtil::ByteSizeOf(shape.tuple_shapes(i))); + } + } else { + stream->ThenMemcpy(output_buffer.buffers().mutable_element(ShapeIndex{}), + result.buffer(ShapeIndex{}), + ShapeUtil::ByteSizeOf(shape)); + } return std::make_optional(timer_duration); } StatusOr AutotunerCompileUtil::Compile( - const HloComputation& hlo_computation, const AutotuneResult& res, - const AutotuneCacheKey& cache_key, ExtractModuleFn extractor) { + const AutotuneResult& res, const AutotuneCacheKey& cache_key, + GenerateModuleFn extractor) { CompilationKey key{cache_key, res}; { absl::MutexLock lock(&executable_cache_mutex); @@ -156,15 +187,14 @@ StatusOr AutotunerCompileUtil::Compile( } TF_ASSIGN_OR_RETURN(std::unique_ptr executable, - CompileNoCache(hlo_computation, std::move(extractor))); + CompileNoCache(std::move(extractor))); absl::MutexLock lock(&executable_cache_mutex); auto [it, inserted] = executable_cache.emplace(key, std::move(executable)); return it->second.get(); } StatusOr> AutotunerCompileUtil::CompileNoCache( - const HloComputation& original_computation, - ExtractModuleFn module_extractor) { + GenerateModuleFn module_extractor) { StatusOr> new_hlo_module = module_extractor(); if (new_hlo_module.status().GetPayload(kUncompilableFusion).has_value()) { // Incompatible value of split-k is an expected failure. @@ -172,15 +202,25 @@ StatusOr> AutotunerCompileUtil::CompileNoCache( } else if (!new_hlo_module.status().ok()) { return new_hlo_module.status(); } - return RunBackend(original_computation, std::move(*new_hlo_module)); + (*new_hlo_module)->config().set_debug_options(opts_); + + StatusOr> out = compiler_->RunBackend( + std::move(*new_hlo_module), &stream_executor_, &allocator_); + if (out.status().code() == absl::StatusCode::kResourceExhausted) { + // Being out of shared memory budget is an expected failure. + return std::unique_ptr(); + } + return out; } /*static*/ StatusOr AutotunerCompileUtil::Create( - se::Stream& stream, se::DeviceMemoryAllocator& allocator) { + se::Stream& stream, se::DeviceMemoryAllocator& allocator, + const DebugOptions& opts) { se::StreamExecutor& stream_executor = *stream.parent(); TF_ASSIGN_OR_RETURN(Compiler * compiler, Compiler::GetForPlatform(stream_executor.platform())); - return AutotunerCompileUtil(compiler, stream_executor, stream, allocator); + return AutotunerCompileUtil(compiler, stream_executor, stream, allocator, + opts); } StatusOr AutotunerCompileUtil::Execute( @@ -202,30 +242,6 @@ StatusOr AutotunerCompileUtil::Execute( return std::move(output); } -StatusOr> AutotunerCompileUtil::RunBackend( - const HloComputation& original_computation, - std::unique_ptr module) { - DebugOptions options = - original_computation.parent()->config().debug_options(); - // Avoid dumping compilation steps. - options.set_xla_dump_to(""); - options.set_xla_gpu_dump_autotune_results_to(""); - options.set_xla_gpu_load_autotune_results_from(""); - options.set_xla_gpu_dump_llvmir(false); - // Avoid using Gpu graphs as we don't want to measure graph construction time. - options.set_xla_gpu_cuda_graph_level(0); - // Avoid using another thread pool. - options.set_xla_gpu_force_compilation_parallelism(1); - module->config().set_debug_options(options); - StatusOr> out = - compiler_->RunBackend(std::move(module), &stream_executor_, &allocator_); - if (out.status().code() == absl::StatusCode::kResourceExhausted) { - // Being out of shared memory budget is an expected failure. - return std::unique_ptr(); - } - return out; -} - /*static*/ void AutotunerCompileUtil::ClearCompilationCache() { absl::MutexLock lock(&executable_cache_mutex); executable_cache.clear(); diff --git a/tensorflow/compiler/xla/service/gpu/autotuner_compile_util.h b/tensorflow/compiler/xla/service/gpu/autotuner_compile_util.h index da3f56ab774bde..c9484bb61faa21 100644 --- a/tensorflow/compiler/xla/service/gpu/autotuner_compile_util.h +++ b/tensorflow/compiler/xla/service/gpu/autotuner_compile_util.h @@ -47,55 +47,51 @@ namespace gpu { // Autotuning utils which require compiling fusions separately. Requires a // separate target, as runtime autotuning cannot perform compilation. +// +// Uses a global cache, *not* unique per instance. class AutotunerCompileUtil { public: - using ExtractModuleFn = + using GenerateModuleFn = absl::AnyInvocable>()>; + // Generates a compile util for a platform associated with the `stream`. static StatusOr Create( - se::Stream& stream, se::DeviceMemoryAllocator& allocator); + se::Stream& stream, se::DeviceMemoryAllocator& allocator, + const DebugOptions& opts); - AutotunerCompileUtil(Compiler* compiler, se::StreamExecutor& stream_executor, - se::Stream& stream, se::DeviceMemoryAllocator& allocator) - : compiler_(compiler), - stream_executor_(stream_executor), - stream_(stream), - allocator_(allocator) {} - - // Runs the compiled executable with the given extractor, cached with - // . Returns std::nullopt on expected failure, bad Status - // otherwise. - // Uses a global cache, *not* unique per instance. + // Generates an executable first, given the module generator function in + // `extractor`. + // + // Runs the resulting executable with the given extractor, cached with + // `(cache_key, config)`. Returns `std::nullopt` on expected failure, bad + // `Status` otherwise. StatusOr> GenerateAndProfileExecutable( - const HloComputation& hlo_computation, const AutotuneResult& config, - const AutotuneCacheKey& cache_key, se::Stream* stream, - absl::Span input_buffers, - se::DeviceMemoryBase output_buffer, ExtractModuleFn extractor); + const AutotuneResult& config, const AutotuneCacheKey& cache_key, + se::Stream* stream, absl::Span input_buffers, + ShapedBuffer output_buffer, GenerateModuleFn extractor); - // Generic method to compile a given computation in isolation using a given - // pipeline, cached on AutotuneResult and AutotuneCacheKey. + // Generic method to compile a generated module from `extractor` in isolation. // // On *expected* failures we will store an empty unique_ptr in cache. // // Returns: - // - on *expected* failure - // - Executable if everything goes fine. - // - Status on *unexpected* failure. + // - `nullptr` on *expected* failure + // - `Executable` if everything goes fine. + // - `Status` on *unexpected* failure. StatusOr Compile( - const HloComputation& hlo_computation, const AutotuneResult& res, - const AutotuneCacheKey& cache_key, - AutotunerCompileUtil::ExtractModuleFn extractor); + const AutotuneResult& res, const AutotuneCacheKey& cache_key, + AutotunerCompileUtil::GenerateModuleFn extractor); + // Clears the global compilation cache. static void ClearCompilationCache(); private: - StatusOr> RunBackend( - const HloComputation& original_computation, - std::unique_ptr module); + AutotunerCompileUtil(Compiler* compiler, se::StreamExecutor& stream_executor, + se::Stream& stream, se::DeviceMemoryAllocator& allocator, + const DebugOptions& opts); StatusOr> CompileNoCache( - const HloComputation& original_computation, - AutotunerCompileUtil::ExtractModuleFn module_extractor); + AutotunerCompileUtil::GenerateModuleFn module_extractor); StatusOr Execute(Executable& executable, std::vector arguments); @@ -104,6 +100,7 @@ class AutotunerCompileUtil { se::StreamExecutor& stream_executor_; se::Stream& stream_; se::DeviceMemoryAllocator& allocator_; + DebugOptions opts_; }; } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/triton_autotuner.cc b/tensorflow/compiler/xla/service/gpu/triton_autotuner.cc index 440a9611a8fe27..bb8468b0782193 100644 --- a/tensorflow/compiler/xla/service/gpu/triton_autotuner.cc +++ b/tensorflow/compiler/xla/service/gpu/triton_autotuner.cc @@ -180,7 +180,7 @@ class TritonAutotunerVisitor : public DfsHloRewriteVisitor { AutotuneResult config; *config.mutable_triton() = conf; StatusOr res = - autotuner_compile_util_->Compile(fusion, config, cache_key, [&] { + autotuner_compile_util_->Compile(config, cache_key, [&] { return TritonGemmAutotuneExtractor(conf, gpu_device_info, fusion.FusionInstruction()); }); @@ -280,11 +280,11 @@ class TritonAutotunerVisitor : public DfsHloRewriteVisitor { AutotuneResult config; *config.mutable_triton() = autotune_config; - std::vector used_buffers; - absl::c_copy(input_buffers, std::back_inserter(used_buffers)); + ShapedBuffer output(hlo_computation.root_instruction()->shape(), 0); + output.set_buffer(output_buffer, ShapeIndex{}); + return autotuner_compile_util_->GenerateAndProfileExecutable( - hlo_computation, config, cache_key, stream, used_buffers, output_buffer, - [&] { + config, cache_key, stream, input_buffers, std::move(output), [&] { return TritonGemmAutotuneExtractor( autotune_config, GetGpuDeviceInfo(config_.GetExecutor()), hlo_computation.FusionInstruction()); @@ -349,10 +349,13 @@ class TritonAutotunerVisitor : public DfsHloRewriteVisitor { gemm.set_algorithm(0); *res.mutable_gemm() = gemm; + ShapedBuffer output(original_computation.root_instruction()->shape(), 0); + output.set_buffer(output_buffer, ShapeIndex{}); + TF_ASSIGN_OR_RETURN(std::optional duration, autotuner_compile_util_->GenerateAndProfileExecutable( - original_computation, res, cache_key, stream, - input_buffers, output_buffer, [&] { + res, cache_key, stream, input_buffers, + std::move(output), [&] { return CublasGemmAutotuneExtractor( GetGpuDeviceInfo(config_.GetExecutor()), &original_computation); @@ -473,15 +476,16 @@ StatusOr TritonAutotuner::Run( std::optional autotuner_compile_util; if (!config_.IsDeviceless()) { - // TODO(cheshire): The ones below should not be needed. se::StreamExecutor* stream_exec = config_.GetExecutor(); se::DeviceMemoryAllocator* allocator = config_.GetAllocator() ? config_.GetAllocator() : stream_exec->GetAllocator(); TF_ASSIGN_OR_RETURN(se::Stream* const stream, allocator->GetStream(stream_exec->device_ordinal())); - TF_ASSIGN_OR_RETURN(AutotunerCompileUtil util, - AutotunerCompileUtil::Create(*stream, *allocator)); + TF_ASSIGN_OR_RETURN( + AutotunerCompileUtil util, + AutotunerCompileUtil::Create(*stream, *allocator, + module->config().debug_options())); autotuner_compile_util.emplace(util); } From bb0f461eee109e9a632162b62bd0fe76767d22b8 Mon Sep 17 00:00:00 2001 From: Alan Kelly Date: Wed, 12 Jul 2023 04:35:45 -0700 Subject: [PATCH 184/376] WHILE: Set subgraphs_prepared to false during lazy initialization PiperOrigin-RevId: 547458178 --- tensorflow/lite/core/subgraph.cc | 1 + tensorflow/lite/kernels/subgraph_test_util.cc | 20 ++++-- tensorflow/lite/kernels/while.cc | 26 +++---- tensorflow/lite/kernels/while_test.cc | 68 ++++++++++++------- 4 files changed, 75 insertions(+), 40 deletions(-) diff --git a/tensorflow/lite/core/subgraph.cc b/tensorflow/lite/core/subgraph.cc index 0de9243868f19a..0f21ec72e5908f 100644 --- a/tensorflow/lite/core/subgraph.cc +++ b/tensorflow/lite/core/subgraph.cc @@ -1407,6 +1407,7 @@ TfLiteStatus Subgraph::MayAllocateOpOutput(TfLiteNode* node) { if (ShouldOptimizeMemoryForLargeTensors()) { for (int i = 0; i < node->outputs->size; ++i) { int tensor_index = node->outputs->data[i]; + if (tensor_index == kTfLiteOptionalTensor) continue; TfLiteTensor* tensor = &context_.tensors[tensor_index]; if (tensor->data.raw == nullptr && tensor->allocation_type == kTfLiteDynamic) { diff --git a/tensorflow/lite/kernels/subgraph_test_util.cc b/tensorflow/lite/kernels/subgraph_test_util.cc index f961dae3124ace..809c185621e015 100644 --- a/tensorflow/lite/kernels/subgraph_test_util.cc +++ b/tensorflow/lite/kernels/subgraph_test_util.cc @@ -317,20 +317,32 @@ void SubgraphBuilder::BuildXNNPACKSubgraph(Subgraph* subgraph) { } void SubgraphBuilder::BuildInputIsOutputSubgraph(Subgraph* subgraph) { - enum { kInputCounter, kInputValue, kOutputCounter, kTensorCount }; + enum { + kInputCounter, + kInputValue0, + kInputOutput, + kOutputCounter, + kOutputValue0, + kConstRhs, + kTensorCount + }; int first_new_tensor_index; ASSERT_EQ(subgraph->AddTensors(kTensorCount, &first_new_tensor_index), kTfLiteOk); ASSERT_EQ(first_new_tensor_index, 0); - ASSERT_EQ(subgraph->SetInputs({kInputCounter, kInputValue}), kTfLiteOk); - ASSERT_EQ(subgraph->SetOutputs({kOutputCounter, kInputValue}), kTfLiteOk); + ASSERT_EQ(subgraph->SetInputs({kInputCounter, kInputValue0, kInputOutput}), + kTfLiteOk); + ASSERT_EQ(subgraph->SetOutputs({kOutputCounter, kOutputValue0, kInputOutput}), + kTfLiteOk); for (int i = 0; i < kTensorCount; ++i) { SetupTensor(subgraph, i, kTfLiteInt32); } + CreateConstantTensor(subgraph, kConstRhs, {1}, {1}); - AddAddNode(subgraph, kInputCounter, kInputValue, kOutputCounter); + AddAddNode(subgraph, kInputCounter, kConstRhs, kOutputCounter); + AddAddNode(subgraph, kInputValue0, kInputOutput, kOutputValue0); } void SubgraphBuilder::BuildInputIsDifferentOutputSubgraph(Subgraph* subgraph) { diff --git a/tensorflow/lite/kernels/while.cc b/tensorflow/lite/kernels/while.cc index 35e0305dda8dd4..fd2251f10cfc6f 100644 --- a/tensorflow/lite/kernels/while.cc +++ b/tensorflow/lite/kernels/while.cc @@ -251,6 +251,12 @@ TfLiteStatus Prepare_impl(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_STATUS(CheckCondOutput(context, cond_output)); } + // Prepare and check the body subgraph. + TF_LITE_ENSURE_OK( + context, CopyTensorsShapeAndType( + context, this_subgraph, TfLiteIntArrayView(node->inputs), + body_subgraph, body_subgraph->inputs(), true)); + // Detect when a WHILE input is read only. const std::vector input_tensors_count = this_subgraph->GetInputTensorsCount(); @@ -265,18 +271,15 @@ TfLiteStatus Prepare_impl(TfLiteContext* context, TfLiteNode* node) { body_subgraph->tensor(body_subgraph->inputs()[i]); if (body_input->type == kTfLiteString) continue; if (IsResourceOrVariant(body_input)) continue; + TfLiteTensor* this_output = + this_subgraph->tensor(node->outputs->data[i]); + TfLiteTensorDataFree(this_output); node->outputs->data[i] = kTfLiteOptionalTensor; body_input->allocation_type = kTfLiteCustom; } } } - // Prepare and check the body subgraph. - TF_LITE_ENSURE_OK( - context, CopyTensorsShapeAndType( - context, this_subgraph, TfLiteIntArrayView(node->inputs), - body_subgraph, body_subgraph->inputs(), true)); - for (int i = 0; i < num_inputs; ++i) { TfLiteTensor* body_input = body_subgraph->tensor(body_subgraph->inputs()[i]); @@ -300,7 +303,7 @@ TfLiteStatus Prepare_impl(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE(context, !IsDynamicTensor(body_output)); if (!TfLiteIntArrayEqual(body_input->dims, body_output->dims)) { // Don't unnecessarily set an output to dynamic when one of input/output - // is a scalar and the other an tensor of size 1. + // is a scalar and the other a tensor of size 1. // If both tensors are scalars or both tensors have shape [1], then // TfLiteIntArrayEqual would return true. We want to detect when one // tensor is a scalar and the other has shape [1], so the total number @@ -340,6 +343,9 @@ TfLiteStatus Prepare_impl(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { Subgraph* this_subgraph = reinterpret_cast(context->impl_); if (this_subgraph->ShouldOptimizeMemoryForLargeTensors()) { + OpData* op_data = reinterpret_cast(node->user_data); + // Call Prepare to ensure input shapes are propagated to the body subgraph. + op_data->subgraphs_prepared = false; // Apply lazy initialization of WHILE kernel. // Just make node output tensors dynamic. int num_outputs = node->outputs->size; @@ -354,10 +360,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { return Prepare_impl(context, node); } -TfLiteStatus Prepare_lazy(TfLiteContext* context, TfLiteNode* node) { - return Prepare_impl(context, node); -} - // Evaluate cond subgraph and set the result. TfLiteStatus Eval_cond_subgraph(TfLiteContext* context, Subgraph* cond_subgraph, bool cond_has_dynamic_output_tensors, @@ -584,7 +586,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { Subgraph* body_subgraph = (*subgraphs)[op_data->body_subgraph_index].get(); if (op_data->subgraphs_prepared == false) { - TF_LITE_ENSURE_OK(context, Prepare_lazy(context, node)); + TF_LITE_ENSURE_OK(context, Prepare_impl(context, node)); } else { TF_LITE_ENSURE_OK(context, cond_subgraph->AllocateTensors()); TF_LITE_ENSURE_OK(context, body_subgraph->AllocateTensors()); diff --git a/tensorflow/lite/kernels/while_test.cc b/tensorflow/lite/kernels/while_test.cc index f4e21513abbc20..8cc2df70233e08 100644 --- a/tensorflow/lite/kernels/while_test.cc +++ b/tensorflow/lite/kernels/while_test.cc @@ -77,24 +77,27 @@ TEST_F(WhileTest, TestWithXNNPACK) { TEST_F(WhileTest, TestInputIsOutput) { interpreter_ = std::make_unique(); AddSubgraphs(2); - builder_->BuildLargeLessEqualCondSubgraph(interpreter_->subgraph(1), 3, 2); + builder_->BuildLargeLessEqualCondSubgraph(interpreter_->subgraph(1), 3, 3); builder_->BuildInputIsOutputSubgraph(interpreter_->subgraph(2)); - builder_->BuildMultiInputWhileSubgraph(&interpreter_->primary_subgraph(), 2); + builder_->BuildMultiInputWhileSubgraph(&interpreter_->primary_subgraph(), 3); ASSERT_EQ(interpreter_->ResizeInputTensor(interpreter_->inputs()[0], {1}), kTfLiteOk); ASSERT_EQ(interpreter_->ResizeInputTensor(interpreter_->inputs()[1], {1}), kTfLiteOk); + ASSERT_EQ(interpreter_->ResizeInputTensor(interpreter_->inputs()[2], {1}), + kTfLiteOk); ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk); FillIntTensor(interpreter_->tensor(interpreter_->inputs()[0]), {1}); FillIntTensor(interpreter_->tensor(interpreter_->inputs()[1]), {1}); + FillIntTensor(interpreter_->tensor(interpreter_->inputs()[2]), {1}); ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); TfLiteTensor* output0 = interpreter_->tensor(interpreter_->outputs()[0]); CheckIntTensor(output0, {1}, {4}); TfLiteTensor* output1 = interpreter_->tensor(interpreter_->outputs()[1]); - CheckIntTensor(output1, {1}, {1}); + CheckIntTensor(output1, {1}, {4}); ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); @@ -216,30 +219,47 @@ TEST_F(WhileTest, TestAllCases) { } TEST_F(WhileTest, TestStaticUnconsumedOutputs) { - interpreter_ = std::make_unique(); - AddSubgraphs(2); - builder_->BuildLargeLessEqualCondSubgraph(interpreter_->subgraph(1), 3, 2); - builder_->BuildInputIsOutputSubgraph(interpreter_->subgraph(2)); - builder_->BuildMultiInputWhileSubgraphWithUnconsumedOutput( - &interpreter_->primary_subgraph(), 2); + for (bool dynamic_tensors : {true, false}) { + interpreter_ = std::make_unique(); + AddSubgraphs(2); + builder_->BuildLargeLessEqualCondSubgraph(interpreter_->subgraph(1), 3, 3); + builder_->BuildInputIsOutputSubgraph(interpreter_->subgraph(2)); + builder_->BuildMultiInputWhileSubgraphWithUnconsumedOutput( + &interpreter_->primary_subgraph(), 3); - ASSERT_EQ(interpreter_->ResizeInputTensor(interpreter_->inputs()[0], {1}), - kTfLiteOk); - ASSERT_EQ(interpreter_->ResizeInputTensor(interpreter_->inputs()[1], {1}), - kTfLiteOk); - ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk); - FillIntTensor(interpreter_->tensor(interpreter_->inputs()[0]), {1}); - FillIntTensor(interpreter_->tensor(interpreter_->inputs()[1]), {2}); + InterpreterOptions options; + if (dynamic_tensors) { + options.OptimizeMemoryForLargeTensors(1); + interpreter_->ApplyOptions(&options); + } - ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); - TfLiteTensor* output0 = interpreter_->tensor(interpreter_->outputs()[0]); - CheckIntTensor(output0, {1}, {5}); + ASSERT_EQ(interpreter_->ResizeInputTensor(interpreter_->inputs()[0], {1}), + kTfLiteOk); + ASSERT_EQ(interpreter_->ResizeInputTensor(interpreter_->inputs()[1], {1}), + kTfLiteOk); + ASSERT_EQ(interpreter_->ResizeInputTensor(interpreter_->inputs()[2], {1}), + kTfLiteOk); + ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk); + FillIntTensor(interpreter_->tensor(interpreter_->inputs()[0]), {1}); + FillIntTensor(interpreter_->tensor(interpreter_->inputs()[1]), {2}); + FillIntTensor(interpreter_->tensor(interpreter_->inputs()[2]), {2}); - ASSERT_EQ(interpreter_->subgraph(2)->tensor(1)->data.data, - interpreter_->tensor(interpreter_->inputs()[1])->data.data); - ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); - ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); - ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); + ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); + TfLiteTensor* output0 = interpreter_->tensor(interpreter_->outputs()[0]); + CheckIntTensor(output0, {1}, {4}); + TfLiteTensor* output1 = interpreter_->tensor(interpreter_->outputs()[1]); + CheckIntTensor(output1, {1}, {8}); + + ASSERT_EQ(interpreter_->ResizeInputTensor(interpreter_->inputs()[2], {2}), + kTfLiteOk); + ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk); + FillIntTensor(interpreter_->tensor(interpreter_->inputs()[2]), {2, 2}); + + ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); + CheckIntTensor(output1, {2}, {8, 8}); + ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); + ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); + } } // Test a body subgraph which triggers the reallocation of an inplace output From a2c7600786b04f793834c97a7cb99b495bcc9272 Mon Sep 17 00:00:00 2001 From: Aliia Khasanova Date: Wed, 12 Jul 2023 05:15:10 -0700 Subject: [PATCH 185/376] [XLA:GPU] Enable --xla_gpu_enable_experimental_block_size by default. PiperOrigin-RevId: 547465243 --- tensorflow/compiler/xla/debug_options_flags.cc | 2 +- .../xla/service/gpu/tests/dynamic_update_slice_inplace.hlo | 2 +- tensorflow/compiler/xla/service/gpu/tests/fused_slice.hlo | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tensorflow/compiler/xla/debug_options_flags.cc b/tensorflow/compiler/xla/debug_options_flags.cc index 299635b30746e0..c3f9d75942da09 100644 --- a/tensorflow/compiler/xla/debug_options_flags.cc +++ b/tensorflow/compiler/xla/debug_options_flags.cc @@ -145,7 +145,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_gpu_collective_inflation_factor(1); - opts.set_xla_gpu_enable_experimental_block_size(false); + opts.set_xla_gpu_enable_experimental_block_size(true); opts.set_xla_gpu_exhaustive_tiling_search(false); opts.set_xla_gpu_enable_priority_fusion(false); diff --git a/tensorflow/compiler/xla/service/gpu/tests/dynamic_update_slice_inplace.hlo b/tensorflow/compiler/xla/service/gpu/tests/dynamic_update_slice_inplace.hlo index 660ae9b3f7eee3..71fe8a85899746 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/dynamic_update_slice_inplace.hlo +++ b/tensorflow/compiler/xla/service/gpu/tests/dynamic_update_slice_inplace.hlo @@ -22,7 +22,7 @@ // CHECK: %[[VAL_16:.*]] = zext i32 %[[VAL_15]] to i64 // CHECK: %[[VAL_17:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !3 // CHECK: %[[VAL_18:.*]] = zext i32 %[[VAL_17]] to i64 -// CHECK: %[[VAL_19:.*]] = mul nuw nsw i64 %[[VAL_16]], 1024 +// CHECK: %[[VAL_19:.*]] = mul nuw nsw i64 %[[VAL_16]], 128 // CHECK: %[[VAL_20:.*]] = add nuw nsw i64 %[[VAL_19]], %[[VAL_18]] // CHECK: %[[VAL_21:.*]] = icmp ult i64 %[[VAL_20]], 98304 // CHECK: call void @llvm.assume(i1 %[[VAL_21]]) diff --git a/tensorflow/compiler/xla/service/gpu/tests/fused_slice.hlo b/tensorflow/compiler/xla/service/gpu/tests/fused_slice.hlo index deaec101253fbe..71d571f965d4ce 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/fused_slice.hlo +++ b/tensorflow/compiler/xla/service/gpu/tests/fused_slice.hlo @@ -5,7 +5,7 @@ // CHECK-LABEL: entry: // CHECK: %[[VAL_0:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !2 // CHECK: %[[VAL_1:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !3 -// CHECK: %[[VAL_2:.*]] = mul nuw nsw i32 %[[VAL_0]], 1024 +// CHECK: %[[VAL_2:.*]] = mul nuw nsw i32 %[[VAL_0]], 128 // CHECK: %[[VAL_3:.*]] = add nuw nsw i32 %[[VAL_2]], %[[VAL_1]] // CHECK: %[[VAL_4:.*]] = icmp ult i32 %[[VAL_3]], 2048 // CHECK: call void @llvm.assume(i1 %[[VAL_4]]) From 7ef08c3d76daee4b6eb85d416077341258fdd8e0 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 12 Jul 2023 06:29:06 -0700 Subject: [PATCH 186/376] Don't depend on @llvm-project//mlir:ConversionPasses, instead depend only on specific passes. :ConversionPasses pulls in many MLIR dialects and passes that aren't used, which bloats binary size. PiperOrigin-RevId: 547478135 --- tensorflow/compiler/xla/mlir/runtime/transforms/BUILD | 2 +- .../xla/mlir/runtime/transforms/compilation_pipeline_cpu.cc | 2 +- tensorflow/compiler/xla/service/cpu/BUILD | 4 +++- .../compiler/xla/service/cpu/hlo_xla_runtime_pipeline.cc | 4 +++- 4 files changed, 8 insertions(+), 4 deletions(-) diff --git a/tensorflow/compiler/xla/mlir/runtime/transforms/BUILD b/tensorflow/compiler/xla/mlir/runtime/transforms/BUILD index 2401fdd984d85b..dc9bd544f64fc2 100644 --- a/tensorflow/compiler/xla/mlir/runtime/transforms/BUILD +++ b/tensorflow/compiler/xla/mlir/runtime/transforms/BUILD @@ -120,10 +120,10 @@ cc_library( "@llvm-project//mlir:BuiltinToLLVMIRTranslation", "@llvm-project//mlir:ComplexToLLVM", "@llvm-project//mlir:ControlFlowDialect", - "@llvm-project//mlir:ConversionPasses", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:FuncExtensions", "@llvm-project//mlir:FuncToLLVM", + "@llvm-project//mlir:GPUToGPURuntimeTransforms", "@llvm-project//mlir:GPUTransforms", "@llvm-project//mlir:LLVMToLLVMIRTranslation", "@llvm-project//mlir:LinalgToLLVM", diff --git a/tensorflow/compiler/xla/mlir/runtime/transforms/compilation_pipeline_cpu.cc b/tensorflow/compiler/xla/mlir/runtime/transforms/compilation_pipeline_cpu.cc index 05fe3773e39365..0ff7e1b0d366da 100644 --- a/tensorflow/compiler/xla/mlir/runtime/transforms/compilation_pipeline_cpu.cc +++ b/tensorflow/compiler/xla/mlir/runtime/transforms/compilation_pipeline_cpu.cc @@ -21,11 +21,11 @@ limitations under the License. #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h" // from @llvm-project #include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h" // from @llvm-project #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h" // from @llvm-project +#include "mlir/Conversion/GPUCommon/GPUCommonPass.h" // from @llvm-project #include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h" // from @llvm-project #include "mlir/Conversion/MathToLLVM/MathToLLVM.h" // from @llvm-project #include "mlir/Conversion/MathToLibm/MathToLibm.h" // from @llvm-project #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" // from @llvm-project -#include "mlir/Conversion/Passes.h" // from @llvm-project #include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" // from @llvm-project #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" // from @llvm-project #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" // from @llvm-project diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index 0e04e5295c7351..c0d0f1908a03a8 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -410,15 +410,16 @@ cc_library( "@llvm-project//mlir:BufferizationToMemRef", "@llvm-project//mlir:BufferizationTransforms", "@llvm-project//mlir:ComplexToStandard", - "@llvm-project//mlir:ConversionPasses", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:FuncTransforms", "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:GPUToNVVMTransforms", "@llvm-project//mlir:IR", "@llvm-project//mlir:LinalgTransforms", "@llvm-project//mlir:MemRefTransforms", "@llvm-project//mlir:Pass", "@llvm-project//mlir:ReconcileUnrealizedCasts", + "@llvm-project//mlir:SCFToControlFlow", "@llvm-project//mlir:SCFTransforms", "@llvm-project//mlir:ShapeToStandard", "@llvm-project//mlir:ShapeTransforms", @@ -426,6 +427,7 @@ cc_library( "@llvm-project//mlir:TensorToLinalg", "@llvm-project//mlir:TensorTransforms", "@llvm-project//mlir:Transforms", + "@llvm-project//mlir:VectorToLLVM", "@llvm-project//mlir:VectorToSCF", "@llvm-project//mlir:VectorTransforms", ], diff --git a/tensorflow/compiler/xla/service/cpu/hlo_xla_runtime_pipeline.cc b/tensorflow/compiler/xla/service/cpu/hlo_xla_runtime_pipeline.cc index dd9221df95e793..08496d0e72c8d3 100644 --- a/tensorflow/compiler/xla/service/cpu/hlo_xla_runtime_pipeline.cc +++ b/tensorflow/compiler/xla/service/cpu/hlo_xla_runtime_pipeline.cc @@ -20,10 +20,12 @@ limitations under the License. #include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h" #include "mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h" // from @llvm-project #include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h" // from @llvm-project -#include "mlir/Conversion/Passes.h" // from @llvm-project +#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" // from @llvm-project #include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" // from @llvm-project +#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" // from @llvm-project #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h" // from @llvm-project #include "mlir/Conversion/TensorToLinalg/TensorToLinalgPass.h" // from @llvm-project +#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" // from @llvm-project #include "mlir/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.h" // from @llvm-project #include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h" // from @llvm-project #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" // from @llvm-project From 4b81db961b41c66607aaae3850143c799e6f64b9 Mon Sep 17 00:00:00 2001 From: George Necula Date: Wed, 12 Jul 2023 06:52:29 -0700 Subject: [PATCH 187/376] Add support for shape assertions to RefinePolymorphicShapes Front-ends can use stablehlo.custom_call @shape_assertion to check that shape dimension sizes meet a constraint that depends only on dimension sizes, and can be evaluated to a constant once all the shapes in the module are known. This is needed for jax2tf, because the JAX lowering relies on constraints such as the batch dimensions of two inputs being the same. These constraints are lost when lowering to Stablehlo dynamic shapes, and it would be unsound to invoke the code on inputs that do not meet the constraints. We also add support for disabling this safety check if "shape_assertions" is included in the `disabled_checks`. PiperOrigin-RevId: 547482522 --- tensorflow/compiler/tests/BUILD | 30 +++ .../compiler/tests/xla_call_module_test.py | 197 +++++++++++++++ .../tf2xla/kernels/xla_call_module_loader.cc | 37 ++- .../tf2xla/kernels/xla_call_module_loader.h | 8 +- tensorflow/compiler/tf2xla/python/xla.py | 12 +- tensorflow/compiler/xla/python/BUILD | 1 - tensorflow/compiler/xla/python/mlir.cc | 7 +- .../xla/python/refine_polymorphic_shapes.cc | 227 +++++++++++++++++- .../xla/python/refine_polymorphic_shapes.h | 12 +- tensorflow/compiler/xla/python/xla_client.py | 2 +- .../xla/python/xla_extension/mlir.pyi | 3 +- 11 files changed, 504 insertions(+), 32 deletions(-) diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 6df3fad2a56cba..4a4c116fb9bdf2 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -2801,6 +2801,36 @@ tf_xla_py_strict_test( ], ) +tf_xla_py_strict_test( + name = "xla_call_module_no_shape_assertions_check_test", + size = "small", + srcs = ["xla_call_module_test.py"], + disabled_backends = ["cpu_ondemand"], # cpu_ondemand overrides the TF_XLA_FLAGS + enable_mlir_bridge = False, + env = {"TF_XLA_FLAGS": "--tf_xla_call_module_disabled_checks=shape_assertions"}, + main = "xla_call_module_test.py", + python_version = "PY3", + tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + ], + use_xla_device = False, # Uses tf.function(jit_compile=True) + deps = [ + ":xla_test", + "//tensorflow/compiler/mlir/stablehlo", + "//tensorflow/compiler/tf2xla/ops:gen_xla_ops", + "//tensorflow/compiler/tf2xla/python:xla", + "//tensorflow/python/eager:def_function", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:errors", + "//tensorflow/python/framework:function", + "//tensorflow/python/framework:ops", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/platform:client_testlib", + "//tensorflow/python/platform:test", + "//third_party/py/numpy", + ], +) + tf_xla_py_strict_test( name = "bincount_op_test", size = "small", diff --git a/tensorflow/compiler/tests/xla_call_module_test.py b/tensorflow/compiler/tests/xla_call_module_test.py index 1da783ffab533e..e8be9400e2a17c 100644 --- a/tensorflow/compiler/tests/xla_call_module_test.py +++ b/tensorflow/compiler/tests/xla_call_module_test.py @@ -432,6 +432,203 @@ def f(x): 'arguments, but it has only 1 total arguments'): self._assertOpOutputMatchesExpected(f, (x,), (x,)) + def test_shape_assertion_success(self): + x = np.ones((3, 5), dtype=np.int32) + res = np.int32(x.shape[0]) + + def f(x): # x: f32[b, 5] and b = 3 + # return x.shape[0] + module, version = serialize(""" +module @jit_fun.1 { + func.func public @main(%arg1: tensor) -> tensor { + %b = "stablehlo.get_dimension_size"(%arg1) {dimension = 0 : i64} : (tensor) -> tensor + %3 = stablehlo.constant dense<3> : tensor + %ok = stablehlo.compare EQ, %b, %3, SIGNED : (tensor, tensor) -> tensor + stablehlo.custom_call @shape_assertion(%ok) { + error_message = "The error message", + has_side_effect = true + } : (tensor) -> () + return %b : tensor + } + +} +""") + return xla.call_module([x,], version=version, + module=module, + Tout=[res.dtype], + Sout=[res.shape], + platforms=[self.testing_platform()],) + + self._assertOpOutputMatchesExpected(f, (x,), (res,)) + + def test_shape_assertion_failure(self): + x = np.ones((3, 5), dtype=np.int32) + res = np.int32(x.shape[0]) + + def f(x): # x: f32[b, 5] and b = 3, with a constraint b == 4. + # return x.shape[0] + module, version = serialize(""" +module @jit_fun.1 { + func.func public @main(%arg1: tensor) -> tensor { + %b = "stablehlo.get_dimension_size"(%arg1) {dimension = 0 : i64} : (tensor) -> tensor + %4 = stablehlo.constant dense<4> : tensor + %ok = stablehlo.compare EQ, %b, %4, SIGNED : (tensor, tensor) -> tensor + stablehlo.custom_call @shape_assertion(%ok, %b, %4) { + error_message = "Expecting {0} == {1}", + has_side_effect = true + } : (tensor, tensor, tensor) -> () + return %b : tensor + } +} +""") + return xla.call_module([x,], version=version, + module=module, + Tout=[res.dtype], + Sout=[res.shape], + platforms=[self.testing_platform()],) + + # This test runs as part of two targets, with and without + # disabling shape_assertions. + disabled_shape_assertions_check = ( + '--tf_xla_call_module_disabled_checks=shape_assertions' + in os.getenv('TF_XLA_FLAGS', '')) + if disabled_shape_assertions_check: + # No error even though the constraint is false. + self._assertOpOutputMatchesExpected(f, (x,), (res,)) + else: + with self.assertRaisesRegex( + errors.InvalidArgumentError, + 'Expecting 3 == 4'): + self._assertOpOutputMatchesExpected(f, (x,), (res,)) + + def test_invalid_shape_assertion(self): + arg_i1 = np.bool_(True) + arg_i32 = np.int32(2) + res = arg_i32 + + # This test runs as part of two targets, with and without + # disabling shape_assertions. + disabled_shape_assertions_check = ( + '--tf_xla_call_module_disabled_checks=shape_assertions' + in os.getenv('TF_XLA_FLAGS', '')) + if disabled_shape_assertions_check: + self.skipTest('Test is N/A when shape_assertions are disabled') + + subtest_count = 1 + def one_subtest(error_msg: str, module_str: str): + def f(*args): + module, version = serialize(module_str) + return xla.call_module( + list(args), + version=version, + module=module, + Tout=[res.dtype], + Sout=[res.shape], + platforms=[self.testing_platform()], + ) + + nonlocal subtest_count + subtest_count += 1 + with self.subTest(count=subtest_count, error_msg=error_msg): + with self.assertRaisesRegex(errors.InvalidArgumentError, error_msg): + self._assertOpOutputMatchesExpected(f, (arg_i1, arg_i32), (res,)) + + one_subtest( + 'expects assert_what .* to be a constant of type tensor', + """ +module @jit_fun.1 { + func.func public @main(%arg_i1: tensor, %arg_i32: tensor) -> tensor { + %ok = stablehlo.constant dense<0> : tensor + stablehlo.custom_call @shape_assertion(%ok) { + error_message = "Some error", + has_side_effect = true + } : (tensor) -> () + return %arg_i32 : tensor + } +} +""", + ) + + one_subtest( + 'expects static assert_what', + """ +module @jit_fun.1 { + func.func public @main(%arg_i1: tensor, %arg_i32: tensor) -> tensor { + stablehlo.custom_call @shape_assertion(%arg_i1) { + error_message = "Some error", + has_side_effect = true + } : (tensor) -> () + return %arg_i32 : tensor + } +} +""", + ) + + one_subtest( + 'expects has_side_effect=true', + """ +module @jit_fun.1 { + func.func public @main(%arg_i1: tensor, %arg_i32: tensor) -> tensor { + %ok = stablehlo.constant dense : tensor + stablehlo.custom_call @shape_assertion(%ok) { + error_message = "Some error", + has_side_effect = false + } : (tensor) -> () + return %arg_i32 : tensor + } +} +""", + ) + + one_subtest( + 'expects error_message .* Found specifier {0}', + """ +module @jit_fun.1 { + func.func public @main(%arg_i1: tensor, %arg_i32: tensor) -> tensor { + %ok = stablehlo.constant dense : tensor + stablehlo.custom_call @shape_assertion(%ok) { + error_message = "Some error {0}", + has_side_effect = true + } : (tensor) -> () + return %arg_i32 : tensor + } +} +""", + ) + + one_subtest( + 'expects static error_message_input', + """ +module @jit_fun.1 { + func.func public @main(%arg_i1: tensor, %arg_i32: tensor) -> tensor { + %ok = stablehlo.constant dense : tensor + stablehlo.custom_call @shape_assertion(%ok, %arg_i32) { + error_message = "Some error {0}", + has_side_effect = true + } : (tensor, tensor) -> () + return %arg_i32 : tensor + } +} +""", + ) + + one_subtest( + 'expects error_message_input .* to be a constant of type tensor', + """ +module @jit_fun.1 { + func.func public @main(%arg_i1: tensor, %arg_i32: tensor) -> tensor { + %ok = stablehlo.constant dense : tensor + %c = stablehlo.constant dense<2.0> : tensor + stablehlo.custom_call @shape_assertion(%ok, %c) { + error_message = "Some error {0}", + has_side_effect = true + } : (tensor, tensor) -> () + return %arg_i32 : tensor + } +} +""", + ) + def test_dynamic_iota(self): x = np.ones((3, 5), dtype=np.int32) res = np.arange(x.shape[0], dtype=np.int32) diff --git a/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.cc b/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.cc index 0b84d5434776bd..98808caa60d551 100644 --- a/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.cc +++ b/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.cc @@ -24,6 +24,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/Block.h" // from @llvm-project @@ -85,16 +86,29 @@ constexpr int VERSION_START_SUPPORT_CALL_TF_GRAPH = 5; // mandates a non-empty `platforms` attribute. // Used in jax2tf since June 2023. constexpr int VERSION_START_SUPPORT_DISABLED_CHECKS = 6; +// Version 7 adds support for `stablehlo.shape_assertion` operations and +// for `shape_assertions` specified in `disabled_checks`. +// Used in JAX serialization since July 2023. +constexpr int VERSION_START_SUPPORT_SHAPE_ASSERTIONS = 7; constexpr int VERSION_MINIMUM_SUPPORTED = VERSION_START_STABLE_HLO_COMPATIBILITY; -constexpr int VERSION_MAXIMUM_SUPPORTED = VERSION_START_SUPPORT_DISABLED_CHECKS; +constexpr int VERSION_MAXIMUM_SUPPORTED = + VERSION_START_SUPPORT_SHAPE_ASSERTIONS; constexpr absl::string_view DISABLED_CHECK_PLATFORM = "platform"; bool IsPlatformCheckDisabled(absl::Span disabled_checks) { - return std::find(disabled_checks.begin(), disabled_checks.end(), - DISABLED_CHECK_PLATFORM) != disabled_checks.end(); + return llvm::is_contained(disabled_checks, DISABLED_CHECK_PLATFORM); +} + +constexpr absl::string_view DISABLED_CHECK_SHAPE_ASSERTIONS = + "shape_assertions"; + +bool IsShapeAssertionsCheckDisabled( + absl::Span loading_disabled_checks) { + return llvm::is_contained(loading_disabled_checks, + DISABLED_CHECK_SHAPE_ASSERTIONS); } // Computes a dimension value from the dim_arg specification. @@ -396,8 +410,11 @@ tsl::Status XlaCallModuleLoader::RefineDynamicShapes( if (VLOG_IS_ON(5)) { DumpMlirOpToFile("xla_call_module.after_refined_input_types", *module_); } - - TF_RETURN_IF_ERROR(xla::RefinePolymorphicShapes(*module_)); + bool enable_shape_assertions = + (version_ >= VERSION_START_SUPPORT_SHAPE_ASSERTIONS && + !IsShapeAssertionsCheckDisabled(loading_disabled_checks_)); + TF_RETURN_IF_ERROR( + xla::RefinePolymorphicShapes(*module_, enable_shape_assertions)); if (VLOG_IS_ON(3)) { DumpMlirOpToFile("xla_call_module.after_shape_refinement", *module_); @@ -436,9 +453,9 @@ tsl::Status XlaCallModuleLoader::LoadAndPreprocessModule( module_ = mlir::parseSourceString(module_str, context_); } - std::vector loading_disabled_checks = disabled_checks; - loading_disabled_checks.insert( - loading_disabled_checks.end(), + loading_disabled_checks_ = disabled_checks; + loading_disabled_checks_.insert( + loading_disabled_checks_.end(), GetXlaCallModuleFlags()->disabled_checks.begin(), GetXlaCallModuleFlags()->disabled_checks.end()); if (!module_) { @@ -451,7 +468,7 @@ tsl::Status XlaCallModuleLoader::LoadAndPreprocessModule( << ", dim_args_spec = [" << absl::StrJoin(dim_args_spec_, ", ") << "], disabled_checks = [" << absl::StrJoin(disabled_checks, ", ") << "], loading_disabled_checks = [" - << absl::StrJoin(loading_disabled_checks, ", ") << "]), module = " + << absl::StrJoin(loading_disabled_checks_, ", ") << "]), module = " << DumpMlirOpToFile("xla_call_module.parsed", *module_); if (version < VERSION_MINIMUM_SUPPORTED) { @@ -471,7 +488,7 @@ tsl::Status XlaCallModuleLoader::LoadAndPreprocessModule( auto found_platform = std::find(platforms.begin(), platforms.end(), loading_platform); if (found_platform == platforms.end()) { - if (!IsPlatformCheckDisabled(loading_disabled_checks)) { + if (!IsPlatformCheckDisabled(loading_disabled_checks_)) { return absl::NotFoundError(absl::StrCat( "The current platform ", loading_platform, " is not among the platforms required by the module: [", diff --git a/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.h b/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.h index 54aaa6ae58f097..2eab7d25faa2ee 100644 --- a/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.h +++ b/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.h @@ -58,11 +58,6 @@ class XlaCallModuleLoader { // Validates that the module only contains ops from valid dialects. tsl::Status ValidateDialect(); - // Validates that the module represents a statically-shaped StableHLO program, - // otherwise all sorts of weirdness might happen in the HLO exporter which is - // much easier to detect here. - tsl::Status ValidateStaticShapes(); - // Lowers the StableHLO module to MHLO in place. absl::Status LowerModuleToMhlo(); @@ -97,6 +92,9 @@ class XlaCallModuleLoader { // a platform index arg. int platform_index_; std::vector dim_args_spec_; + // The disabled checks at loading time, including those from the + // disabled_checks attribute and the TF_XLA_FLAGS environment variable. + std::vector loading_disabled_checks_; mlir::func::FuncOp main_; }; diff --git a/tensorflow/compiler/tf2xla/python/xla.py b/tensorflow/compiler/tf2xla/python/xla.py index d3638443234033..4cc4845b60b7e7 100644 --- a/tensorflow/compiler/tf2xla/python/xla.py +++ b/tensorflow/compiler/tf2xla/python/xla.py @@ -663,12 +663,16 @@ def call_module( return res -# pylint: enable=g-doc-args -# pylint: enable=g-doc-return-or-yield +def call_module_maximum_supported_version(): + """Maximum version of XlaCallModule op supported. + See versioning details documentation for the XlaCallModule op at: + https://github.com/search?q=repo%3Atensorflow%2Ftensorflow+path%3Axla_call_module+%22int+VERSION_MAXIMUM_SUPPORTED%22&type=code + """ + return 7 -def call_module_maximum_supported_version(): - return 6 +# pylint: enable=g-doc-args +# pylint: enable=g-doc-return-or-yield def call_module_disable_check_platform(): diff --git a/tensorflow/compiler/xla/python/BUILD b/tensorflow/compiler/xla/python/BUILD index cd68ae2423b4f5..f7614d7ae9adc4 100644 --- a/tensorflow/compiler/xla/python/BUILD +++ b/tensorflow/compiler/xla/python/BUILD @@ -748,7 +748,6 @@ cc_library( deps = [ "//tensorflow/compiler/xla/mlir/utils:error_util", "//tensorflow/tsl/platform:errors", - "//tensorflow/tsl/platform:logging", "@com_google_absl//absl/status", "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", diff --git a/tensorflow/compiler/xla/python/mlir.cc b/tensorflow/compiler/xla/python/mlir.cc index f121daaaec0a29..4c6526aa51cb60 100644 --- a/tensorflow/compiler/xla/python/mlir.cc +++ b/tensorflow/compiler/xla/python/mlir.cc @@ -233,13 +233,14 @@ void BuildMlirSubmodule(py::module& m) { py::arg("mlir_module")); mlir_module.def( "refine_polymorphic_shapes", - [](std::string mlir_module) -> py::bytes { + [](std::string mlir_module, bool enable_shape_assertions) -> py::bytes { std::string buffer; llvm::raw_string_ostream os(buffer); - xla::ThrowIfError(RefinePolymorphicShapes(mlir_module, os)); + xla::ThrowIfError( + RefinePolymorphicShapes(mlir_module, os, enable_shape_assertions)); return py::bytes(buffer); }, - py::arg("mlir_module"), + py::arg("mlir_module"), py::arg("enable_shape_assertions") = true, R"(Refines the dynamic shapes for a module. The "main" function must have static shapes and all the intermediate dynamic shapes depend only on the input static diff --git a/tensorflow/compiler/xla/python/refine_polymorphic_shapes.cc b/tensorflow/compiler/xla/python/refine_polymorphic_shapes.cc index 91b16e3b6b0e4a..f3179e7b8d1f5f 100644 --- a/tensorflow/compiler/xla/python/refine_polymorphic_shapes.cc +++ b/tensorflow/compiler/xla/python/refine_polymorphic_shapes.cc @@ -12,14 +12,20 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ - #include "tensorflow/compiler/xla/python/refine_polymorphic_shapes.h" +#include +#include + #include "absl/status/status.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/Regex.h" #include "llvm/Support/raw_ostream.h" #include "mlir/Bytecode/BytecodeWriter.h" // from @llvm-project #include "mlir/Dialect/Func/Extensions/AllExtensions.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/ValueRange.h" // from @llvm-project #include "mlir/IR/Verifier.h" // from @llvm-project #include "mlir/Parser/Parser.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project @@ -33,7 +39,203 @@ limitations under the License. namespace xla { -absl::Status RefinePolymorphicShapes(mlir::ModuleOp module) { +namespace { + +constexpr absl::string_view shapeAssertionName = "shape_assertion"; +constexpr absl::string_view errorMessageAttrName = "error_message"; +// We bound the number of error_message_inputs for using llvm::formatv +constexpr int maxErrorMessageInputs = 4; + +// This pass is needed when we have shape assertions. A shape assertion is +// represented via the `stablehlo.custom_call @shape_assertion` +// custom call, and represents an assertion that the first operand +// (`assert_what`) evaluates to `true`. The custom call also has an +// `error_message` string attribute, and a variadic number +// of integer scalar operands that may be used to format the error message. +// The `error_message` may contain format specifiers `{0}`, `{1}`, ..., that +// are replaced with the values of the error message inputs. The formatting is +// done with the `llvm::formatv` function +// (https://llvm.org/docs/ProgrammersManual.html#formatting-strings-the-formatv-function). +// +struct CheckShapeAssertionsPass + : public mlir::PassWrapper> { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CheckShapeAssertionsPass) + + explicit CheckShapeAssertionsPass(bool enable_shape_assertions = true) + : PassWrapper() { + this->enable_shape_assertions = enable_shape_assertions; + } + + CheckShapeAssertionsPass(const CheckShapeAssertionsPass &pass) { + enable_shape_assertions = pass.enable_shape_assertions; + } + + private: + void runOnOperation() final { + mlir::func::FuncOp func_op = getOperation(); + func_op.walk([this](mlir::stablehlo::CustomCallOp op) { + if (!op.getCallTargetName().equals(shapeAssertionName)) return; + if (!enable_shape_assertions) { + op.erase(); + return; + } + // Check first for ill-formed assertions, rather than silently fail. + if (mlir::failed(verifyShapeAssertion(op))) { + signalPassFailure(); + return; + } + mlir::OperandRange inputs = op.getInputs(); + mlir::SmallVector assertWhat; + if (mlir::failed(mlir::hlo::matchInts(inputs[0], assertWhat))) { + op.emitError() << "expects static assert_what (operand #0)"; + signalPassFailure(); + return; + } + if (assertWhat[0] != 0) { + op.erase(); + return; + } + llvm::StringRef errorMessage = getErrorMessage(op); + mlir::SmallVector errorMessageInputs; + for (int i = 1; i < inputs.size(); ++i) { + mlir::SmallVector input; + if (failed(mlir::hlo::matchInts(inputs[i], input))) { + op.emitError() << "expects static error_message_input (operand #" << i + << ")"; + signalPassFailure(); + return; + } + errorMessageInputs.push_back(input[0]); + } + op.emitError(formatErrorMessage(errorMessage, errorMessageInputs)); + signalPassFailure(); + }); + } + + mlir::LogicalResult verifyShapeAssertion(mlir::stablehlo::CustomCallOp op) { + if (!(1 <= op->getNumOperands() && + op->getNumOperands() <= 1 + maxErrorMessageInputs)) + return op.emitError() << "expects 1 <= size(operands) <= " + << (1 + maxErrorMessageInputs); + int nrErrorMessageInputs = op.getNumOperands() - 1; + if (op->getNumResults() != 0) + return op.emitError("expects size(results) = 0"); + for (const auto &attr : op->getAttrs()) { + if (attr.getName() != "api_version" && + attr.getName() != "backend_config" && + attr.getName() != "call_target_name" && + attr.getName() != "error_message" && + attr.getName() != "has_side_effect") + return op.emitError() + << attr.getName() << " is not a supported attribute"; + } + if (!op.getBackendConfig().empty()) + return op.emitError() << "expects an empty backend_config"; + if (!op.getCallTargetName().equals(shapeAssertionName)) + return op.emitError() << "expects @shape_assertion"; + if (!op.getHasSideEffect()) + return op.emitError() << "expects has_side_effect=true"; + + // input[0] (assert_what) : tensor + auto assertWhatType = + op.getInputs()[0].getType().dyn_cast(); + if (!assertWhatType || !assertWhatType.hasRank() || + assertWhatType.getRank() != 0 || + !assertWhatType.getElementType().isSignlessInteger() || + assertWhatType.getElementTypeBitWidth() != 1) + return op.emitError() << "expects assert_what (operand #0) " + << "to be a constant of type tensor"; + + // input[1:] (error_message_inputs) : tensor or tensor + for (int i = 0; i < nrErrorMessageInputs; ++i) { + auto errorMessageInputType = + op.getInputs()[i + 1].getType().dyn_cast(); + if (!errorMessageInputType || !errorMessageInputType.hasRank() || + errorMessageInputType.getRank() != 0 || + !errorMessageInputType.getElementType().isSignlessInteger() || + (errorMessageInputType.getElementTypeBitWidth() != 32 && + errorMessageInputType.getElementTypeBitWidth() != 64)) + return op.emitError() + << "expects error_message_input (operand #" << (i + 1) << ") " + << "to be a constant of type tensor or tensor"; + } + + if (!op->hasAttr(errorMessageAttrName)) + return op.emitError() << "expects an error_message attribute"; + + // error_message contains valid format specifiers. + std::string errorMessage = getErrorMessage(op).data(); + // format specs: "{" index ["," layout] [":" format] "}" + llvm::Regex formatSpecifierRE = llvm::Regex("{([0-9]+)[,:}]"); + do { + mlir::SmallVector formatSpec; + if (!formatSpecifierRE.match(errorMessage, &formatSpec)) { + break; + } + int index = std::stoi(formatSpec[1].data()); + if (!(0 <= index && index < nrErrorMessageInputs)) { + return op.emitError() + << "expects error_message to contain format specifiers with " + << "error_message_input index less than " << nrErrorMessageInputs + << ". Found specifier " << formatSpec[0]; + } + errorMessage = formatSpecifierRE.sub("", errorMessage); + } while (true); + + return mlir::success(); + } + + llvm::StringRef getErrorMessage(mlir::stablehlo::CustomCallOp op) const { + return op->getAttr(errorMessageAttrName) + .cast() + .getValue(); + } + + std::string formatErrorMessage( + llvm::StringRef errorMessage, + const mlir::SmallVector &errorMessageInputs) const { + int nrErrorMessageInputs = errorMessageInputs.size(); + auto errorMessageFormat = errorMessage.data(); + switch (nrErrorMessageInputs) { + case 0: + return errorMessageFormat; + case 1: + return llvm::formatv(errorMessageFormat, errorMessageInputs[0]); + case 2: + return llvm::formatv(errorMessageFormat, errorMessageInputs[0], + errorMessageInputs[1]); + case 3: + return llvm::formatv(errorMessageFormat, errorMessageInputs[0], + errorMessageInputs[1], errorMessageInputs[2]); + case 4: + return llvm::formatv(errorMessageFormat, errorMessageInputs[0], + errorMessageInputs[1], errorMessageInputs[2], + errorMessageInputs[3]); + default: + return errorMessageFormat; + } + } + + mlir::StringRef getArgument() const override { + return "check-shape-assertions"; + } + + mlir::StringRef getDescription() const override { + return "Check stablehlo.custom_call @shape_assertion ops."; + } + + Option enable_shape_assertions{ + *this, "enable-shape-assertions", + llvm::cl::desc("Whether shape assertions may generate errors."), + llvm::cl::init(true)}; +}; + +} // namespace + +absl::Status RefinePolymorphicShapes(mlir::ModuleOp module, + bool enable_shape_assertions) { mlir::MLIRContext *context = module->getContext(); if (VLOG_IS_ON(3)) context->disableMultithreading(); @@ -62,6 +264,8 @@ absl::Status RefinePolymorphicShapes(mlir::ModuleOp module) { pm.addPass(mlir::stablehlo::createStablehloRefineShapesPass()); pm.addNestedPass( mlir::stablehlo::createStablehloCanonicalizeDynamismPass()); + pm.addNestedPass( + std::make_unique(enable_shape_assertions)); if (!mlir::succeeded(pm.run(module))) { return absl::InvalidArgumentError( absl::StrCat("Module shape refinement failed: ", @@ -71,7 +275,8 @@ absl::Status RefinePolymorphicShapes(mlir::ModuleOp module) { } absl::Status RefinePolymorphicShapes(llvm::StringRef module_str, - llvm::raw_ostream &os) { + llvm::raw_ostream &os, + bool enable_shape_assertions) { mlir::MLIRContext context; if (VLOG_IS_ON(3)) context.disableMultithreading(); context.loadDialect(); @@ -88,16 +293,18 @@ absl::Status RefinePolymorphicShapes(llvm::StringRef module_str, if (!module) { return absl::InvalidArgumentError("Cannot parse module."); } - TF_RETURN_IF_ERROR(RefinePolymorphicShapes(*module)); + TF_RETURN_IF_ERROR(RefinePolymorphicShapes(*module, enable_shape_assertions)); if (mlir::failed(mlir::writeBytecodeToFile(*module, os))) { return absl::InternalError("Cannot serialize module."); } + return absl::OkStatus(); } absl::Status ValidateStaticShapes(mlir::ModuleOp module) { mlir::BaseScopedDiagnosticHandler diag_handler(module->getContext()); bool moduleHasDynamicShapes = false; + bool moduleHasShapeAssertions = false; module->walk([&](mlir::Operation *op) { // It's sufficient to only check results because operands either come from @@ -116,6 +323,13 @@ absl::Status ValidateStaticShapes(mlir::ModuleOp module) { moduleHasDynamicShapes = true; op->emitOpError() << "has dynamic shapes"; } + + auto customCall = mlir::dyn_cast(op); + if (customCall && + customCall.getCallTargetName().equals(shapeAssertionName)) { + moduleHasShapeAssertions = true; + op->emitOpError() << "has residual shape assertions"; + } }); if (moduleHasDynamicShapes) { @@ -123,6 +337,11 @@ absl::Status ValidateStaticShapes(mlir::ModuleOp module) { absl::StrCat("Module has dynamic shapes: ", diag_handler.ConsumeStatus().ToString())); } + if (moduleHasShapeAssertions) { + return absl::InvalidArgumentError( + absl::StrCat("Module has residual shape assertions: ", + diag_handler.ConsumeStatus().ToString())); + } return absl::OkStatus(); } diff --git a/tensorflow/compiler/xla/python/refine_polymorphic_shapes.h b/tensorflow/compiler/xla/python/refine_polymorphic_shapes.h index 4f553d5f64c73d..726ce3aec07249 100644 --- a/tensorflow/compiler/xla/python/refine_polymorphic_shapes.h +++ b/tensorflow/compiler/xla/python/refine_polymorphic_shapes.h @@ -24,12 +24,18 @@ namespace xla { // Refines the dynamic shapes for a module whose "main" has static shapes // and all the intermediate dynamic shapes depend only on the input static -// shapes. -absl::Status RefinePolymorphicShapes(mlir::ModuleOp module); +// shapes. Upon refinement, validates that the module does not contain remaining +// dynamic shapes. +// If `enable_shape_assertions` is false, then the shape assertions +// are removed from the module, otherwise they are removed only if the +// assertions hold, and result in an error otherwise. +absl::Status RefinePolymorphicShapes(mlir::ModuleOp module, + bool enable_shape_assertions); // Like the above but with serialized input and output modules. absl::Status RefinePolymorphicShapes(llvm::StringRef module_str, - llvm::raw_ostream &os); + llvm::raw_ostream &os, + bool enable_shape_assertions); // Validates that the module has only static shapes. absl::Status ValidateStaticShapes(mlir::ModuleOp module); diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index 0d6f2f6c0c6c47..61a3cd8e0e7fd4 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -47,7 +47,7 @@ _version = 166 # Version number for MLIR:Python components. -mlir_api_version = 51 +mlir_api_version = 52 xla_platform_names = { 'cpu': 'Host', diff --git a/tensorflow/compiler/xla/python/xla_extension/mlir.pyi b/tensorflow/compiler/xla/python/xla_extension/mlir.pyi index 2e47b1746833e9..e3f073b1defbc9 100644 --- a/tensorflow/compiler/xla/python/xla_extension/mlir.pyi +++ b/tensorflow/compiler/xla/python/xla_extension/mlir.pyi @@ -24,4 +24,5 @@ def mhlo_to_stablehlo(mlir_module: Union[bytes, str]) -> str: ... def stablehlo_to_mhlo(mlir_module: Union[bytes, str]) -> str: ... def serialize_portable_artifact(mlir_module: str, target:str) -> bytes: ... def deserialize_portable_artifact(mlir_module: bytes) -> str: ... -def refine_polymorphic_shapes(mlir_module: Union[bytes, str]) -> bytes: ... +def refine_polymorphic_shapes(mlir_module: Union[bytes, str], + enable_shape_assertions: bool = ...) -> bytes: ... From ffe0b91a3a74077dbab294acf666a341fe7b40fa Mon Sep 17 00:00:00 2001 From: Christian Sigg Date: Wed, 12 Jul 2023 07:02:26 -0700 Subject: [PATCH 188/376] Import openai/triton from GitHub. PiperOrigin-RevId: 547484298 --- .../xla/service/gpu/ir_emitter_triton.cc | 1 + third_party/triton/cl545644269.patch | 180 ------------------ third_party/triton/workspace.bzl | 5 +- 3 files changed, 3 insertions(+), 183 deletions(-) delete mode 100644 third_party/triton/cl545644269.patch diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_triton.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_triton.cc index 709f3e40b52c3f..d1b570f06054b2 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_triton.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_triton.cc @@ -615,6 +615,7 @@ void CreateTritonPipeline(mlir::OpPassManager& pm, pm.addPass(mlir::createCanonicalizerPass()); pm.addPass(mlir::createCSEPass()); pm.addPass(mlir::createSymbolDCEPass()); + // Note: translateTritonGPUToLLVMIR adds line info with LLVMDIScopePass. } // Extract additional attributes from an LLVM function that are not passed diff --git a/third_party/triton/cl545644269.patch b/third_party/triton/cl545644269.patch deleted file mode 100644 index 9f453888c70c19..00000000000000 --- a/third_party/triton/cl545644269.patch +++ /dev/null @@ -1,180 +0,0 @@ -diff --git a/BUILD b/BUILD -index a5a813485..c7f8aa5a6 100644 ---- a/BUILD -+++ b/BUILD -@@ -275,8 +275,7 @@ cc_library( - copts = _no_unused_variable, - includes = ["include"], - deps = [ -- ":TritonDialect", -- ":TritonGPUDialect", -+ ":TritonDialects", - ":TritonTools", - ":triton_gpu_attr_inc_gen", - "@llvm-project//llvm:Support", -@@ -291,44 +290,53 @@ cc_library( - ) - - cc_library( -- name = "TritonDialect", -- srcs = glob(["lib/Dialect/Triton/IR/*.cpp"]), -- hdrs = glob(["include/triton/Dialect/Triton/IR/*.h"]), -+ name = "TritonDialects", -+ srcs = glob([ -+ "lib/Dialect/Triton/IR/*.cpp", -+ "lib/Dialect/TritonGPU/IR/*.cpp", -+ ]) + [ -+ "include/triton/Analysis/Utility.h", # Avoid circular dependency. -+ ], -+ hdrs = glob([ -+ "include/triton/Dialect/Triton/IR/*.h", -+ "include/triton/Dialect/TritonGPU/IR/*.h", -+ ]), - copts = _no_unused_variable, - includes = ["include"], - deps = [ -- ":TritonGPUAttributes", - ":triton_dialect_inc_gen", -+ ":triton_gpu_attr_inc_gen", -+ ":triton_gpu_dialect_inc_gen", -+ ":triton_gpu_ops_inc_gen", -+ ":triton_gpu_transforms_inc_gen", - ":triton_interfaces_inc_gen", - ":triton_ops_inc_gen", - "@llvm-project//llvm:Support", -+ "@llvm-project//mlir:Analysis", - "@llvm-project//mlir:ArithDialect", - "@llvm-project//mlir:ControlFlowDialect", - "@llvm-project//mlir:ControlFlowInterfaces", -+ "@llvm-project//mlir:DestinationStyleOpInterface", - "@llvm-project//mlir:FuncDialect", -+ "@llvm-project//mlir:GPUDialect", - "@llvm-project//mlir:IR", -+ "@llvm-project//mlir:LLVMDialect", - "@llvm-project//mlir:MathDialect", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:SCFDialect", - "@llvm-project//mlir:Support", - "@llvm-project//mlir:TensorDialect", -+ "@llvm-project//mlir:Transforms", - ], - ) - --cc_library( -- name = "TritonGPUAttributes", -- hdrs = ["include/triton/Dialect/TritonGPU/IR/Attributes.h"], -- includes = ["include"], -- deps = ["triton_gpu_attr_inc_gen"], --) -- - cc_library( - name = "TritonTransforms", - srcs = glob(["lib/Dialect/Triton/Transforms/*.cpp"]), - hdrs = glob(["include/triton/Dialect/Triton/Transforms/*.h"]), - includes = ["include"], - deps = [ -- ":TritonDialect", -+ ":TritonDialects", - ":triton_combine_inc_gen", - ":triton_transforms_inc_gen", - "@llvm-project//llvm:Support", -@@ -347,36 +355,6 @@ cc_library( - alwayslink = True, # TritonDialect uses getCanonicalizationPatterns(). - ) - --cc_library( -- name = "TritonGPUDialect", -- srcs = glob(["lib/Dialect/TritonGPU/IR/*.cpp"]), -- hdrs = [ -- "include/triton/Analysis/Utility.h", # Avoid circular dependency. -- "include/triton/Dialect/TritonGPU/IR/Dialect.h", -- "include/triton/Dialect/TritonGPU/IR/Traits.h", -- ], -- copts = _no_unused_variable, -- includes = ["include"], -- deps = [ -- ":TritonDialect", -- ":TritonGPUAttributes", -- ":triton_gpu_attr_inc_gen", -- ":triton_gpu_dialect_inc_gen", -- ":triton_gpu_ops_inc_gen", -- ":triton_gpu_transforms_inc_gen", -- "@llvm-project//llvm:Support", -- "@llvm-project//mlir:Analysis", -- "@llvm-project//mlir:DestinationStyleOpInterface", -- "@llvm-project//mlir:GPUDialect", -- "@llvm-project//mlir:IR", -- "@llvm-project//mlir:LLVMDialect", -- "@llvm-project//mlir:Pass", -- "@llvm-project//mlir:Support", -- "@llvm-project//mlir:TensorDialect", -- "@llvm-project//mlir:Transforms", -- ], --) -- - cc_library( - name = "TritonGPUTransforms", - srcs = glob([ -@@ -388,9 +366,7 @@ cc_library( - includes = ["include"], - deps = [ - ":TritonAnalysis", -- ":TritonDialect", -- ":TritonGPUAttributes", -- ":TritonGPUDialect", -+ ":TritonDialects", - ":triton_gpu_transforms_inc_gen", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:Analysis", -@@ -428,8 +404,7 @@ cc_library( - ], - deps = [ - ":TritonAnalysis", -- ":TritonDialect", -- ":TritonGPUDialect", -+ ":TritonDialects", - ":triton_conversion_triton_gpu_to_llvm_passes_inc_gen", - ":triton_conversion_triton_to_triton_gpu_passes_inc_gen", - "@llvm-project//llvm:Support", -@@ -466,8 +441,8 @@ cc_library( - hdrs = glob(["include/triton/Conversion/TritonToTritonGPU/*.h"]), - includes = ["include"], - deps = [ -- ":TritonDialect", -- ":TritonGPUDialect", -+ ":TritonAnalysis", -+ ":TritonDialects", - ":TritonGPUTransforms", - ":triton_conversion_triton_gpu_to_llvm_passes_inc_gen", - ":triton_conversion_triton_to_triton_gpu_passes_inc_gen", -@@ -513,9 +488,7 @@ cc_library( - "@llvm-project//mlir:ROCDLToLLVMIRTranslation", - "@llvm-project//mlir:ToLLVMIRTranslation", - "@llvm-project//mlir:Transforms", -- # copybara:uncomment_begin -- # "//third_party/py/triton/google:find_cuda", -- # copybara:uncomment_end -+ # copybara:uncomment "//third_party/py/triton/google:find_cuda", - ], - ) - -@@ -579,8 +552,7 @@ cc_binary( - ], - includes = ["include"], - deps = [ -- ":TritonDialect", -- ":TritonGPUDialect", -+ ":TritonDialects", - ":TritonGPUToLLVM", - ":TritonGPUTransforms", - ":TritonToTritonGPU", -@@ -618,8 +590,7 @@ cc_binary( - ], - includes = ["include"], - deps = [ -- ":TritonDialect", -- ":TritonGPUDialect", -+ ":TritonDialects", - ":TritonGPUToLLVM", - ":TritonGPUTransforms", - ":TritonHSACO", diff --git a/third_party/triton/workspace.bzl b/third_party/triton/workspace.bzl index 1ccf24e7d75fc7..219c10b85e2e7f 100644 --- a/third_party/triton/workspace.bzl +++ b/third_party/triton/workspace.bzl @@ -5,8 +5,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): """Imports Triton.""" - TRITON_COMMIT = "cl545371535" - TRITON_SHA256 = "97e9af5aa986744b9d3807e8a473b2b2056c8bedc74842b607d40cf780e8ac5a" + TRITON_COMMIT = "cl546794996" + TRITON_SHA256 = "57d4b5f1e68bb4df93528bd5394ba3338bef7bf9c0afdc96b44371fba650c037" tf_http_archive( name = "triton", @@ -16,6 +16,5 @@ def repo(): # For temporary changes which haven't landed upstream yet. patch_file = [ "//third_party/triton:cl536931041.patch", - "//third_party/triton:cl545644269.patch", ], ) From bca2121bd0c740acb40d7d2d44cb435522740721 Mon Sep 17 00:00:00 2001 From: Michael Hudgins Date: Wed, 12 Jul 2023 10:48:29 -0400 Subject: [PATCH 189/376] Make upload condition more specific Only upload if the tag specifically starts with v2 --- .github/workflows/arm-cd.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/arm-cd.yml b/.github/workflows/arm-cd.yml index ad6bcb9edc0e28..f99e948571dbfe 100644 --- a/.github/workflows/arm-cd.yml +++ b/.github/workflows/arm-cd.yml @@ -68,6 +68,6 @@ jobs: CI_DOCKER_BUILD_EXTRA_PARAMS="--build-arg py_major_minor_version=${{ matrix.pyver }} --build-arg is_nightly=${is_nightly} --build-arg tf_project_name=${tf_project_name}" \ ./tensorflow/tools/ci_build/ci_build.sh cpu.arm64 bash tensorflow/tools/ci_build/rel/ubuntu/cpu_arm64_pip.sh - name: Upload pip wheel to PyPI - if: github.event_name == 'schedule' || (github.event_name == 'push' && contains(github.ref, 'refs/tags/')) # only if it is a scheduled nightly or tagged + if: github.event_name == 'schedule' || (github.event_name == 'push' && startsWith(github.ref, 'refs/tags/v2')) # only if it is a scheduled nightly or tagged shell: bash run: python3 -m twine upload --verbose /home/ubuntu/actions-runner/_work/tensorflow/tensorflow/whl/* -u "__token__" -p ${{ secrets.AWS_PYPI_ACCOUNT_TOKEN }} From 395d9429994d6ed522a87dbfbbcec6e2563693cf Mon Sep 17 00:00:00 2001 From: George Karpenkov Date: Wed, 12 Jul 2023 08:13:24 -0700 Subject: [PATCH 190/376] [XLA:GPU] Minor refactoring of autotuning utils PiperOrigin-RevId: 547499313 --- tensorflow/compiler/xla/service/gpu/BUILD | 16 ++++++------ .../xla/service/gpu/amdgpu_compiler.cc | 2 +- .../xla/service/gpu/autotuner_compile_util.cc | 25 ++++++++++++------- .../xla/service/gpu/autotuner_compile_util.h | 14 +++++++---- .../compiler/xla/service/gpu/autotuner_util.h | 8 +++++- ...thm_picker.cc => conv_algorithm_picker.cc} | 15 +++-------- ...rithm_picker.h => conv_algorithm_picker.h} | 6 ++--- ..._test.cc => conv_algorithm_picker_test.cc} | 2 +- .../xla/service/gpu/gemm_algorithm_picker.cc | 7 +----- .../xla/service/gpu/nvptx_compiler.cc | 2 +- .../compiler/xla/service/gpu/runtime/BUILD | 2 +- .../compiler/xla/service/gpu/runtime/conv.cc | 2 +- .../xla/service/gpu/triton_autotuner.cc | 18 +++---------- 13 files changed, 55 insertions(+), 64 deletions(-) rename tensorflow/compiler/xla/service/gpu/{gpu_conv_algorithm_picker.cc => conv_algorithm_picker.cc} (98%) rename tensorflow/compiler/xla/service/gpu/{gpu_conv_algorithm_picker.h => conv_algorithm_picker.h} (96%) rename tensorflow/compiler/xla/service/gpu/{gpu_conv_algorithm_picker_test.cc => conv_algorithm_picker_test.cc} (98%) diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index ca2ee0cb3c66b7..4e654f9bf38446 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -1499,9 +1499,9 @@ xla_cc_test( ) cc_library( - name = "gpu_conv_algorithm_picker", - srcs = if_gpu_is_configured(["gpu_conv_algorithm_picker.cc"]), - hdrs = if_gpu_is_configured(["gpu_conv_algorithm_picker.h"]), + name = "conv_algorithm_picker", + srcs = if_gpu_is_configured(["conv_algorithm_picker.cc"]), + hdrs = if_gpu_is_configured(["conv_algorithm_picker.h"]), copts = if_cuda_is_configured(["-DGOOGLE_CUDA=1"]) + if_rocm_is_configured([ "-DTENSORFLOW_USE_ROCM=1", ]), @@ -1538,8 +1538,8 @@ cc_library( ) xla_cc_test( - name = "gpu_conv_algorithm_picker_test", - srcs = if_gpu_is_configured(["gpu_conv_algorithm_picker_test.cc"]), + name = "conv_algorithm_picker_test", + srcs = if_gpu_is_configured(["conv_algorithm_picker_test.cc"]), tags = [ "gpu", "no_oss", @@ -1548,7 +1548,7 @@ xla_cc_test( "requires-gpu-sm70", ], deps = [ - ":gpu_conv_algorithm_picker", + ":conv_algorithm_picker", ":gpu_conv_rewriter", "//tensorflow/compiler/xla/service:gpu_plugin", "//tensorflow/compiler/xla/service:pattern_matcher", @@ -2564,7 +2564,7 @@ cc_library( ":gemm_algorithm_picker", ":gpu_asm_opts_util", ":gpu_compiler", - ":gpu_conv_algorithm_picker", + ":conv_algorithm_picker", ":gpu_conv_padding_legalization", ":gpu_conv_rewriter", ":gpu_executable", @@ -2688,7 +2688,7 @@ cc_library( ":cusolver_rewriter", ":gemm_rewriter", ":gpu_compiler", - ":gpu_conv_algorithm_picker", + ":conv_algorithm_picker", ":gpu_conv_padding_legalization", ":gpu_conv_rewriter", ":gpu_layout_assignment", diff --git a/tensorflow/compiler/xla/service/gpu/amdgpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/amdgpu_compiler.cc index 416027ad998c7e..6365652cb7d1c6 100644 --- a/tensorflow/compiler/xla/service/gpu/amdgpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/amdgpu_compiler.cc @@ -20,9 +20,9 @@ limitations under the License. #include "tensorflow/compiler/xla/service/algebraic_simplifier.h" #include "tensorflow/compiler/xla/service/call_inliner.h" +#include "tensorflow/compiler/xla/service/gpu/conv_algorithm_picker.h" #include "tensorflow/compiler/xla/service/gpu/cusolver_rewriter.h" #include "tensorflow/compiler/xla/service/gpu/gemm_rewriter.h" -#include "tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.h" #include "tensorflow/compiler/xla/service/gpu/gpu_conv_padding_legalization.h" #include "tensorflow/compiler/xla/service/gpu/gpu_conv_rewriter.h" #include "tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h" diff --git a/tensorflow/compiler/xla/service/gpu/autotuner_compile_util.cc b/tensorflow/compiler/xla/service/gpu/autotuner_compile_util.cc index d78454ca7961ba..7ee4c0b2fba5fd 100644 --- a/tensorflow/compiler/xla/service/gpu/autotuner_compile_util.cc +++ b/tensorflow/compiler/xla/service/gpu/autotuner_compile_util.cc @@ -103,12 +103,14 @@ std::vector ExecutionInputsFromBuffers( } // namespace -AutotunerCompileUtil::AutotunerCompileUtil(Compiler* compiler, +AutotunerCompileUtil::AutotunerCompileUtil(const AutotuneConfig& config, + Compiler* compiler, se::StreamExecutor& stream_executor, se::Stream& stream, se::DeviceMemoryAllocator& allocator, const DebugOptions& opts) - : compiler_(compiler), + : config_(config), + compiler_(compiler), stream_executor_(stream_executor), stream_(stream), allocator_(allocator), @@ -213,14 +215,19 @@ StatusOr> AutotunerCompileUtil::CompileNoCache( return out; } -/*static*/ StatusOr AutotunerCompileUtil::Create( - se::Stream& stream, se::DeviceMemoryAllocator& allocator, - const DebugOptions& opts) { - se::StreamExecutor& stream_executor = *stream.parent(); +/*static*/ StatusOr> +AutotunerCompileUtil::Create(const AutotuneConfig& config, + const DebugOptions& opts) { + if (config.IsDeviceless()) { + return std::nullopt; + } + se::StreamExecutor* stream_exec = config.GetExecutor(); + se::DeviceMemoryAllocator* allocator = config.GetAllocator(); + TF_ASSIGN_OR_RETURN(se::Stream* const stream, config.GetStream()); TF_ASSIGN_OR_RETURN(Compiler * compiler, - Compiler::GetForPlatform(stream_executor.platform())); - return AutotunerCompileUtil(compiler, stream_executor, stream, allocator, - opts); + Compiler::GetForPlatform(stream_exec->platform())); + return AutotunerCompileUtil(config, compiler, *stream_exec, *stream, + *allocator, opts); } StatusOr AutotunerCompileUtil::Execute( diff --git a/tensorflow/compiler/xla/service/gpu/autotuner_compile_util.h b/tensorflow/compiler/xla/service/gpu/autotuner_compile_util.h index c9484bb61faa21..7de054f6841221 100644 --- a/tensorflow/compiler/xla/service/gpu/autotuner_compile_util.h +++ b/tensorflow/compiler/xla/service/gpu/autotuner_compile_util.h @@ -55,9 +55,11 @@ class AutotunerCompileUtil { absl::AnyInvocable>()>; // Generates a compile util for a platform associated with the `stream`. - static StatusOr Create( - se::Stream& stream, se::DeviceMemoryAllocator& allocator, - const DebugOptions& opts); + // + // Returns an empty optional if the AutotuneConfig is deviceless, as + // autotuning is impossible in that case. + static StatusOr> Create( + const AutotuneConfig& config, const DebugOptions& opts); // Generates an executable first, given the module generator function in // `extractor`. @@ -86,8 +88,9 @@ class AutotunerCompileUtil { static void ClearCompilationCache(); private: - AutotunerCompileUtil(Compiler* compiler, se::StreamExecutor& stream_executor, - se::Stream& stream, se::DeviceMemoryAllocator& allocator, + AutotunerCompileUtil(const AutotuneConfig& config, Compiler* compiler, + se::StreamExecutor& stream_executor, se::Stream& stream, + se::DeviceMemoryAllocator& allocator, const DebugOptions& opts); StatusOr> CompileNoCache( @@ -96,6 +99,7 @@ class AutotunerCompileUtil { StatusOr Execute(Executable& executable, std::vector arguments); + AutotuneConfig config_; Compiler* compiler_; se::StreamExecutor& stream_executor_; se::Stream& stream_; diff --git a/tensorflow/compiler/xla/service/gpu/autotuner_util.h b/tensorflow/compiler/xla/service/gpu/autotuner_util.h index c50ad60e2c3601..d70c51429027c4 100644 --- a/tensorflow/compiler/xla/service/gpu/autotuner_util.h +++ b/tensorflow/compiler/xla/service/gpu/autotuner_util.h @@ -123,7 +123,13 @@ class AutotuneConfig { se::DeviceMemoryAllocator* GetAllocator() const { CHECK(std::holds_alternative(config_)); - return std::get(config_).allocator; + auto& cf = std::get(config_); + return cf.allocator ? cf.allocator : GetExecutor()->GetAllocator(); + } + + StatusOr GetStream() const { + CHECK(std::holds_alternative(config_)); + return GetAllocator()->GetStream(GetExecutor()->device_ordinal()); } se::CudaComputeCapability GetCudaComputeCapability() const { diff --git a/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/conv_algorithm_picker.cc similarity index 98% rename from tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc rename to tensorflow/compiler/xla/service/gpu/conv_algorithm_picker.cc index fb2af10e7b3880..c7867e4359cd17 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc +++ b/tensorflow/compiler/xla/service/gpu/conv_algorithm_picker.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.h" +#include "tensorflow/compiler/xla/service/gpu/conv_algorithm_picker.h" #include #include @@ -360,18 +360,9 @@ StatusOr GpuConvAlgorithmPicker::PickBestAlgorithmNoCache( // allocator either points to this->allocator_ or, if that's null, to a // se::StreamExecutorMemoryAllocator for stream_exec. - se::DeviceMemoryAllocator* device_allocator = config_.GetAllocator(); - se::DeviceMemoryAllocator* allocator; - optional se_allocator; - if (device_allocator != nullptr) { - allocator = device_allocator; - } else { - se_allocator.emplace(stream_exec); - allocator = &*se_allocator; - } + se::DeviceMemoryAllocator* allocator = config_.GetAllocator(); - TF_ASSIGN_OR_RETURN(se::Stream* const stream, - allocator->GetStream(stream_exec->device_ordinal())); + TF_ASSIGN_OR_RETURN(se::Stream* const stream, config_.GetStream()); StatusOr result_or(InternalError("Unknown platform.")); // Check StreamExecutor on which platform it is. ROCm and Cuda implementation // have diverged. Specifically, we need to make sure redzone allocator related diff --git a/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.h b/tensorflow/compiler/xla/service/gpu/conv_algorithm_picker.h similarity index 96% rename from tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.h rename to tensorflow/compiler/xla/service/gpu/conv_algorithm_picker.h index 4b903e228755c2..256318c6e84f06 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.h +++ b/tensorflow/compiler/xla/service/gpu/conv_algorithm_picker.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_CONV_ALGORITHM_PICKER_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_CONV_ALGORITHM_PICKER_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CONV_ALGORITHM_PICKER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CONV_ALGORITHM_PICKER_H_ #include #include @@ -162,4 +162,4 @@ class GpuConvAlgorithmPicker : public HloModulePass { } // namespace gpu } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_CONV_ALGORITHM_PICKER_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CONV_ALGORITHM_PICKER_H_ diff --git a/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker_test.cc b/tensorflow/compiler/xla/service/gpu/conv_algorithm_picker_test.cc similarity index 98% rename from tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker_test.cc rename to tensorflow/compiler/xla/service/gpu/conv_algorithm_picker_test.cc index 3a9c0ad4919258..8e352e5e428645 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker_test.cc +++ b/tensorflow/compiler/xla/service/gpu/conv_algorithm_picker_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.h" +#include "tensorflow/compiler/xla/service/gpu/conv_algorithm_picker.h" #include "tensorflow/compiler/xla/service/gpu/gpu_conv_rewriter.h" #include "tensorflow/compiler/xla/service/pattern_matcher.h" diff --git a/tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.cc index 4f4fe23d379d4f..1720e4e5805913 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.cc @@ -244,13 +244,8 @@ StatusOr DoGemmAutotuneNoCache( } VLOG(3) << "Starting autotune of GemmThunk " << gemm->ToString(); - se::StreamExecutor* executor = autotune_config.GetExecutor(); se::DeviceMemoryAllocator* allocator = autotune_config.GetAllocator(); - if (allocator == nullptr) { - allocator = executor->GetAllocator(); - } - TF_ASSIGN_OR_RETURN(se::Stream* const stream, - allocator->GetStream(executor->device_ordinal())); + TF_ASSIGN_OR_RETURN(se::Stream* const stream, autotune_config.GetStream()); GemmBackendConfig gemm_config = gemm->backend_config().value(); const DebugOptions& debug_options = diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc index ed050a66028ee8..458a46e5313b1f 100644 --- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc @@ -38,6 +38,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/float_normalization.h" #include "tensorflow/compiler/xla/service/float_support.h" #include "tensorflow/compiler/xla/service/gpu/autotuner_util.h" +#include "tensorflow/compiler/xla/service/gpu/conv_algorithm_picker.h" #include "tensorflow/compiler/xla/service/gpu/cublas_cudnn.h" #include "tensorflow/compiler/xla/service/gpu/cublas_pad_for_gemms.h" #include "tensorflow/compiler/xla/service/gpu/cublas_padding_requirements.h" @@ -49,7 +50,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/cusolver_rewriter.h" #include "tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.h" #include "tensorflow/compiler/xla/service/gpu/gpu_asm_opts_util.h" -#include "tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.h" #include "tensorflow/compiler/xla/service/gpu/gpu_conv_padding_legalization.h" #include "tensorflow/compiler/xla/service/gpu/gpu_conv_rewriter.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" diff --git a/tensorflow/compiler/xla/service/gpu/runtime/BUILD b/tensorflow/compiler/xla/service/gpu/runtime/BUILD index ebaba4c2b36b2d..8224b840eec663 100644 --- a/tensorflow/compiler/xla/service/gpu/runtime/BUILD +++ b/tensorflow/compiler/xla/service/gpu/runtime/BUILD @@ -81,7 +81,7 @@ cc_library( "@com_google_absl//absl/synchronization", "@llvm-project//llvm:Support", ] + if_cuda_is_configured([ - "//tensorflow/compiler/xla/service/gpu:gpu_conv_algorithm_picker", + "//tensorflow/compiler/xla/service/gpu:conv_algorithm_picker", ]), ) diff --git a/tensorflow/compiler/xla/service/gpu/runtime/conv.cc b/tensorflow/compiler/xla/service/gpu/runtime/conv.cc index 13a093c4831fa4..9663da9faef1db 100644 --- a/tensorflow/compiler/xla/service/gpu/runtime/conv.cc +++ b/tensorflow/compiler/xla/service/gpu/runtime/conv.cc @@ -38,7 +38,7 @@ limitations under the License. #if GOOGLE_CUDA #include "tensorflow/compiler/xla/service/gpu/autotuner_util.h" -#include "tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.h" +#include "tensorflow/compiler/xla/service/gpu/conv_algorithm_picker.h" #endif namespace xla { diff --git a/tensorflow/compiler/xla/service/gpu/triton_autotuner.cc b/tensorflow/compiler/xla/service/gpu/triton_autotuner.cc index bb8468b0782193..b1cbc2f5551f5c 100644 --- a/tensorflow/compiler/xla/service/gpu/triton_autotuner.cc +++ b/tensorflow/compiler/xla/service/gpu/triton_autotuner.cc @@ -474,21 +474,9 @@ StatusOr TritonAutotuner::Run( return false; } - std::optional autotuner_compile_util; - if (!config_.IsDeviceless()) { - se::StreamExecutor* stream_exec = config_.GetExecutor(); - se::DeviceMemoryAllocator* allocator = config_.GetAllocator() - ? config_.GetAllocator() - : stream_exec->GetAllocator(); - TF_ASSIGN_OR_RETURN(se::Stream* const stream, - allocator->GetStream(stream_exec->device_ordinal())); - TF_ASSIGN_OR_RETURN( - AutotunerCompileUtil util, - AutotunerCompileUtil::Create(*stream, *allocator, - module->config().debug_options())); - autotuner_compile_util.emplace(util); - } - + TF_ASSIGN_OR_RETURN( + std::optional autotuner_compile_util, + AutotunerCompileUtil::Create(config_, module->config().debug_options())); return TritonAutotunerVisitor{config_, thread_pool_, autotuner_compile_util} .RunOnModule(module, execution_threads); } From 3a5c13738f8e52a5b4235434e6649a37d1d1e318 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 12 Jul 2023 08:13:57 -0700 Subject: [PATCH 191/376] Do not run `concat_ops_test` in asan config due to timeouts. PiperOrigin-RevId: 547499437 --- tensorflow/compiler/tests/BUILD | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 4a4c116fb9bdf2..5a887a5d238f35 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -536,6 +536,7 @@ tf_xla_py_strict_test( python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + "noasan", # Timed out on 2023-07-12 ], deps = [ ":xla_test", From ff8abca792d446603119bad04fc00e093e0cf20f Mon Sep 17 00:00:00 2001 From: George Karpenkov Date: Wed, 12 Jul 2023 08:50:38 -0700 Subject: [PATCH 192/376] [XLA:GPU] [NFC] Consolidate the logic to create RedzoneAllocator to a single function It was previously duplicated four times. Additionally, RedzoneAllocator is no longer CUDA-only. PiperOrigin-RevId: 547507918 --- .../xla/service/gpu/autotuner_util.cc | 16 ++++++ .../compiler/xla/service/gpu/autotuner_util.h | 8 ++- .../xla/service/gpu/conv_algorithm_picker.cc | 49 ++++++------------- .../xla/service/gpu/conv_algorithm_picker.h | 5 +- .../service/gpu/conv_algorithm_picker_test.cc | 3 +- .../xla/service/gpu/gemm_algorithm_picker.cc | 21 ++------ .../compiler/xla/service/gpu/runtime/conv.cc | 2 +- .../xla/service/gpu/triton_autotuner.cc | 9 ++-- .../stream_executor/gpu/redzone_allocator.cc | 8 ++- 9 files changed, 57 insertions(+), 64 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/autotuner_util.cc b/tensorflow/compiler/xla/service/gpu/autotuner_util.cc index 76037c8dd3b75f..69016104aa5b64 100644 --- a/tensorflow/compiler/xla/service/gpu/autotuner_util.cc +++ b/tensorflow/compiler/xla/service/gpu/autotuner_util.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/service/gpu/gpu_asm_opts_util.h" #include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h" namespace xla { @@ -259,5 +260,20 @@ AutotunerUtil::ExtractComputationIntoNewModule( return new_hlo_module; } +/*static*/ StatusOr AutotunerUtil::CreateRedzoneAllocator( + const AutotuneConfig& config, const DebugOptions& opts, + se::Stream* force_stream) { + se::Stream* stream = force_stream; + if (stream == nullptr) { + TF_ASSIGN_OR_RETURN(stream, config.GetStream()); + } + return se::RedzoneAllocator( + stream, config.GetAllocator(), PtxOptsFromDebugOptions(opts), + /*memory_limit=*/std::numeric_limits::max(), + /*redzone_size=*/config.should_check_correctness() + ? opts.xla_gpu_redzone_padding_bytes() + : 0); +} + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/autotuner_util.h b/tensorflow/compiler/xla/service/gpu/autotuner_util.h index d70c51429027c4..f759e1d8f2ca71 100644 --- a/tensorflow/compiler/xla/service/gpu/autotuner_util.h +++ b/tensorflow/compiler/xla/service/gpu/autotuner_util.h @@ -43,7 +43,7 @@ struct DeviceConfig { // If the `allocator` parameter is not null, we will use it to allocate temp // memory while timing the various convolution algorithms. If it's null, // we'll use the default allocator on the StreamExecutor. - se::DeviceMemoryAllocator* allocator; // may be null + se::DeviceMemoryAllocator* allocator = nullptr; // may be null }; struct DevicelessConfig { @@ -165,6 +165,12 @@ struct AutotunerUtil { const HloInstruction* instr, const AutotuneConfig& config, const AutotuneNoCacheFn& autotune_fn); + // Creates a RedzoneAllocator from a given config. If `force_stream` is + // provided, than it is used for checking redzones. + static StatusOr CreateRedzoneAllocator( + const AutotuneConfig& config, const DebugOptions& opts, + se::Stream* force_stream = nullptr); + // Functions to save/load XLA's autotuning results. // // This is used for ahead-of-time autotuning. Specifically: diff --git a/tensorflow/compiler/xla/service/gpu/conv_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/conv_algorithm_picker.cc index c7867e4359cd17..b263969a5ee2a5 100644 --- a/tensorflow/compiler/xla/service/gpu/conv_algorithm_picker.cc +++ b/tensorflow/compiler/xla/service/gpu/conv_algorithm_picker.cc @@ -262,7 +262,6 @@ void PrintPlatformInfo(const se::Stream* stream) { } } -#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) // Returns true if the redzones in `allocator`'s allocations are unmodified. // // If the redzones are modified, logs an error, sets the appropriate failure @@ -306,7 +305,6 @@ StatusOr CheckRedzones(const se::RedzoneAllocator& allocator, PrintPlatformInfo(stream); return false; } -#endif } // anonymous namespace @@ -371,15 +369,10 @@ StatusOr GpuConvAlgorithmPicker::PickBestAlgorithmNoCache( result_or = PickBestAlgorithmNoCacheRocm(instr, allocator, stream); } else if (stream_exec->platform_kind() == se::PlatformKind::kCuda) { #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) - // Right now Redzone allocator is available in Cuda target only. - auto hlo_module_config = instr->GetModule()->config(); - se::RedzoneAllocator input_output_allocator( - stream, allocator, - PtxOptsFromDebugOptions(hlo_module_config.debug_options()), - /*memory_limit=*/std::numeric_limits::max(), - ShouldCheckConv(hlo_module_config) - ? hlo_module_config.debug_options().xla_gpu_redzone_padding_bytes() - : 0); + DebugOptions debug_opts = instr->GetModule()->config().debug_options(); + TF_ASSIGN_OR_RETURN( + se::RedzoneAllocator input_output_allocator, + AutotunerUtil::CreateRedzoneAllocator(config_, debug_opts)); TF_ASSIGN_OR_RETURN( AutotuneRuntimeArguments runtime_arguments, @@ -510,19 +503,12 @@ StatusOr GpuConvAlgorithmPicker::AutotuneOneConvRunner( "Disqualified for implicit RELU."); } - const int64_t rz_space_limit = - runtime_arguments.hlo_module_config.debug_options() - .xla_gpu_redzone_scratch_max_megabytes() * - (1LL << 20); - se::RedzoneAllocator scratch_allocator( - stream, allocator, - PtxOptsFromDebugOptions( - runtime_arguments.hlo_module_config.debug_options()), - /*memory_limit=*/rz_space_limit, - ShouldCheckConv(runtime_arguments.hlo_module_config) - ? runtime_arguments.hlo_module_config.debug_options() - .xla_gpu_redzone_padding_bytes() - : 0); + TF_ASSIGN_OR_RETURN( + se::RedzoneAllocator scratch_allocator, + AutotunerUtil::CreateRedzoneAllocator( + config_, runtime_arguments.hlo_module_config.debug_options(), + stream)); + se::dnn::ProfileResult profile_result; VLOG(4) << "Trying algorithm " << alg.ToString() << " for " << instr_str; @@ -837,23 +823,20 @@ StatusOr GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheCuda( StatusOr GpuConvAlgorithmPicker::PickBestAlgorithmWithAllocatedBuffer( - const GpuConvConfig conv_config, + const AutotuneConfig& config, const GpuConvConfig conv_config, const ServiceExecutableRunOptions* run_options, - const DebugOptions* debug_options, + const DebugOptions& debug_options, const std::vector buffers, const se::DeviceMemoryBase result_buffer) { #if GOOGLE_CUDA Shape output_shape = conv_config.output_shape; HloModuleConfig hlo_module_config; - hlo_module_config.set_debug_options(*debug_options); + hlo_module_config.set_debug_options(debug_options); se::Stream* stream = run_options->stream(); se::DeviceMemoryAllocator* allocator = run_options->allocator(); - se::RedzoneAllocator input_output_allocator( - stream, allocator, PtxOptsFromDebugOptions(*debug_options), - /*memory_limit=*/std::numeric_limits::max(), - ShouldCheckConv(hlo_module_config) - ? debug_options->xla_gpu_redzone_padding_bytes() - : 0); + TF_ASSIGN_OR_RETURN( + se::RedzoneAllocator input_output_allocator, + AutotunerUtil::CreateRedzoneAllocator(config, debug_options, stream)); GpuConvAlgorithmPicker::AutotuneRuntimeArguments autotune_runtime_arguments = {output_shape, hlo_module_config, buffers, diff --git a/tensorflow/compiler/xla/service/gpu/conv_algorithm_picker.h b/tensorflow/compiler/xla/service/gpu/conv_algorithm_picker.h index 256318c6e84f06..0d265a2e7f680b 100644 --- a/tensorflow/compiler/xla/service/gpu/conv_algorithm_picker.h +++ b/tensorflow/compiler/xla/service/gpu/conv_algorithm_picker.h @@ -96,8 +96,9 @@ class GpuConvAlgorithmPicker : public HloModulePass { // Run autotuning on allocated buffers and pick the best algorithm. StatusOr PickBestAlgorithmWithAllocatedBuffer( - GpuConvConfig conv_config, const ServiceExecutableRunOptions* run_options, - const DebugOptions* debug_options, + const AutotuneConfig& config, GpuConvConfig conv_config, + const ServiceExecutableRunOptions* run_options, + const DebugOptions& debug_options, std::vector buffers, se::DeviceMemoryBase result_buffer); diff --git a/tensorflow/compiler/xla/service/gpu/conv_algorithm_picker_test.cc b/tensorflow/compiler/xla/service/gpu/conv_algorithm_picker_test.cc index 8e352e5e428645..73b2b6a8d52e2a 100644 --- a/tensorflow/compiler/xla/service/gpu/conv_algorithm_picker_test.cc +++ b/tensorflow/compiler/xla/service/gpu/conv_algorithm_picker_test.cc @@ -54,7 +54,8 @@ ENTRY main { bool changed = false; TF_ASSERT_OK_AND_ASSIGN(changed, RunHloPass(GpuConvRewriter(), m.get())); changed = false; - DebugOptions opts; + DebugOptions opts = DefaultDebugOptionsIgnoringFlags(); + AutotuneConfig cfg{DeviceConfig{stream_exec, nullptr}, opts}; TF_ASSERT_OK_AND_ASSIGN(changed, RunHloPass(GpuConvAlgorithmPicker(cfg), m.get())); diff --git a/tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.cc index 1720e4e5805913..98a01d2824694f 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.cc @@ -53,22 +53,6 @@ limitations under the License. namespace xla { namespace gpu { -#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) -static se::RedzoneAllocator CreateRedzoneAllocator( - se::Stream* stream, se::DeviceMemoryAllocator* allocator, - const DebugOptions& debug_options, const AutotuneConfig& config) { - // TODO(jlebar): The memory limit here should by rights be - // debug_options.xla_gpu_redzone_scratch_max_megabytes(), but tests OOM when - // we do that. Are the tests wrong, or is the option named incorrectly? - return se::RedzoneAllocator( - stream, allocator, PtxOptsFromDebugOptions(debug_options), - /*memory_limit=*/std::numeric_limits::max(), - /*redzone_size=*/config.should_check_correctness() - ? debug_options.xla_gpu_redzone_padding_bytes() - : 0); -} -#endif - // Returns the index (into `algorithms`) of the fastest algorithm. template StatusOr GetBestAlgorithm( @@ -256,8 +240,9 @@ StatusOr DoGemmAutotuneNoCache( // Don't run autotuning concurrently on the same GPU. absl::MutexLock gpu_lock(&GetGpuMutex(stream->parent())); - se::RedzoneAllocator buffer_allocator = - CreateRedzoneAllocator(stream, allocator, debug_options, autotune_config); + TF_ASSIGN_OR_RETURN( + se::RedzoneAllocator buffer_allocator, + AutotunerUtil::CreateRedzoneAllocator(autotune_config, debug_options)); int64_t rng_state = 0; TF_ASSIGN_OR_RETURN( diff --git a/tensorflow/compiler/xla/service/gpu/runtime/conv.cc b/tensorflow/compiler/xla/service/gpu/runtime/conv.cc index 9663da9faef1db..c125a110463332 100644 --- a/tensorflow/compiler/xla/service/gpu/runtime/conv.cc +++ b/tensorflow/compiler/xla/service/gpu/runtime/conv.cc @@ -404,7 +404,7 @@ static absl::Status ConvImpl( TF_ASSIGN_OR_RETURN( AutotuneResult best_algo, conv_algorithm_picker.PickBestAlgorithmWithAllocatedBuffer( - gpu_conv_config, run_options, debug_options, buffers, + config, gpu_conv_config, run_options, *debug_options, buffers, result_buffer)); // Set algorithm in the convolution runner state. diff --git a/tensorflow/compiler/xla/service/gpu/triton_autotuner.cc b/tensorflow/compiler/xla/service/gpu/triton_autotuner.cc index b1cbc2f5551f5c..ccaf09e6816c54 100644 --- a/tensorflow/compiler/xla/service/gpu/triton_autotuner.cc +++ b/tensorflow/compiler/xla/service/gpu/triton_autotuner.cc @@ -148,12 +148,9 @@ class TritonAutotunerVisitor : public DfsHloRewriteVisitor { const DebugOptions& debug_opts = fusion.parent()->config().debug_options(); - se::RedzoneAllocator rz_allocator( - stream, allocator, PtxOptsFromDebugOptions(debug_opts), - /*memory_limit=*/std::numeric_limits::max(), - /*redzone_size=*/config_.should_check_correctness() - ? debug_opts.xla_gpu_redzone_padding_bytes() - : 0); + TF_ASSIGN_OR_RETURN( + se::RedzoneAllocator rz_allocator, + AutotunerUtil::CreateRedzoneAllocator(config_, debug_opts)); se::DeviceMemoryBase reference_buffer; if (config_.should_check_correctness()) { diff --git a/tensorflow/compiler/xla/stream_executor/gpu/redzone_allocator.cc b/tensorflow/compiler/xla/stream_executor/gpu/redzone_allocator.cc index 1ab21ed78506ab..cfd3de8bb50f21 100644 --- a/tensorflow/compiler/xla/stream_executor/gpu/redzone_allocator.cc +++ b/tensorflow/compiler/xla/stream_executor/gpu/redzone_allocator.cc @@ -52,7 +52,7 @@ using RedzoneCheckStatus = RedzoneAllocator::RedzoneCheckStatus; RedzoneAllocator::RedzoneAllocator(Stream* stream, DeviceMemoryAllocator* memory_allocator, - GpuAsmOpts ptx_compilation_opts, + GpuAsmOpts gpu_compilation_opts, int64_t memory_limit, int64_t redzone_size, uint8_t redzone_pattern) : device_ordinal_(stream->parent()->device_ordinal()), @@ -63,7 +63,7 @@ RedzoneAllocator::RedzoneAllocator(Stream* stream, static_cast(tsl::Allocator::kAllocatorAlignment))), redzone_pattern_(redzone_pattern), memory_allocator_(memory_allocator), - gpu_compilation_opts_(ptx_compilation_opts) {} + gpu_compilation_opts_(gpu_compilation_opts) {} tsl::StatusOr> RedzoneAllocator::AllocateBytes( int64_t byte_size) { @@ -223,6 +223,10 @@ static tsl::Status RunRedzoneChecker( const ComparisonKernelT& comparison_kernel) { StreamExecutor* executor = stream->parent(); + if (redzone.size() == 0) { + return tsl::OkStatus(); + } + int64_t num_elements = redzone.size(); int64_t threads_per_block = std::min( executor->GetDeviceDescription().threads_per_block_limit(), num_elements); From b853df2707f72021a880f482f3a394ffdf9df715 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 12 Jul 2023 09:10:17 -0700 Subject: [PATCH 193/376] [XLA StreamExecutor TPU] Unit tests for c_api_conversions.h PiperOrigin-RevId: 547513154 --- .../compiler/xla/stream_executor/tpu/BUILD | 25 +- .../stream_executor/tpu/c_api_conversions.cc | 166 +++++---- .../stream_executor/tpu/c_api_conversions.h | 35 +- .../tpu/c_api_conversions_test.cc | 350 ++++++++++++++++++ tensorflow/core/tpu/BUILD | 1 + 5 files changed, 482 insertions(+), 95 deletions(-) create mode 100644 tensorflow/compiler/xla/stream_executor/tpu/c_api_conversions_test.cc diff --git a/tensorflow/compiler/xla/stream_executor/tpu/BUILD b/tensorflow/compiler/xla/stream_executor/tpu/BUILD index a0114e09f4e3c5..958c4430d34792 100644 --- a/tensorflow/compiler/xla/stream_executor/tpu/BUILD +++ b/tensorflow/compiler/xla/stream_executor/tpu/BUILD @@ -2,6 +2,7 @@ load("//tensorflow/tsl:tsl.bzl", "set_external_visibility") load("//tensorflow/tsl/platform:rules_cc.bzl", "cc_library") +load("//tensorflow/compiler/xla:xla.bzl", "xla_cc_test") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -58,18 +59,38 @@ cc_library( "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla/hlo/ir:hlo", - "//tensorflow/compiler/xla/service:executable", "//tensorflow/compiler/xla/service:hlo_module_config", "//tensorflow/compiler/xla/service:maybe_owning_device_memory", "//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/compiler/xla/stream_executor:device_memory", "//tensorflow/compiler/xla/stream_executor:device_memory_allocator", - "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/types:span", ], ) +xla_cc_test( + name = "c_api_conversions_test", + srcs = ["c_api_conversions_test.cc"], + deps = [ + ":c_api_conversions", + ":c_api_decl", + "//tensorflow/compiler/xla:executable_run_options", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:xla_proto_cc", + "//tensorflow/compiler/xla/hlo/ir:hlo", + "//tensorflow/compiler/xla/service:hlo_module_config", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/service:hlo_proto_cc", + "//tensorflow/tsl/platform:protobuf", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest_main", + ], +) + cc_library( name = "libtftpu_header", hdrs = ["libtftpu.h"], diff --git a/tensorflow/compiler/xla/stream_executor/tpu/c_api_conversions.cc b/tensorflow/compiler/xla/stream_executor/tpu/c_api_conversions.cc index a7b8f793cca724..518fd6b0dd4c8d 100644 --- a/tensorflow/compiler/xla/stream_executor/tpu/c_api_conversions.cc +++ b/tensorflow/compiler/xla/stream_executor/tpu/c_api_conversions.cc @@ -16,11 +16,14 @@ limitations under the License. #include "tensorflow/compiler/xla/stream_executor/tpu/c_api_conversions.h" #include +#include #include #include #include #include "absl/types/span.h" +#include "tensorflow/compiler/xla/service/hlo_module_config.h" +#include "tensorflow/compiler/xla/stream_executor/device_memory_allocator.h" #include "tensorflow/compiler/xla/stream_executor/tpu/c_api_decl.h" #include "tensorflow/compiler/xla/stream_executor/tpu/c_api_defn.h" #include "tensorflow/compiler/xla/stream_executor/tpu/tpu_api.h" @@ -29,6 +32,90 @@ limitations under the License. namespace ApiConverter { +// Helper functions for copying data to possibly-inlined C arrays. + +// 'Src' and 'Dst' are allowed to be different types to make this usable with +// memory-identical types, e.g. int64_t and int64_t. This should not be used +// with types that require a static_cast. +template +static void CreateVectorBase(const absl::Span src, DstList* dst) { + dst->size = src.size(); + if (dst->size > TPU_C_API_MAX_INLINED) { + dst->heap = new Dst[dst->size]; + std::copy(src.begin(), src.end(), dst->heap); + } else { + std::copy(src.begin(), src.end(), dst->inlined); + } +} + +void CreateVector(const absl::Span src, IntList* dst) { + return CreateVectorBase(src, dst); +} + +void CreateVector(const absl::Span src, Int64List* dst) { + return CreateVectorBase(src, dst); +} + +void CreateVector(const absl::Span src, FloatList* dst) { + return CreateVectorBase(src, dst); +} + +void CreateVector(const absl::Span src, BoolList* dst) { + return CreateVectorBase(src, dst); +} + +void CreateVector(const absl::Span src, IntList* dst) { + CreateVectorBase(src, dst); +} + +static void CreateVector(const absl::Span src, IntList* dst) { + CreateVectorBase(src, dst); +} + +static void CreateVector(const absl::Span src, TileList* dst) { + dst->size = src.size(); + XLA_Tile* c_tiles; + if (dst->size > TPU_C_API_MAX_INLINED) { + dst->heap = new XLA_Tile[dst->size]; + c_tiles = dst->heap; + } else { + c_tiles = dst->inlined; + } + for (int i = 0; i < dst->size; ++i) { + ToC(src[i], &c_tiles[i]); + } +} + +// Helper functions for creating a view of possibly-inlined C arrays. + +// 'Src' and 'Dst' are allowed to be different types to make this usable with +// memory-identical types, e.g. int64_t and int64_t. This should not be used +// with types that require a static_cast. +template +static absl::Span MakeSpanBase(const SrcList& src_list) { + static_assert(sizeof(Src) == sizeof(Dst), "Mismatched types"); + const Src* src = src_list.size > TPU_C_API_MAX_INLINED ? src_list.heap + : &src_list.inlined[0]; + return absl::Span(reinterpret_cast(src), + src_list.size); +} + +absl::Span MakeSpan(const IntList& src_list) { + return MakeSpanBase(src_list); +} + +absl::Span MakeSpan(const Int64List& src_list) { + return MakeSpanBase(src_list); +} + +absl::Span MakeSpan(const FloatList& src_list) { + return MakeSpanBase(src_list); +} + +absl::Span MakeSpan(const BoolList& src_list) { + return MakeSpanBase(src_list); +} + xla::ShapedBuffer FromC(XLA_ShapedBuffer* c_buffer) { xla::Shape xla_on_device_shape = ApiConverter::FromC(&c_buffer->on_device_shape); @@ -154,85 +241,6 @@ stream_executor::DeviceMemoryBase FromC(const SE_DeviceMemoryBase& se_base) { return base; } -// Helper functions for copying data to possibly-inlined C arrays. - -// 'Src' and 'Dst' are allowed to be different types to make this usable with -// memory-identical types, e.g. int64_t and int64_t. This should not be used -// with types that require a static_cast. -template -static void CreateVectorBase(const absl::Span src, DstList* dst) { - dst->size = src.size(); - if (dst->size > TPU_C_API_MAX_INLINED) { - dst->heap = new Dst[dst->size]; - std::copy(src.begin(), src.end(), dst->heap); - } else { - std::copy(src.begin(), src.end(), dst->inlined); - } -} - -void CreateVector(const absl::Span src, IntList* dst) { - return CreateVectorBase(src, dst); -} -void CreateVector(const absl::Span src, Int64List* dst) { - return CreateVectorBase(src, dst); -} -void CreateVector(const absl::Span src, FloatList* dst) { - return CreateVectorBase(src, dst); -} -void CreateVector(const absl::Span src, BoolList* dst) { - return CreateVectorBase(src, dst); -} -static void CreateVector(const absl::Span src, - IntList* dst) { - CreateVectorBase(src, dst); -} -static void CreateVector(const absl::Span src, IntList* dst) { - CreateVectorBase(src, dst); -} - -static void CreateVector(const absl::Span src, TileList* dst) { - dst->size = src.size(); - XLA_Tile* c_tiles; - if (dst->size > TPU_C_API_MAX_INLINED) { - dst->heap = new XLA_Tile[dst->size]; - c_tiles = dst->heap; - } else { - c_tiles = dst->inlined; - } - for (int i = 0; i < dst->size; ++i) { - ToC(src[i], &c_tiles[i]); - } -} - -// Helper functions for creating a view of possibly-inlined C arrays. - -// 'Src' and 'Dst' are allowed to be different types to make this usable with -// memory-identical types, e.g. int64_t and int64_t. This should not be used -// with types that require a static_cast. -template -static absl::Span MakeSpanBase(const SrcList& src_list) { - static_assert(sizeof(Src) == sizeof(Dst), "Mismatched types"); - const Src* src = src_list.size > TPU_C_API_MAX_INLINED ? src_list.heap - : &src_list.inlined[0]; - return absl::Span(reinterpret_cast(src), - src_list.size); -} - -absl::Span MakeSpan(const IntList& src_list) { - return MakeSpanBase(src_list); -} - -absl::Span MakeSpan(const Int64List& src_list) { - return MakeSpanBase(src_list); -} - -absl::Span MakeSpan(const FloatList& src_list) { - return MakeSpanBase(src_list); -} -absl::Span MakeSpan(const BoolList& src_list) { - return MakeSpanBase(src_list); -} - void ToC(const xla::Shape& xla_shape, XLA_Shape* c_shape) { c_shape->element_type = xla_shape.element_type(); diff --git a/tensorflow/compiler/xla/stream_executor/tpu/c_api_conversions.h b/tensorflow/compiler/xla/stream_executor/tpu/c_api_conversions.h index e1d846739f1fd0..9e4aa9ab2e1b90 100644 --- a/tensorflow/compiler/xla/stream_executor/tpu/c_api_conversions.h +++ b/tensorflow/compiler/xla/stream_executor/tpu/c_api_conversions.h @@ -16,18 +16,20 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_TPU_C_API_CONVERSIONS_H_ #define TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_TPU_C_API_CONVERSIONS_H_ -#include "absl/container/inlined_vector.h" +#include +#include +#include + #include "tensorflow/compiler/xla/executable_run_options.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_module.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h" -#include "tensorflow/compiler/xla/service/service_executable_run_options.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" #include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/stream_executor/device_memory.h" -#include "tensorflow/compiler/xla/stream_executor/device_memory_allocator.h" #include "tensorflow/compiler/xla/stream_executor/tpu/c_api_decl.h" #include "tensorflow/compiler/xla/stream_executor/tpu/tpu_executor_c_api.h" @@ -36,14 +38,19 @@ limitations under the License. namespace ApiConverter { absl::Span MakeSpan(const FloatList& src_list); -void CreateVector(const absl::Span src, FloatList* dst); +void CreateVector(absl::Span src, FloatList* dst); void Destroy(FloatList* float_list); absl::Span MakeSpan(const Int64List& src_list); -void CreateVector(const absl::Span src, Int64List* dst); +void CreateVector(absl::Span src, Int64List* dst); + +absl::Span MakeSpan(const IntList& src_list); +void CreateVector(absl::Span src, IntList* dst); absl::Span MakeSpan(const BoolList& src_list); -void CreateVector(const absl::Span src, BoolList* dst); +void CreateVector(absl::Span src, BoolList* dst); + +void CreateVector(absl::Span src, IntList* dst); // se::DeviceMemoryBase SE_DeviceMemoryBase ToC(const stream_executor::DeviceMemoryBase& base); @@ -52,20 +59,20 @@ void ToC(const stream_executor::DeviceMemoryBase& base, stream_executor::DeviceMemoryBase FromC(const SE_DeviceMemoryBase& se_base); void Destroy(SE_DeviceMemoryBase*); -// xla::Shape -xla::Shape FromC(const XLA_Shape* c_shape); -void ToC(const xla::Shape& xla_shape, XLA_Shape* c_shape); -void Destroy(XLA_Shape* c_shape); +// xla::Tile +xla::Tile FromC(const XLA_Tile* c_tile); +void ToC(const xla::Tile& xla_tile, XLA_Tile* c_tile); +void Destroy(XLA_Tile* c_tile); // xla::Layout xla::Layout FromC(const XLA_Layout* c_layout); void ToC(const xla::Layout& xla_layout, XLA_Layout* c_layout); void Destroy(XLA_Layout* c_layout); -// xla::Tile -xla::Tile FromC(const XLA_Tile* c_tile); -void ToC(const xla::Tile& xla_tile, XLA_Tile* c_tile); -void Destroy(XLA_Tile* c_tile); +// xla::Shape +xla::Shape FromC(const XLA_Shape* c_shape); +void ToC(const xla::Shape& xla_shape, XLA_Shape* c_shape); +void Destroy(XLA_Shape* c_shape); // xla::ShapeIndex XLA_ShapeIndex ToC(const xla::ShapeIndex& xla_shape); diff --git a/tensorflow/compiler/xla/stream_executor/tpu/c_api_conversions_test.cc b/tensorflow/compiler/xla/stream_executor/tpu/c_api_conversions_test.cc new file mode 100644 index 00000000000000..333cb4066b534e --- /dev/null +++ b/tensorflow/compiler/xla/stream_executor/tpu/c_api_conversions_test.cc @@ -0,0 +1,350 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/stream_executor/tpu/c_api_conversions.h" + +#include +#include + +#include +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "tensorflow/compiler/xla/executable_run_options.h" +#include "tensorflow/compiler/xla/hlo/ir/hlo_module.h" +#include "tensorflow/compiler/xla/layout.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/compiler/xla/service/hlo_module_config.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/shape.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/stream_executor/tpu/c_api_decl.h" +#include "tensorflow/compiler/xla/xla.pb.h" +#include "tensorflow/tsl/platform/protobuf.h" + +namespace ApiConverter { + +namespace { + +constexpr absl::string_view kHloString = + R"( +HloModule TupleCreate_module: +ENTRY %TupleCreate.v4 (v1: f32[], v2: f32[3], v3: f32[2,3]) -> (f32[], f32[3], f32[2,3]) { + %v1 = f32[] parameter(0) + %v2 = f32[3]{0} parameter(1) + %v3 = f32[2,3]{1,0} parameter(2) + ROOT %tuple = (f32[], f32[3]{0}, f32[2,3]{1,0}) tuple(f32[] %v1, f32[3]{0} %v2, f32[2,3]{1,0} %v3) +} +)"; + +TEST(XlaTile, ToCInlined) { + std::vector tile_dimensions{2, 3, 4, 5}; + xla::Tile cpp_tile(tile_dimensions); + XLA_Tile c_tile; + ToC(cpp_tile, &c_tile); + + absl::Span cpp_tile_dimensions = cpp_tile.dimensions(); + ASSERT_EQ(cpp_tile_dimensions, tile_dimensions); + absl::Span c_tile_dimensions = MakeSpan(c_tile.dimensions); + EXPECT_EQ(cpp_tile_dimensions, c_tile_dimensions); + + Destroy(&c_tile); +} + +TEST(XlaTile, ToCDynamic) { + std::vector tile_dimensions{2, 3, 4, 5, 6, 7, 8, 9}; + xla::Tile cpp_tile(tile_dimensions); + XLA_Tile c_tile; + ToC(cpp_tile, &c_tile); + + absl::Span cpp_tile_dimensions = cpp_tile.dimensions(); + ASSERT_EQ(cpp_tile_dimensions, tile_dimensions); + absl::Span c_tile_dimensions = MakeSpan(c_tile.dimensions); + EXPECT_EQ(cpp_tile_dimensions, c_tile_dimensions); + + Destroy(&c_tile); +} + +TEST(XlaTile, FromCInlined) { + constexpr size_t kInlinedSize = 4; + Int64List tile_dimensions; + tile_dimensions.size = kInlinedSize; + for (int i = 0; i < kInlinedSize; ++i) { + tile_dimensions.inlined[i] = i + 2; + } + XLA_Tile c_tile{tile_dimensions}; + xla::Tile cpp_tile = FromC(&c_tile); + auto cpp_dimensions = cpp_tile.dimensions(); + EXPECT_EQ(cpp_dimensions.size(), kInlinedSize); + for (int i = 0; i < kInlinedSize; ++i) { + EXPECT_EQ(cpp_dimensions[i], i + 2); + } + Destroy(&c_tile); +} + +TEST(XlaTile, FromCDynamic) { + constexpr size_t kDynamicSize = 8; + int64_t* dynamic = new int64_t[kDynamicSize]; + for (int i = 0; i < kDynamicSize; ++i) { + dynamic[i] = i + 2; + } + Int64List tile_dimensions; + tile_dimensions.size = kDynamicSize; + tile_dimensions.heap = dynamic; + XLA_Tile c_tile{tile_dimensions}; + xla::Tile cpp_tile = FromC(&c_tile); + auto cpp_dimensions = cpp_tile.dimensions(); + EXPECT_EQ(cpp_dimensions.size(), kDynamicSize); + for (int i = 0; i < kDynamicSize; ++i) { + EXPECT_EQ(cpp_dimensions[i], i + 2); + } + Destroy(&c_tile); +} + +namespace TestImpl { + +void XlaLayout_ToC(const xla::Layout& cpp_layout) { + XLA_Layout c_layout; + ToC(cpp_layout, &c_layout); + + absl::Span cpp_minor_to_major = cpp_layout.minor_to_major(); + absl::Span c_minor_to_major = + MakeSpan(c_layout.minor_to_major); + EXPECT_EQ(cpp_minor_to_major, c_minor_to_major); + + absl::Span cpp_dim_level_types = + cpp_layout.dim_level_types(); + absl::Span c_dim_level_types = MakeSpan(c_layout.dim_level_types); + EXPECT_EQ(cpp_dim_level_types.size(), c_dim_level_types.size()); + for (int i = 0; i < c_dim_level_types.size(); ++i) { + EXPECT_EQ(static_cast(cpp_dim_level_types[i]), c_dim_level_types[i]); + } + + absl::Span cpp_dim_unique = cpp_layout.dim_unique(); + absl::Span c_dim_unique = MakeSpan(c_layout.dim_unique); + EXPECT_EQ(cpp_dim_unique.size(), c_dim_unique.size()); + for (int i = 0; i < c_dim_unique.size(); ++i) { + EXPECT_EQ(cpp_dim_unique[i], static_cast(c_dim_unique[i])); + } + + absl::Span cpp_dim_ordered = cpp_layout.dim_ordered(); + absl::Span c_dim_ordered = MakeSpan(c_layout.dim_ordered); + EXPECT_EQ(cpp_dim_ordered.size(), c_dim_ordered.size()); + for (int i = 0; i < c_dim_ordered.size(); ++i) { + EXPECT_EQ(cpp_dim_ordered[i], static_cast(c_dim_ordered[i])); + } + + absl::Span cpp_tiles = cpp_layout.tiles(); + TileList c_tiles = c_layout.tiles; + EXPECT_EQ(cpp_tiles.size(), c_tiles.size); + XLA_Tile* tile_base = + (c_tiles.size > TPU_C_API_MAX_INLINED) ? c_tiles.heap : c_tiles.inlined; + for (int i = 0; i < c_tiles.size; ++i) { + xla::Tile converted_c_tile = FromC(&tile_base[i]); + EXPECT_EQ(cpp_tiles[i], converted_c_tile); + } + + EXPECT_EQ(cpp_layout.index_primitive_type(), c_layout.index_primitive_type); + EXPECT_EQ(cpp_layout.pointer_primitive_type(), + c_layout.pointer_primitive_type); + EXPECT_EQ(cpp_layout.element_size_in_bits(), c_layout.element_size_in_bits); + EXPECT_EQ(cpp_layout.memory_space(), c_layout.memory_space); + EXPECT_EQ(cpp_layout.dynamic_shape_metadata_prefix_bytes(), + c_layout.dynamic_shape_metadata_prefix_bytes); + + Destroy(&c_layout); +} + +} // namespace TestImpl + +TEST(XlaLayout, ToCScalar) { + xla::Shape cpp_shape = xla::ShapeUtil::MakeShapeWithType({4}); + xla::Layout cpp_layout = cpp_shape.layout(); + TestImpl::XlaLayout_ToC(cpp_layout); +} + +TEST(XlaLayout, ToCNested) { + xla::Shape cpp_shape = xla::ShapeUtil::MakeShapeWithType({4, 3, 2}); + xla::Layout cpp_layout = cpp_shape.layout(); + TestImpl::XlaLayout_ToC(cpp_layout); +} + +TEST(XlaLayout, FromCScalar) { + xla::Shape cpp_shape = xla::ShapeUtil::MakeShapeWithType({4}); + xla::Layout in_layout = cpp_shape.layout(); + XLA_Layout c_layout; + ToC(in_layout, &c_layout); + xla::Layout out_layout = FromC(&c_layout); + EXPECT_EQ(in_layout, out_layout); + Destroy(&c_layout); +} + +TEST(XlaLayout, FromCNested) { + xla::Shape cpp_shape = xla::ShapeUtil::MakeShapeWithType({4, 3, 2}); + xla::Layout in_layout = cpp_shape.layout(); + XLA_Layout c_layout; + ToC(in_layout, &c_layout); + xla::Layout out_layout = FromC(&c_layout); + EXPECT_EQ(in_layout, out_layout); + Destroy(&c_layout); +} + +TEST(XlaShape, ToCScalar) { + xla::Shape cpp_shape = xla::ShapeUtil::MakeShapeWithType({4}); + XLA_Shape c_shape; + ToC(cpp_shape, &c_shape); + + EXPECT_EQ(cpp_shape.element_type(), c_shape.element_type); + + absl::Span cpp_dimensions = cpp_shape.dimensions(); + absl::Span c_dimensions = MakeSpan(c_shape.dimensions); + EXPECT_EQ(cpp_dimensions, c_dimensions); + + absl::Span cpp_dynamic_dimensions = + cpp_shape.dynamic_dimensions(); + absl::Span c_dynamic_dimensions = + MakeSpan(c_shape.dynamic_dimensions); + EXPECT_EQ(cpp_dynamic_dimensions, c_dynamic_dimensions); + + int cpp_ntuple_shapes = cpp_shape.tuple_shapes_size(); + int c_ntuple_shapes = c_shape.ntuple_shapes; + EXPECT_EQ(cpp_ntuple_shapes, c_ntuple_shapes); + EXPECT_EQ(cpp_ntuple_shapes, 0); + + bool cpp_has_layout = cpp_shape.has_layout(); + bool c_has_layout = c_shape.has_layout; + EXPECT_EQ(cpp_has_layout, c_has_layout); + + Destroy(&c_shape); +} + +TEST(XlaShape, ToCNested) { + xla::Shape cpp_shape = xla::ShapeUtil::MakeShapeWithType({4, 3, 2}); + XLA_Shape c_shape; + ToC(cpp_shape, &c_shape); + + EXPECT_EQ(cpp_shape.element_type(), c_shape.element_type); + + absl::Span cpp_dimensions = cpp_shape.dimensions(); + absl::Span c_dimensions = MakeSpan(c_shape.dimensions); + EXPECT_EQ(cpp_dimensions, c_dimensions); + + absl::Span cpp_dynamic_dimensions = + cpp_shape.dynamic_dimensions(); + absl::Span c_dynamic_dimensions = + MakeSpan(c_shape.dynamic_dimensions); + EXPECT_EQ(cpp_dynamic_dimensions, c_dynamic_dimensions); + + int cpp_ntuple_shapes = cpp_shape.tuple_shapes_size(); + int c_ntuple_shapes = c_shape.ntuple_shapes; + EXPECT_EQ(cpp_ntuple_shapes, c_ntuple_shapes); + + const std::vector& cpp_tuple_shapes = cpp_shape.tuple_shapes(); + absl::Span c_tuple_shapes(c_shape.tuple_shapes, + c_ntuple_shapes); + for (int i = 0; i < c_ntuple_shapes; ++i) { + xla::Shape converted_c_shape = FromC(&c_tuple_shapes[i]); + EXPECT_EQ(cpp_tuple_shapes[i], converted_c_shape); + } + + bool cpp_has_layout = cpp_shape.has_layout(); + bool c_has_layout = c_shape.has_layout; + EXPECT_EQ(cpp_has_layout, c_has_layout); + + if (c_has_layout) { + xla::Layout converted_c_layout = FromC(&c_shape.layout); + EXPECT_EQ(cpp_shape.layout(), converted_c_layout); + } + + Destroy(&c_shape); +} + +TEST(XlaShape, FromCScalar) { + xla::Shape in_shape = xla::ShapeUtil::MakeShapeWithType({4}); + XLA_Shape c_shape; + ToC(in_shape, &c_shape); + xla::Shape out_shape = FromC(&c_shape); + EXPECT_EQ(in_shape, out_shape); + Destroy(&c_shape); +} + +TEST(XlaShape, FromCNested) { + xla::Shape in_shape = xla::ShapeUtil::MakeShapeWithType({4, 3, 2}); + XLA_Shape c_shape; + ToC(in_shape, &c_shape); + xla::Shape out_shape = FromC(&c_shape); + EXPECT_EQ(in_shape, out_shape); + Destroy(&c_shape); +} + +// TODO(b/290654348): xla::ShapeIndex, xla::Literal, xla::ShapedBuffer + +TEST(XlaHloModuleConfig, ToAndFromC) { + xla::StatusOr> hlo_module = + xla::ParseAndReturnUnverifiedModule(kHloString); + ASSERT_TRUE(hlo_module.ok()); + xla::HloModule& cpp_module = *hlo_module.value(); + xla::HloModuleConfig in_config = cpp_module.config(); + + XLA_HloModuleConfig c_config = ToC(in_config); + xla::HloModuleConfig out_config = FromC(c_config); + + TF_ASSERT_OK_AND_ASSIGN(xla::HloModuleConfigProto in_config_proto, + in_config.ToProto()); + TF_ASSERT_OK_AND_ASSIGN(xla::HloModuleConfigProto out_config_proto, + out_config.ToProto()); + + tsl::protobuf::util::MessageDifferencer diff; + diff.set_message_field_comparison( + tsl::protobuf::util::MessageDifferencer::EQUIVALENT); + EXPECT_TRUE(diff.Equals(in_config_proto, out_config_proto)); + + Destroy(&c_config); +} + +TEST(XlaHloModule, ToAndFromC) { + xla::StatusOr> hlo_module = + xla::ParseAndReturnUnverifiedModule(kHloString); + ASSERT_TRUE(hlo_module.ok()); + xla::HloModule& in_module = *hlo_module.value(); + + XLA_HloModule c_module = ToC(in_module); + xla::StatusOr> out_module_ptr = + FromC(c_module); + ASSERT_TRUE(out_module_ptr.ok()); + xla::HloModule& out_module = *out_module_ptr.value(); + + TF_ASSERT_OK_AND_ASSIGN(xla::HloModuleProtoWithConfig in_module_proto, + in_module.ToProtoWithConfig()); + TF_ASSERT_OK_AND_ASSIGN(xla::HloModuleProtoWithConfig out_module_proto, + out_module.ToProtoWithConfig()); + + tsl::protobuf::util::MessageDifferencer diff; + diff.set_message_field_comparison( + tsl::protobuf::util::MessageDifferencer::EQUIVALENT); + const auto* ignore_unique_id = + xla::HloModuleProto::GetDescriptor()->FindFieldByName("id"); + diff.IgnoreField(ignore_unique_id); + EXPECT_TRUE(diff.Compare(in_module_proto, out_module_proto)); + + Destroy(&c_module); +} + +// TODO(b/290654348): SE_DeviceMemoryBase, SE_DeviceMemoryAllocator, +// SE_MaybeOwningDeviceMemory + +} // namespace + +} // namespace ApiConverter diff --git a/tensorflow/core/tpu/BUILD b/tensorflow/core/tpu/BUILD index 096caff26ef884..1005b9d061b5fa 100644 --- a/tensorflow/core/tpu/BUILD +++ b/tensorflow/core/tpu/BUILD @@ -12,6 +12,7 @@ package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = [ "//tensorflow/compiler/mlir/tf2xla:__subpackages__", + "//tensorflow/compiler/xla:__subpackages__", "//tensorflow/compiler/xrt:__subpackages__", "//tensorflow/core/tpu:__subpackages__", "//tensorflow/dtensor:__subpackages__", From dbec0c00fc3ae9339d7e0518a2b6295ccc68b2d2 Mon Sep 17 00:00:00 2001 From: Andrew Goodbody Date: Wed, 12 Jul 2023 17:20:02 +0100 Subject: [PATCH 194/376] [Linaro:ARM_CI] Add broken test to skip list //tensorflow/compiler/xla/service/gpu:fusion_merger_test is broken and only works on x86 as a divide by zero is evaluated as -inf --- tensorflow/tools/ci_build/build_scripts/ARM_SKIP_TESTS.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorflow/tools/ci_build/build_scripts/ARM_SKIP_TESTS.sh b/tensorflow/tools/ci_build/build_scripts/ARM_SKIP_TESTS.sh index d2846da30469e8..8e8ccab623261c 100644 --- a/tensorflow/tools/ci_build/build_scripts/ARM_SKIP_TESTS.sh +++ b/tensorflow/tools/ci_build/build_scripts/ARM_SKIP_TESTS.sh @@ -16,6 +16,7 @@ set -x ARM_SKIP_TESTS="-//tensorflow/lite/... \ +-//tensorflow/compiler/xla/service/gpu:fusion_merger_test \ -//tensorflow/python/kernel_tests/nn_ops:atrous_conv2d_test \ -//tensorflow/python/kernel_tests/nn_ops:conv_ops_test \ " From 38a9f7d8a269b99e7f588a8f2ff19840f2edd5cf Mon Sep 17 00:00:00 2001 From: Justin Szaday Date: Wed, 12 Jul 2023 09:28:58 -0700 Subject: [PATCH 195/376] Correct layout order on top-level multi-device call op. PiperOrigin-RevId: 547517997 --- .../mlir/dtensor_multi_device_expansion.cc | 138 +++++++++++------- 1 file changed, 86 insertions(+), 52 deletions(-) diff --git a/tensorflow/dtensor/mlir/dtensor_multi_device_expansion.cc b/tensorflow/dtensor/mlir/dtensor_multi_device_expansion.cc index 92bafef7f7e252..d0793e348cf7f6 100644 --- a/tensorflow/dtensor/mlir/dtensor_multi_device_expansion.cc +++ b/tensorflow/dtensor/mlir/dtensor_multi_device_expansion.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include #include #include #include @@ -73,7 +75,20 @@ using ExpandedArgumentMap = absl::flat_hash_map>>; -using ExpandedResultsMap = absl::flat_hash_map>; +struct ExpandedResults { + std::optional layout; + std::vector results; + + template + void insert(Value&& value) { + using T = std::decay_t; + if constexpr (std::is_same_v) { + results.emplace_back(std::forward(value)); + } else { + results.insert(results.end(), value.begin(), value.end()); + } + } +}; mlir::BlockArgument InsertArgumentForDevice(mlir::OpBuilder& builder, mlir::func::FuncOp func, @@ -218,11 +233,11 @@ void AddMetadataToTPUCluster(const Mesh& mesh_config, int64_t num_devices, // into a cluster func that has partitioned inputs and outputs ops; // it will be rewritten by TPURewritePass into per-device TPUExecute ops. template -mlir::LogicalResult ExpandTPUOperation(mlir::func::FuncOp target_func, - mlir::func::ReturnOp return_op, - ExpandedArgumentMap& expanded_arguments, - ExpandedResultsMap& expanded_results, - const Mesh& target_mesh, Operation op) { +mlir::LogicalResult ExpandTPUOperation( + mlir::func::FuncOp target_func, mlir::func::ReturnOp return_op, + ExpandedArgumentMap& expanded_arguments, + std::vector& expanded_results, const Mesh& target_mesh, + Operation op) { const absl::Span devices = GetDevices(target_mesh); const std::size_t num_devices = devices.size(); @@ -288,9 +303,7 @@ mlir::LogicalResult ExpandTPUOperation(mlir::func::FuncOp target_func, const std::size_t result_number = search - results.begin(); const mlir::Operation::result_range replicated_results = replications.at(result_number); - expanded_results[i].insert(expanded_results[i].end(), - replicated_results.begin(), - replicated_results.end()); + expanded_results[i].insert(replicated_results); } } } @@ -302,11 +315,11 @@ mlir::LogicalResult ExpandTPUOperation(mlir::func::FuncOp target_func, // de/multiplexes the per-device inputs/outputs for each "expanded" op. // Only usable on CPU/GPU devices, which do not require additional rewriting. template -mlir::LogicalResult ExpandOperation(mlir::func::FuncOp target_func, - mlir::func::ReturnOp return_op, - ExpandedArgumentMap& expanded_arguments, - ExpandedResultsMap& expanded_results, - const Mesh& target_mesh, Operation op) { +mlir::LogicalResult ExpandOperation( + mlir::func::FuncOp target_func, mlir::func::ReturnOp return_op, + ExpandedArgumentMap& expanded_arguments, + std::vector& expanded_results, const Mesh& target_mesh, + Operation op) { mlir::OpBuilder builder(target_func.getBody()); const absl::Span devices = GetDevices(target_mesh); const std::size_t num_devices = devices.size(); @@ -355,8 +368,8 @@ mlir::LogicalResult ExpandOperation(mlir::func::FuncOp target_func, llvm::find(results, operand); const std::size_t result_number = search - results.begin(); for (const Operation& replication : replications) { - expanded_results[i].emplace_back( - replication->getResult(result_number)); + expanded_results[i].insert( + (mlir::Value)replication->getResult(result_number)); } } } @@ -479,46 +492,39 @@ struct InferredResourceAttributes { : layouts(layouts_), indices(indices_) {} }; -// Build a new main function that calls the multi-device/translated function. template -mlir::LogicalResult BuildOuterMainFunc(mlir::ModuleOp module, - mlir::func::FuncOp old_main_func, - mlir::func::FuncOp translated_func, - mlir::func::ReturnOp return_op, - mlir::ArrayAttr num_local_outputs_attr, - Operations&& call_ops) { - using CallOp = typename std::decay_t::value_type; - llvm::SmallVector output_layouts; - std::optional resource_attrs; - for (CallOp call_op : call_ops) { - // Then extract all their output layouts. - mlir::Attribute layout_attr = call_op->getAttr(kLayoutAttr); - mlir::ArrayAttr layouts = layout_attr.dyn_cast_or_null(); - if (!layouts) { - call_op.emitOpError() << "Could not find op's layouts."; - return mlir::failure(); - } +mlir::LogicalResult GetInferredResourceAttributes( + const Operations& call_ops, + std::optional* resource_attrs) { + for (auto call_op : call_ops) { // Set the resource layouts. mlir::Attribute resource_layouts_attr = call_op->getAttr(kNewResourceArgLayouts); mlir::Attribute resource_indices_attr = call_op->getAttr(kNewResourceLayoutIndices); if (resource_indices_attr && resource_layouts_attr) { - if (resource_attrs) { + if (resource_attrs->has_value()) { // TODO(twelve): Determine how to merge inferred resource attrs if there // are multiple sets of them. (when can that happen?) call_op.emitOpError() << "Multiple sets of inferred resource attributes!"; return mlir::failure(); } else { - resource_attrs.emplace(resource_layouts_attr, resource_indices_attr); + resource_attrs->emplace(resource_layouts_attr, resource_indices_attr); } } - // Here, we assume that the output layouts and the results are in the same - // ordering--this property should be guaranteed as long as all the results - // have been expanded (produced by ExpandOperation). - output_layouts.insert(output_layouts.end(), layouts.begin(), layouts.end()); } + return mlir::success(); +} + +// Build a new main function that calls the multi-device/translated function. +template +mlir::LogicalResult BuildOuterMainFunc( + mlir::ModuleOp module, mlir::func::FuncOp old_main_func, + mlir::func::FuncOp translated_func, mlir::func::ReturnOp return_op, + const std::vector& expanded_results, + mlir::ArrayAttr num_local_outputs_attr, Operations&& call_ops) { + using CallOp = typename std::decay_t::value_type; mlir::SymbolTable symbol_table(module); mlir::Block* module_body = module.getBody(); @@ -550,12 +556,19 @@ mlir::LogicalResult BuildOuterMainFunc(mlir::ModuleOp module, /*executor_type=*/builder.getStringAttr("")); // Set the output layout attribute on the new call op. - llvm::ArrayRef output_layouts_ref(output_layouts); - mlir::ArrayAttr output_layouts_attr = - builder.getArrayAttr(output_layouts_ref); - expanded_call_op->setAttr(kLayoutAttr, output_layouts_attr); + std::vector> output_layouts; + std::transform(expanded_results.begin(), expanded_results.end(), + std::back_inserter(output_layouts), + [](const ExpandedResults& result) { return result.layout; }); + SetLayoutOnOp(expanded_call_op, builder, output_layouts); + expanded_call_op->setAttr(kNumLocalOutputsAttr, num_local_outputs_attr); + std::optional resource_attrs; + if (failed(GetInferredResourceAttributes(call_ops, &resource_attrs))) { + return mlir::failure(); + } + if (resource_attrs) { expanded_call_op->setAttr(kNewResourceArgLayouts, resource_attrs->layouts); expanded_call_op->setAttr(kNewResourceLayoutIndices, @@ -586,6 +599,25 @@ mlir::LogicalResult BuildOuterMainFunc(mlir::ModuleOp module, return mlir::success(); } +Status ExtractResultLayouts(mlir::Operation* op, mlir::func::ReturnOp return_op, + std::vector& expanded_results) { + if (!return_op || (return_op.getNumOperands() == 0)) { + return OkStatus(); + } + TF_ASSIGN_OR_RETURN(std::vector> layouts, + ExtractLayoutFromOp(op)); + mlir::Operation::operand_range operands = return_op.getOperands(); + for (auto [layout_index, result] : llvm::enumerate(op->getResults())) { + auto search = std::find(operands.begin(), operands.end(), result); + if (search == operands.end()) { + continue; + } + std::size_t result_index = std::distance(operands.begin(), search); + expanded_results[result_index].layout = layouts[layout_index]; + } + return OkStatus(); +} + struct DTensorMultiDeviceExpansion : public impl::DTensorMultiDeviceExpansionBase< DTensorMultiDeviceExpansion> { @@ -656,11 +688,14 @@ struct DTensorMultiDeviceExpansion return; } - ExpandedResultsMap expanded_results; + std::vector expanded_results( + return_op ? return_op->getNumOperands() : 0); for (const mlir::TF::StatefulPartitionedCallOp& stateful_call_op : stateful_call_ops) { + const Status status = + ExtractResultLayouts(stateful_call_op, return_op, expanded_results); const StatusOr> mesh = - ExtractDeviceMeshFromOp(stateful_call_op); + status.ok() ? ExtractDeviceMeshFromOp(stateful_call_op) : status; if (!(mesh.ok() && *mesh)) { stateful_call_op->emitOpError("Failed to retrieve op mesh or layout."); return; @@ -690,13 +725,12 @@ struct DTensorMultiDeviceExpansion llvm::SmallVector num_local_outputs; if (return_op) { for (unsigned i = 0; i < return_op->getNumOperands(); ++i) { - ExpandedResultsMap::iterator search = expanded_results.find(i); + std::vector& values = expanded_results[i].results; int num_outputs; - if (search == expanded_results.end()) { + if (values.empty()) { results.emplace_back(return_op->getOperand(i)); num_outputs = 1; } else { - std::vector& values = search->second; results.insert(results.end(), values.begin(), values.end()); num_outputs = values.size(); } @@ -714,9 +748,9 @@ struct DTensorMultiDeviceExpansion builder, translated_func, absl::Span(results))); UpdateEntryFuncAttr(builder, translated_func); - mlir::LogicalResult status = - BuildOuterMainFunc(module, main_func, translated_func, return_op, - num_local_outputs_attr, stateful_call_ops); + mlir::LogicalResult status = BuildOuterMainFunc( + module, main_func, translated_func, return_op, expanded_results, + num_local_outputs_attr, stateful_call_ops); if (mlir::failed(status)) { return; } From 58e88879effc1ed045928c0ec59421712dd1f851 Mon Sep 17 00:00:00 2001 From: Justin Szaday Date: Wed, 12 Jul 2023 09:31:25 -0700 Subject: [PATCH 196/376] Mirror conditional structure between send and recv lowering paths. PiperOrigin-RevId: 547518588 --- tensorflow/dtensor/mlir/dtensor_send_recv.cc | 430 ++++++++++--------- 1 file changed, 216 insertions(+), 214 deletions(-) diff --git a/tensorflow/dtensor/mlir/dtensor_send_recv.cc b/tensorflow/dtensor/mlir/dtensor_send_recv.cc index 14e865fc789173..4631f2cfd7e905 100644 --- a/tensorflow/dtensor/mlir/dtensor_send_recv.cc +++ b/tensorflow/dtensor/mlir/dtensor_send_recv.cc @@ -538,99 +538,105 @@ StatusOr LowerDTensorSend(mlir::Operation* send_op, // Is tensor transfer is from TPU mesh to host mesh and send layout and recv // layout is identical, then tensor from each source device is sent to // target device asynchronously. + mlir::Operation* lowered_send; if (one_to_one && IsTpuToHostMeshTransfer(input_mesh, target_mesh)) { - return LowerDTensorSendToXlaOp(input_layout, dtensor_send.getInput(), - dtensor_send, - /*send_from_device_zero=*/false); + TF_ASSIGN_OR_RETURN(lowered_send, + LowerDTensorSendToXlaOp( + input_layout, dtensor_send.getInput(), dtensor_send, + /*send_from_device_zero=*/false)); } else if (one_to_one && IsGpuToHostMeshTransfer(input_mesh, target_mesh) && !recv_layout.IsFullyReplicated()) { - return LowerOneToOneDTensorSendToTFHostSend(input_layout, target_mesh, - dtensor_send); - } - - // Calculate input tensor layout of data to send and target fully replicated - // layout. For now, we ensure that all data transfer happen with fully - // replicated tensors. - const int rank = ValueRank(dtensor_send.getInput()); - const Layout target_layout = Layout::ReplicatedOnMesh(input_mesh, rank); - - // Convert tensor to send to replicated layout. - mlir::OpBuilder builder(dtensor_send); - TF_ASSIGN_OR_RETURN(mlir::Value send_input, - EmitAllGather(builder, dtensor_send.getInput(), - input_layout, target_layout)); - - // Insert control flow such that only device with device ordinal == 0 sends - // the tensor data across mesh. - auto send_cluster = - dtensor_send->getParentOfType(); - TF_ASSIGN_OR_RETURN(std::optional mesh, - ExtractDeviceMeshFromOp(send_cluster)); - if (!mesh.has_value()) - return errors::InvalidArgument( - "failed to lower DTensor CopyToMesh op as sending side mesh is not " - "specified."); - - mlir::Location loc = dtensor_send.getLoc(); - TF_ASSIGN_OR_RETURN( - mlir::Value device_ordinal, - GetDeviceOrdinal(*mesh, loc, - send_cluster->getParentOfType(), - &builder)); - mlir::Value predicate = builder.create( - loc, device_ordinal, CreateIntScalarConst(0, builder, loc), - /*incompatible_shape_error=*/builder.getBoolAttr(true)); - - auto send_if = builder.create( - loc, llvm::SmallVector{}, predicate, - /*is_stateless=*/builder.getBoolAttr(true), - GetUniqueControlflowFnName("copy_to_mesh_send_if_then", builder), - GetUniqueControlflowFnName("copy_to_mesh_send_if_else", builder)); - - // Create empty else branch region. - auto& else_branch = send_if.getElseBranch(); - else_branch.push_back(new mlir::Block); - builder.setInsertionPointToEnd(&else_branch.front()); - builder.create(loc, - /*operands=*/llvm::ArrayRef{}); - - // Create then branch region with DTensorSend op. - auto& then_branch = send_if.getThenBranch(); - then_branch.push_back(new mlir::Block); - builder.setInsertionPointToEnd(&then_branch.front()); - auto yield = builder.create( - loc, /*operands=*/llvm::ArrayRef{}); - dtensor_send->moveBefore(yield); - - // Lower DTensorSend op to actual TF op. - TF_ASSIGN_OR_RETURN(const Mesh recv_mesh, - ExtractDeviceMeshEnclosingCluster(recv_op)); - mlir::Operation* lowered_send; - if (SendRecvOpUsesXla(input_layout.mesh(), recv_mesh)) { - // Lower DTensorSend op to Xla Send ops. - TF_ASSIGN_OR_RETURN( - lowered_send, - LowerDTensorSendToXlaOp(input_layout, send_input, dtensor_send, - /*send_from_device_zero=*/true)); - } else if (input_layout.mesh().is_cpu_mesh() && recv_mesh.is_cpu_mesh()) { - // Lower DTensorSend op to TF Host Send op. - TF_ASSIGN_OR_RETURN( - lowered_send, - LowerDTensorSendFromCPUToTFOp(input_layout, send_input, dtensor_send)); + TF_ASSIGN_OR_RETURN(lowered_send, + LowerOneToOneDTensorSendToTFHostSend( + input_layout, target_mesh, dtensor_send)); } else { - mlir::TensorType send_type = send_input.getType().cast(); - if (!recv_mesh.is_cpu_mesh() && send_type.getElementType().isInteger(32)) { - builder.setInsertionPointAfter(send_input.getDefiningOp()); - auto cast_to_int64 = builder.create( - send_input.getLoc(), - mlir::RankedTensorType::get(send_type.getShape(), - builder.getIntegerType(64)), - send_input); - send_input = cast_to_int64->getResult(0); + // Calculate input tensor layout of data to send and target fully replicated + // layout. For now, we ensure that all data transfer happen with fully + // replicated tensors. + const int rank = ValueRank(dtensor_send.getInput()); + const Layout target_layout = Layout::ReplicatedOnMesh(input_mesh, rank); + + // Convert tensor to send to replicated layout. + mlir::OpBuilder builder(dtensor_send); + TF_ASSIGN_OR_RETURN(mlir::Value send_input, + EmitAllGather(builder, dtensor_send.getInput(), + input_layout, target_layout)); + + // Insert control flow such that only device with device ordinal == 0 sends + // the tensor data across mesh. + auto send_cluster = + dtensor_send->getParentOfType(); + TF_ASSIGN_OR_RETURN(std::optional mesh, + ExtractDeviceMeshFromOp(send_cluster)); + if (!mesh.has_value()) { + return absl::InvalidArgumentError( + "failed to lower DTensor CopyToMesh op as sending side mesh is not " + "specified."); } + + mlir::Location loc = dtensor_send.getLoc(); TF_ASSIGN_OR_RETURN( - lowered_send, - LowerDTensorSendToTFOp(input_layout, send_input, dtensor_send)); + mlir::Value device_ordinal, + GetDeviceOrdinal(*mesh, loc, + send_cluster->getParentOfType(), + &builder)); + mlir::Value predicate = builder.create( + loc, device_ordinal, CreateIntScalarConst(0, builder, loc), + /*incompatible_shape_error=*/builder.getBoolAttr(true)); + + auto send_if = builder.create( + loc, llvm::SmallVector{}, predicate, + /*is_stateless=*/builder.getBoolAttr(true), + GetUniqueControlflowFnName("copy_to_mesh_send_if_then", builder), + GetUniqueControlflowFnName("copy_to_mesh_send_if_else", builder)); + + // Create empty else branch region. + auto& else_branch = send_if.getElseBranch(); + else_branch.push_back(new mlir::Block); + builder.setInsertionPointToEnd(&else_branch.front()); + builder.create( + loc, + /*operands=*/llvm::ArrayRef{}); + + // Create then branch region with DTensorSend op. + auto& then_branch = send_if.getThenBranch(); + then_branch.push_back(new mlir::Block); + builder.setInsertionPointToEnd(&then_branch.front()); + auto yield = builder.create( + loc, /*operands=*/llvm::ArrayRef{}); + dtensor_send->moveBefore(yield); + + // Lower DTensorSend op to actual TF op. + TF_ASSIGN_OR_RETURN(const Mesh recv_mesh, + ExtractDeviceMeshEnclosingCluster(recv_op)); + if (SendRecvOpUsesXla(input_layout.mesh(), recv_mesh)) { + // Lower DTensorSend op to Xla Send ops. + TF_ASSIGN_OR_RETURN( + lowered_send, + LowerDTensorSendToXlaOp(input_layout, send_input, dtensor_send, + /*send_from_device_zero=*/true)); + } else if (input_layout.mesh().is_cpu_mesh() && recv_mesh.is_cpu_mesh()) { + // Lower DTensorSend op to TF Host Send op. + TF_ASSIGN_OR_RETURN( + lowered_send, LowerDTensorSendFromCPUToTFOp(input_layout, send_input, + dtensor_send)); + } else { + mlir::TensorType send_type = + send_input.getType().cast(); + if (!recv_mesh.is_cpu_mesh() && + send_type.getElementType().isInteger(32)) { + builder.setInsertionPointAfter(send_input.getDefiningOp()); + auto cast_to_int64 = builder.create( + send_input.getLoc(), + mlir::RankedTensorType::get(send_type.getShape(), + builder.getIntegerType(64)), + send_input); + send_input = cast_to_int64->getResult(0); + } + TF_ASSIGN_OR_RETURN( + lowered_send, + LowerDTensorSendToTFOp(input_layout, send_input, dtensor_send)); + } } return lowered_send; @@ -655,8 +661,7 @@ StatusOr LowerDTensorRecv(mlir::Operation* send_op, const Mesh& recv_mesh = recv_layout.mesh(); mlir::OpBuilder builder(dtensor_recv); - bool cpu_to_cpu = - dtensor_recv.getLayout().mesh().is_cpu_mesh() && send_mesh.is_cpu_mesh(); + bool cpu_to_cpu = recv_mesh.is_cpu_mesh() && send_mesh.is_cpu_mesh(); bool one_to_one = IsOneToOneMeshTransfer(send_layout, recv_layout); bool send_recv_xla = SendRecvOpUsesXla(send_mesh, recv_mesh); @@ -672,137 +677,134 @@ StatusOr LowerDTensorRecv(mlir::Operation* send_op, } return lowered_recv; - } else if (send_recv_xla || !cpu_to_cpu) { - if (send_recv_xla && - ((one_to_one && IsTpuToHostMeshTransfer(send_mesh, recv_mesh)) || - recv_mesh.is_cpu_mesh())) { - // Recv can be lowered directly for a 1-to-1 transfer between host and - // device (*for XLA/TPUs). - TF_ASSIGN_OR_RETURN(mlir::TensorType local_output_type, - LocalTypeFromGlobalType( - dtensor_recv.getLayout(), - dtensor_recv.getType().cast())); - TF_ASSIGN_OR_RETURN(lowered_recv, LowerDTensorRecvToXlaOp( - dtensor_recv, local_output_type)); - dtensor_recv->replaceAllUsesWith(lowered_recv); - dtensor_recv.erase(); - } else { - // Choose which receive lowering function to use. - auto lower_fn = - send_recv_xla - ? (decltype(&LowerDTensorRecvToTFOp))LowerDTensorRecvToXlaOp - : LowerDTensorRecvToTFOp; - - // For other send/recv layouts, the tensor needs to be replicated. - if (!dtensor_recv.getLayout().IsFullyReplicated()) { - return errors::InvalidArgument( - "CopyToMesh where target mesh is GPU/TPU requires a replicated " - "target layout."); - } + } else if (cpu_to_cpu) { + // Lower DTensorRecv op to TF Host Recv op. + TF_ASSIGN_OR_RETURN(lowered_recv, + LowerDTensorRecvFromCPUToTFOp(send_mesh, dtensor_recv)); + } else if ((one_to_one && IsTpuToHostMeshTransfer(send_mesh, recv_mesh)) || + (send_recv_xla && recv_mesh.is_cpu_mesh())) { + // Recv can be lowered directly for a 1-to-1 transfer between host and + // device (*for XLA/TPUs). + TF_ASSIGN_OR_RETURN(mlir::TensorType local_output_type, + LocalTypeFromGlobalType( + dtensor_recv.getLayout(), + dtensor_recv.getType().cast())); + TF_ASSIGN_OR_RETURN( + lowered_recv, LowerDTensorRecvToXlaOp(dtensor_recv, local_output_type)); + dtensor_recv->replaceAllUsesWith(lowered_recv); + dtensor_recv.erase(); + } else { + // Choose which receive lowering function to use. + auto lower_fn = + send_recv_xla + ? (decltype(&LowerDTensorRecvToTFOp))LowerDTensorRecvToXlaOp + : LowerDTensorRecvToTFOp; + + // For other send/recv layouts, the tensor needs to be replicated. + if (!dtensor_recv.getLayout().IsFullyReplicated()) { + return absl::InvalidArgumentError( + "CopyToMesh where target mesh is GPU/TPU requires a replicated " + "target layout."); + } - // For Receiving at GPU/TPU, only device 0 (ordinal) receives from the - // host, then it shares the tensor with its peers. - auto recv_cluster = - dtensor_recv->getParentOfType(); - mlir::Location loc = dtensor_recv.getLoc(); - TF_ASSIGN_OR_RETURN( - mlir::Value device_ordinal, - GetDeviceOrdinal(recv_mesh, loc, - recv_cluster->getParentOfType(), - &builder)); - mlir::Value predicate = builder.create( - loc, device_ordinal, CreateIntScalarConst(0, builder, loc), - /*incompatible_shape_error=*/builder.getBoolAttr(true)); - - mlir::TensorType recv_type = dtensor_recv.getType(); - bool i32_copy = recv_type.getElementType().isInteger(32); - bool need_i32_to_i64_upcast = - i32_copy && !(recv_mesh.is_cpu_mesh() || send_recv_xla); - mlir::TensorType output_type = - need_i32_to_i64_upcast - ? mlir::RankedTensorType::get(recv_type.getShape(), - builder.getIntegerType(64)) - : recv_type; - - auto recv_if = builder.create( - loc, llvm::SmallVector{output_type}, predicate, - /*is_stateless=*/builder.getBoolAttr(true), - GetUniqueControlflowFnName("copy_to_mesh_recv_if_then", builder), - GetUniqueControlflowFnName("copy_to_mesh_recv_if_else", builder)); - - // Create empty else branch region that outputs zeros. - auto& else_branch = recv_if.getElseBranch(); - else_branch.push_back(new mlir::Block); - builder.setInsertionPointToEnd(&else_branch.front()); - - // Create a zero constant. - mlir::Attribute const_attr; - auto output_element_type = output_type.getElementType(); - if (output_element_type.isIntOrIndex()) { - if (output_element_type.isInteger(64)) { - const_attr = mlir::DenseIntElementsAttr::get( - output_type, llvm::SmallVector{0}); - } else { - const_attr = mlir::DenseIntElementsAttr::get( - output_type, llvm::SmallVector{0}); - } - } else if (output_element_type.isBF16()) { - mlir::FloatAttr zero = mlir::FloatAttr::get(output_element_type, 0.); - const_attr = mlir::DenseElementsAttr::get( - output_type, llvm::SmallVector{zero}); - } else if (output_element_type.isF16() || output_element_type.isF32()) { - const_attr = mlir::DenseFPElementsAttr::get( - output_type, llvm::SmallVector{0.0}); - } else if (output_element_type.isF64()) { - const_attr = mlir::DenseFPElementsAttr::get( - output_type, llvm::SmallVector{0.0}); + // For Receiving at GPU/TPU, only device 0 (ordinal) receives from the + // host, then it shares the tensor with its peers. + auto recv_cluster = + dtensor_recv->getParentOfType(); + mlir::Location loc = dtensor_recv.getLoc(); + TF_ASSIGN_OR_RETURN( + mlir::Value device_ordinal, + GetDeviceOrdinal(recv_mesh, loc, + recv_cluster->getParentOfType(), + &builder)); + mlir::Value predicate = builder.create( + loc, device_ordinal, CreateIntScalarConst(0, builder, loc), + /*incompatible_shape_error=*/builder.getBoolAttr(true)); + + mlir::TensorType recv_type = dtensor_recv.getType(); + bool i32_copy = recv_type.getElementType().isInteger(32); + bool need_i32_to_i64_upcast = + i32_copy && !(recv_mesh.is_cpu_mesh() || send_recv_xla); + mlir::TensorType output_type = + need_i32_to_i64_upcast + ? mlir::RankedTensorType::get(recv_type.getShape(), + builder.getIntegerType(64)) + : recv_type; + + auto recv_if = builder.create( + loc, llvm::SmallVector{output_type}, predicate, + /*is_stateless=*/builder.getBoolAttr(true), + GetUniqueControlflowFnName("copy_to_mesh_recv_if_then", builder), + GetUniqueControlflowFnName("copy_to_mesh_recv_if_else", builder)); + + // Create empty else branch region that outputs zeros. + auto& else_branch = recv_if.getElseBranch(); + else_branch.push_back(new mlir::Block); + builder.setInsertionPointToEnd(&else_branch.front()); + + // Create a zero constant. + mlir::Attribute const_attr; + auto output_element_type = output_type.getElementType(); + if (output_element_type.isIntOrIndex()) { + if (output_element_type.isInteger(64)) { + const_attr = mlir::DenseIntElementsAttr::get( + output_type, llvm::SmallVector{0}); } else { - return errors::InvalidArgument("unsupported output type"); + const_attr = mlir::DenseIntElementsAttr::get( + output_type, llvm::SmallVector{0}); } + } else if (output_element_type.isBF16()) { + mlir::FloatAttr zero = mlir::FloatAttr::get(output_element_type, 0.); + const_attr = mlir::DenseElementsAttr::get( + output_type, llvm::SmallVector{zero}); + } else if (output_element_type.isF16() || output_element_type.isF32()) { + const_attr = mlir::DenseFPElementsAttr::get( + output_type, llvm::SmallVector{0.0}); + } else if (output_element_type.isF64()) { + const_attr = mlir::DenseFPElementsAttr::get( + output_type, llvm::SmallVector{0.0}); + } else { + return absl::InvalidArgumentError("unsupported output type"); + } - mlir::Value zeros = builder.create(loc, const_attr); - builder.create( - loc, /*operands=*/llvm::ArrayRef{zeros}); - - // Create then branch region with DTensorRecv op. - auto& then_branch = recv_if.getThenBranch(); - then_branch.push_back(new mlir::Block); - builder.setInsertionPointToEnd(&then_branch.front()); - dtensor_recv->moveBefore(&then_branch.front(), then_branch.front().end()); - - TF_ASSIGN_OR_RETURN(mlir::Operation * xla_recv, - lower_fn(send_mesh, dtensor_recv, output_type)); - builder.create( - loc, - /*operands=*/llvm::ArrayRef{xla_recv->getResult(0)}); - - // Broadcast the received output to all GPU/TPU devices. - mlir::Value if_output = recv_if->getResult(0); - builder.setInsertionPointAfterValue(if_output); - absl::flat_hash_set reduced_dims; - for (const auto& mesh_dim : recv_mesh.dims()) - reduced_dims.insert(mesh_dim.name); - - TF_ASSIGN_OR_RETURN(lowered_recv, - EmitAllReduce(builder, recv_layout, reduced_dims, - recv_if, kReduceOpAdd)); - - if (need_i32_to_i64_upcast) { - lowered_recv = builder.create( - loc, recv_type, lowered_recv->getResult(0)); - } + mlir::Value zeros = builder.create(loc, const_attr); + builder.create( + loc, /*operands=*/llvm::ArrayRef{zeros}); + + // Create then branch region with DTensorRecv op. + auto& then_branch = recv_if.getThenBranch(); + then_branch.push_back(new mlir::Block); + builder.setInsertionPointToEnd(&then_branch.front()); + dtensor_recv->moveBefore(&then_branch.front(), then_branch.front().end()); + + TF_ASSIGN_OR_RETURN(mlir::Operation * xla_recv, + lower_fn(send_mesh, dtensor_recv, output_type)); + builder.create( + loc, + /*operands=*/llvm::ArrayRef{xla_recv->getResult(0)}); + + // Broadcast the received output to all GPU/TPU devices. + mlir::Value if_output = recv_if->getResult(0); + builder.setInsertionPointAfterValue(if_output); + absl::flat_hash_set reduced_dims; + for (const auto& mesh_dim : recv_mesh.dims()) + reduced_dims.insert(mesh_dim.name); - // Replaces usages of DTensorRecv op with the broadcasted value. - dtensor_recv.getOutput().replaceUsesWithIf( - lowered_recv->getResult(0), [&](mlir::OpOperand& operand) { - return !recv_if->isProperAncestor(operand.getOwner()); - }); - dtensor_recv.erase(); + TF_ASSIGN_OR_RETURN( + lowered_recv, EmitAllReduce(builder, recv_layout, reduced_dims, recv_if, + kReduceOpAdd)); + + if (need_i32_to_i64_upcast) { + lowered_recv = builder.create( + loc, recv_type, lowered_recv->getResult(0)); } - } else { - // Lower DTensorRecv op to TF Host Recv op. - TF_ASSIGN_OR_RETURN(lowered_recv, - LowerDTensorRecvFromCPUToTFOp(send_mesh, dtensor_recv)); + + // Replaces usages of DTensorRecv op with the broadcasted value. + dtensor_recv.getOutput().replaceUsesWithIf( + lowered_recv->getResult(0), [&](mlir::OpOperand& operand) { + return !recv_if->isProperAncestor(operand.getOwner()); + }); + dtensor_recv.erase(); } llvm::SmallPtrSet newly_created_ops; From 3a8a4670abc532dc8fc120b71ace6b4f17e4f474 Mon Sep 17 00:00:00 2001 From: Antonio Sanchez Date: Wed, 12 Jul 2023 09:37:36 -0700 Subject: [PATCH 197/376] Fix AVX512 builds involving XLA CPU conv2d. There's an incompatibility between TensorFlow's CPU convolution implementation and Intel's new AVX512 matrix multiplication in Eigen. We currently need to disable the specialized Eigen matmul routine. Fixes #61216. PiperOrigin-RevId: 547520161 --- tensorflow/compiler/tf2xla/BUILD | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index 42b2acd6d27cec..dcf4d4880a3344 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -254,7 +254,11 @@ cc_library( ":xla_compiled_cpu_runtime_hdrs", ], copts = runtime_copts() + tf_openmp_copts(), - defines = ["EIGEN_NEON_GEBP_NR=4"], + defines = [ + "EIGEN_NEON_GEBP_NR=4", + # TODO(b/238649163): remove this once no longer necessary. + "EIGEN_USE_AVX512_GEMM_KERNELS=0", + ], features = [ "fully_static_link", "-parse_headers", From ad5b2ea1d971596c35ce127f415d4f4c5669e168 Mon Sep 17 00:00:00 2001 From: Rahul Joshi Date: Wed, 12 Jul 2023 09:56:14 -0700 Subject: [PATCH 198/376] [XLA/GPU] Change PGLE to support either text or binary proto files PiperOrigin-RevId: 547525016 --- .../xla/service/gpu/gpu_hlo_schedule.cc | 46 +++++++++++-------- 1 file changed, 26 insertions(+), 20 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc index 1e63e1f7dfc37c..fb62f0e2b5c947 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc @@ -437,31 +437,37 @@ std::optional ReadPGLEProfile( return std::nullopt; } tsl::Env* env = tsl::Env::Default(); + auto read_text_or_binary_profile = [&profile, env]( + const std::string& text_path, + const std::string& binary_path) + -> std::optional { + Status s = tsl::ReadTextProto(env, text_path, &profile); + if (s.ok()) { + LOG(INFO) << "Using PGLE profile from " << text_path; + return profile; + } + profile.Clear(); + s = tsl::ReadBinaryProto(env, binary_path, &profile); + if (s.ok()) { + LOG(INFO) << "Using PGLE profile from " << binary_path; + return profile; + } + return std::nullopt; + }; + // If its a directory, use fingerprint to look for the profile for this // specific module. if (env->IsDirectory(pgle_profile_file_or_dir_path).ok()) { - std::string pgle_profile_path = - pgle_profile_file_or_dir_path + "/" + fingerprint + ".pbtxt"; - Status s = - tsl::ReadTextProto(tsl::Env::Default(), pgle_profile_path, &profile); - if (!s.ok()) { - // Unable to read PGLE using fingerprint. - return std::nullopt; - } - LOG(INFO) << "Using PGLE profile from " << pgle_profile_path; - return profile; + std::string pgle_profile_path_prefix = + pgle_profile_file_or_dir_path + "/" + fingerprint; + return read_text_or_binary_profile(pgle_profile_path_prefix + ".pbtxt", + pgle_profile_path_prefix + ".pb"); } - // The pgle_profile_file_or_dir is a file. Read the profile and see if its - // applicable for this HLO module (all instruction names in the profile should - // be present in the HLO module) - Status s = tsl::ReadTextProto(tsl::Env::Default(), - pgle_profile_file_or_dir_path, &profile); - if (s.ok()) { - LOG(INFO) << "Using PGLE profile from " << pgle_profile_file_or_dir_path; - return profile; - } - return std::nullopt; + // The pgle_profile_file_or_dir is a file. Attempt to read the profile as text + // proto or binary proto. + return read_text_or_binary_profile(pgle_profile_file_or_dir_path, + pgle_profile_file_or_dir_path); } // Return true if the profile is applicable to the module. That is true if every From 3efa230babd2dcd9ac6fdc50246a1afc1f96a0d9 Mon Sep 17 00:00:00 2001 From: Yu Feng Date: Wed, 12 Jul 2023 10:10:24 -0700 Subject: [PATCH 199/376] [KernelGen] JIT-compile most the MLIR-generated GPU kernels JIT-compile all MLIR-generated kernels for which the build rules can be reconfigured easily. For now, this excludes i64-indexed kernels and kernels with different input and output types. PiperOrigin-RevId: 547529230 --- tensorflow/core/kernels/mlir_generated/BUILD | 329 +++++++++---------- tensorflow/python/kernel_tests/linalg/BUILD | 2 +- tensorflow/python/ops/BUILD | 4 +- 3 files changed, 153 insertions(+), 182 deletions(-) diff --git a/tensorflow/core/kernels/mlir_generated/BUILD b/tensorflow/core/kernels/mlir_generated/BUILD index b94f1ffd77aa32..c4c5b6f5d99435 100644 --- a/tensorflow/core/kernels/mlir_generated/BUILD +++ b/tensorflow/core/kernels/mlir_generated/BUILD @@ -639,14 +639,13 @@ gpu_kernel_library( gpu_kernel_library( name = "gpu_atan2_kernels", - jit_types = [ + op = "atan2", + tile_size = "256", + types = [ "f16", "f32", "f64", ], - op = "atan2", - tile_size = "256", - types = [], unroll_factors = "4", ) @@ -749,27 +748,25 @@ gpu_kernel_library( gpu_kernel_library( name = "gpu_ceil_kernels", - jit_types = [ + op = "ceil", + tile_size = "256", + types = [ "f16", "f32", "f64", ], - op = "ceil", - tile_size = "256", - types = [], unroll_factors = "4", ) gpu_kernel_library( name = "gpu_floor_kernels", - jit_types = [ + op = "floor", + tile_size = "256", + types = [ "f16", "f32", "f64", ], - op = "floor", - tile_size = "256", - types = [], unroll_factors = "4", ) @@ -795,28 +792,26 @@ gpu_kernel_library( gpu_kernel_library( name = "gpu_rint_kernels", - jit_types = [ - "f16", + jit_types = ["f16"], + op = "rint", + tile_size = "1024", + types = [ "f32", "f64", ], - op = "rint", - tile_size = "1024", - types = [], ) gpu_kernel_library( name = "gpu_round_kernels", - jit_types = [ + op = "round", + tile_size = "1024", + types = [ "f16", "f32", "f64", "i32", "i64", ], - op = "round", - tile_size = "1024", - types = [], ) # Predicate kernels @@ -1034,13 +1029,12 @@ gpu_kernel_library( gpu_kernel_library( name = "gpu_conj_kernels", - jit_types = [ + op = "conj", + tile_size = "256", + types = [ "c64", "c128", ], - op = "conj", - tile_size = "256", - types = [], unroll_factors = "2", ) @@ -1177,6 +1171,10 @@ gpu_kernel_library( "ui16", "ui32", "ui64", + ], + op = "maximum", + tile_size = "1024", + types = [ "f16", "f32", "f64", @@ -1184,9 +1182,6 @@ gpu_kernel_library( "i64", "ui8", ], - op = "maximum", - tile_size = "1024", - types = [], unroll_factors = "4", ) @@ -1197,6 +1192,10 @@ gpu_kernel_library( "ui16", "ui32", "ui64", + ], + op = "minimum", + tile_size = "1024", + types = [ "f16", "f32", "f64", @@ -1204,9 +1203,6 @@ gpu_kernel_library( "i64", "ui8", ], - op = "minimum", - tile_size = "1024", - types = [], unroll_factors = "4", ) @@ -1258,7 +1254,9 @@ gpu_kernel_library( gpu_kernel_library( name = "gpu_neg_kernels", - jit_types = [ + op = "neg", + tile_size = "256", + types = [ "f16", "f32", "f64", @@ -1269,9 +1267,6 @@ gpu_kernel_library( "c64", "c128", ], - op = "neg", - tile_size = "256", - types = [], unroll_factors = "4", ) @@ -1280,19 +1275,22 @@ gpu_kernel_library( jit_types = [ "i8", "i16", + ], + op = "pow", + tile_size = "1024", + types = [ "f16", "f32", "f64", "i64", ], - op = "pow", - tile_size = "1024", - types = [], ) gpu_kernel_library( name = "gpu_reciprocal_kernels", - jit_types = [ + op = "reciprocal", + tile_size = "256", + types = [ "c64", "c128", "f16", @@ -1300,9 +1298,6 @@ gpu_kernel_library( "f64", "i64", ], - op = "reciprocal", - tile_size = "256", - types = [], unroll_factors = "4", ) @@ -1311,6 +1306,10 @@ gpu_kernel_library( jit_types = [ "i8", "i16", + ], + op = "sign", + tile_size = "256", + types = [ "f16", "f32", "f64", @@ -1319,9 +1318,6 @@ gpu_kernel_library( "c64", "c128", ], - op = "sign", - tile_size = "256", - types = [], unroll_factors = "4", ) @@ -1365,86 +1361,80 @@ gpu_kernel_library( gpu_kernel_library( name = "gpu_xdivy_kernels", - jit_types = [ + op = "xdivy", + tile_size = "1024", + types = [ "f16", "f32", "f64", "c64", "c128", ], - op = "xdivy", - tile_size = "1024", - types = [], unroll_factors = "4", ) # Logarithmic and exponential kernels gpu_kernel_library( name = "gpu_exp_kernels", - jit_types = [ + op = "exp", + tile_size = "256", + types = [ "f16", "f32", "f64", "c64", "c128", ], - op = "exp", - tile_size = "256", - types = [], unroll_factors = "4", ) gpu_kernel_library( name = "gpu_expm1_kernels", - jit_types = [ + op = "expm1", + tile_size = "256", + types = [ "f16", "f32", "f64", ], - op = "expm1", - tile_size = "256", - types = [], unroll_factors = "4", ) gpu_kernel_library( name = "gpu_log_kernels", - jit_types = [ + op = "log", + tile_size = "256", + types = [ "f16", "f32", "f64", ], - op = "log", - tile_size = "256", - types = [], unroll_factors = "4", ) gpu_kernel_library( name = "gpu_log1p_kernels", - jit_types = [ + op = "log1p", + tile_size = "256", + types = [ "f16", "f32", "f64", ], - op = "log1p", - tile_size = "256", - types = [], unroll_factors = "4", ) gpu_kernel_library( name = "gpu_xlogy_kernels", - jit_types = [ + op = "xlogy", + tile_size = "1024", + types = [ "f16", "f32", "f64", "c64", "c128", ], - op = "xlogy", - tile_size = "1024", - types = [], unroll_factors = "4", # For complex XlogyOp kernels, we don't use unrolling, it would only cause # slowdowns. @@ -1456,16 +1446,15 @@ gpu_kernel_library( gpu_kernel_library( name = "gpu_xlog1py_kernels", - jit_types = [ + op = "xlog1py", + tile_size = "1024", + types = [ "f16", "f32", "f64", "c64", "c128", ], - op = "xlog1py", - tile_size = "1024", - types = [], unroll_factors = "4", # For complex Xlog1pyOp kernels, we don't use unrolling, it would only cause # slowdowns. @@ -1479,27 +1468,25 @@ gpu_kernel_library( gpu_kernel_library( name = "gpu_sqrt_kernels", - jit_types = [ + op = "sqrt", + tile_size = "256", + types = [ "f16", "f32", "f64", ], - op = "sqrt", - tile_size = "256", - types = [], unroll_factors = "4", ) gpu_kernel_library( name = "gpu_rsqrt_kernels", - jit_types = [ + op = "rsqrt", + tile_size = "256", + types = [ "f16", "f32", "f64", ], - op = "rsqrt", - tile_size = "256", - types = [], unroll_factors = "4", ) @@ -1512,28 +1499,28 @@ gpu_kernel_library( "ui16", "ui32", "ui64", + ], + op = "square", + tile_size = "1024", + types = [ "f16", "f32", "f64", "i64", ], - op = "square", - tile_size = "1024", - types = [], unroll_factors = "4", ) gpu_kernel_library( name = "gpu_squared_difference_kernels", - jit_types = [ + op = "squared_difference", + tile_size = "1024", + types = [ "f16", "f32", "f64", "i64", ], - op = "squared_difference", - tile_size = "1024", - types = [], unroll_factors = "4", ) @@ -1541,77 +1528,74 @@ gpu_kernel_library( gpu_kernel_library( name = "gpu_bitwise_and_kernels", - jit_types = [ + op = "bitwise_and", + tile_size = "1024", + types = [ "i8", "i16", "i32", "i64", ], - op = "bitwise_and", - tile_size = "1024", - types = [], unroll_factors = "4", ) gpu_kernel_library( name = "gpu_bitwise_or_kernels", - jit_types = [ + op = "bitwise_or", + tile_size = "1024", + types = [ "i8", "i16", "i32", "i64", ], - op = "bitwise_or", - tile_size = "1024", - types = [], unroll_factors = "4", ) gpu_kernel_library( name = "gpu_bitwise_xor_kernels", - jit_types = [ + op = "bitwise_xor", + tile_size = "1024", + types = [ "i8", "i16", "i32", "i64", ], - op = "bitwise_xor", - tile_size = "1024", - types = [], unroll_factors = "4", ) gpu_kernel_library( name = "gpu_invert_kernels", - jit_types = [ + op = "invert", + tile_size = "1024", + types = [ "i8", "i16", "i32", "i64", ], - op = "invert", - tile_size = "1024", - types = [], unroll_factors = "4", ) gpu_kernel_library( name = "gpu_left_shift_kernels", - jit_types = [ + op = "left_shift", + tile_size = "1024", + types = [ "i8", "i16", "i32", "i64", ], - op = "left_shift", - tile_size = "1024", - types = [], unroll_factors = "4", ) gpu_kernel_library( name = "gpu_right_shift_kernels", - jit_types = [ + op = "right_shift", + tile_size = "1024", + types = [ "i8", "i16", "i32", @@ -1621,9 +1605,6 @@ gpu_kernel_library( "ui32", "ui64", ], - op = "right_shift", - tile_size = "1024", - types = [], unroll_factors = "4", ) @@ -1631,57 +1612,52 @@ gpu_kernel_library( gpu_kernel_library( name = "gpu_logical_not_kernels", - jit_types = ["i1"], op = "logical_not", tile_size = "256", - types = [], + types = ["i1"], ) gpu_kernel_library( name = "gpu_logical_and_kernels", - jit_types = [ - "i1", - ], op = "logical_and", tile_size = "1024", - types = [], + types = [ + "i1", + ], ) gpu_kernel_library( name = "gpu_logical_or_kernels", - jit_types = [ - "i1", - ], op = "logical_or", tile_size = "1024", - types = [], + types = [ + "i1", + ], ) # Erf kernels gpu_kernel_library( name = "gpu_erf_kernels", - jit_types = [ + op = "erf", + tile_size = "256", + types = [ "f16", "f32", "f64", ], - op = "erf", - tile_size = "256", - types = [], unroll_factors = "4", ) gpu_kernel_library( name = "gpu_erfc_kernels", - jit_types = [ + op = "erfc", + tile_size = "256", + types = [ "f16", "f32", "f64", ], - op = "erfc", - tile_size = "256", - types = [], unroll_factors = "4", ) @@ -1689,49 +1665,45 @@ gpu_kernel_library( gpu_kernel_library( name = "gpu_polygamma_kernels", - jit_types = [ + op = "polygamma", + tile_size = "256", + types = [ "f32", "f64", ], - op = "polygamma", - tile_size = "256", - types = [], ) gpu_kernel_library( name = "gpu_digamma_kernels", - jit_types = [ + op = "digamma", + tile_size = "256", + types = [ "f16", "f32", "f64", ], - op = "digamma", - tile_size = "256", - types = [], ) gpu_kernel_library( name = "gpu_lgamma_kernels", - jit_types = [ + op = "lgamma", + tile_size = "256", + types = [ "f16", "f32", "f64", ], - op = "lgamma", - tile_size = "256", - types = [], ) gpu_kernel_library( # The zeta kernels needs many registers so tile at 256. name = "gpu_zeta_kernels", - jit_types = [ + op = "zeta", + tile_size = "256", + types = [ "f32", "f64", ], - op = "zeta", - tile_size = "256", - types = [], # TODO(b/178388085): Enable unrolling after vectorization is fixed. # unroll_factors = "4", ) @@ -1758,64 +1730,61 @@ gpu_kernel_library( "ui16", "ui32", "ui64", + ], + op = "relu", + tile_size = "256", + types = [ "f16", "f32", "f64", ], - op = "relu", - tile_size = "256", - types = [], unroll_factors = "16B", ) gpu_kernel_library( name = "gpu_elu_kernels", - jit_types = [ + op = "elu", + tile_size = "256", + types = [ "f16", "f32", "f64", ], - op = "elu", - tile_size = "256", - types = [], ) gpu_kernel_library( name = "gpu_selu_kernels", - jit_types = [ + op = "selu", + tile_size = "256", + types = [ "f16", "f32", "f64", ], - op = "selu", - tile_size = "256", - types = [], ) gpu_kernel_library( name = "gpu_sigmoid_kernels", - jit_types = [ + op = "sigmoid", + tile_size = "256", + types = [ "f16", "f32", "f64", ], - op = "sigmoid", - tile_size = "256", - types = [], ) # Kernels that support all floating-point types. [ gpu_kernel_library( name = "gpu_" + op + "_kernels", - jit_types = [ + op = op, + tile_size = "256", + types = [ "f16", "f32", "f64", ], - op = op, - tile_size = "256", - types = [], unroll_factors = "4", ) for op in [ @@ -1867,6 +1836,11 @@ gpu_kernel_library( "ui16", "ui32", "ui64", + ], + max_supported_rank = 8, + op = "select_v2", + tile_size = "256", + types = [ "i1", "i32", "i64", @@ -1876,10 +1850,6 @@ gpu_kernel_library( "c64", "c128", ], - max_supported_rank = 8, - op = "select_v2", - tile_size = "256", - types = [], ) gpu_kernel_library( @@ -1891,6 +1861,10 @@ gpu_kernel_library( "ui16", "ui32", "ui64", + ], + op = "zeros_like", + tile_size = "1024", + types = [ "i1", "i64", "f16", @@ -1899,9 +1873,6 @@ gpu_kernel_library( "c64", "c128", ], - op = "zeros_like", - tile_size = "1024", - types = [], ) gpu_kernel_library( @@ -1913,6 +1884,10 @@ gpu_kernel_library( "ui16", "ui32", "ui64", + ], + op = "ones_like", + tile_size = "1024", + types = [ "i1", "i64", "f16", @@ -1921,18 +1896,14 @@ gpu_kernel_library( "c64", "c128", ], - op = "ones_like", - tile_size = "1024", - types = [], ) gpu_kernel_library( name = "gpu_next_after_kernels", - jit_types = [ + op = "next_after", + tile_size = "1024", + types = [ "f32", "f64", ], - op = "next_after", - tile_size = "1024", - types = [], ) diff --git a/tensorflow/python/kernel_tests/linalg/BUILD b/tensorflow/python/kernel_tests/linalg/BUILD index 42fd131137743f..ada2c64403c01b 100644 --- a/tensorflow/python/kernel_tests/linalg/BUILD +++ b/tensorflow/python/kernel_tests/linalg/BUILD @@ -271,7 +271,7 @@ cuda_py_strict_test( name = "linear_operator_circulant_test", size = "medium", srcs = ["linear_operator_circulant_test.py"], - shard_count = 32, + shard_count = 15, tags = [ "no_cuda11", # TODO(b/197522782): reenable test after fixing. "optonly", # times out, b/79171797 diff --git a/tensorflow/python/ops/BUILD b/tensorflow/python/ops/BUILD index 83dc031e86559e..1a8d02a84cd15f 100644 --- a/tensorflow/python/ops/BUILD +++ b/tensorflow/python/ops/BUILD @@ -3158,7 +3158,7 @@ py_strict_library( cuda_py_strict_test( name = "bitwise_ops_test", - size = "medium", + size = "small", srcs = ["bitwise_ops_test.py"], main = "bitwise_ops_test.py", python_version = "PY3", @@ -3503,7 +3503,7 @@ cuda_py_strict_test( cuda_py_strict_test( name = "math_grad_test", - size = "medium", + size = "small", srcs = ["math_grad_test.py"], main = "math_grad_test.py", python_version = "PY3", From a4d063f70c4f4a62ba683cd49368f26387cd44ae Mon Sep 17 00:00:00 2001 From: Junwhan Ahn Date: Wed, 12 Jul 2023 10:45:09 -0700 Subject: [PATCH 200/376] Fix a copy-and-paste error PiperOrigin-RevId: 547539747 --- tensorflow/compiler/xla/python/ifrt/sharding.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/compiler/xla/python/ifrt/sharding.h b/tensorflow/compiler/xla/python/ifrt/sharding.h index a03ce4ebda8e57..375cedc16a0a68 100644 --- a/tensorflow/compiler/xla/python/ifrt/sharding.h +++ b/tensorflow/compiler/xla/python/ifrt/sharding.h @@ -72,7 +72,7 @@ class Sharding : public llvm::RTTIExtends { DeviceList devices_; }; -std::ostream& operator<<(std::ostream& os, const Shape& shape); +std::ostream& operator<<(std::ostream& os, const Sharding& sharding); // Single-device sharding. // From 2e86aa79d38698d43e9aebdeab55b32ddcaa819c Mon Sep 17 00:00:00 2001 From: Shiqing Yan Date: Wed, 12 Jul 2023 10:58:25 -0700 Subject: [PATCH 201/376] Add CompatibilityStatusToString and StringToCompatibilityStatus util functions to GPUCompatibilityList. PiperOrigin-RevId: 547544036 --- .../compatibility/gpu_compatibility.cc | 34 ++++++++++++++----- .../compatibility/gpu_compatibility.h | 11 ++++++ .../compatibility/gpu_compatibility_test.cc | 30 ++++++++++++++++ 3 files changed, 67 insertions(+), 8 deletions(-) diff --git a/tensorflow/lite/experimental/acceleration/compatibility/gpu_compatibility.cc b/tensorflow/lite/experimental/acceleration/compatibility/gpu_compatibility.cc index 4546fe0d4ff2bb..a561992ab54ec2 100644 --- a/tensorflow/lite/experimental/acceleration/compatibility/gpu_compatibility.cc +++ b/tensorflow/lite/experimental/acceleration/compatibility/gpu_compatibility.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/lite/experimental/acceleration/compatibility/database_generated.h" #include "tensorflow/lite/experimental/acceleration/compatibility/devicedb.h" #include "tensorflow/lite/experimental/acceleration/compatibility/gpu_compatibility_binary.h" +#include "tensorflow/lite/experimental/acceleration/compatibility/variables.h" namespace tflite { namespace acceleration { @@ -111,14 +112,7 @@ gpu::CompatibilityStatus GPUCompatibilityList::GetStatus( CanonicalizeValues(&variables); if (!database_) return gpu::CompatibilityStatus::kUnknown; UpdateVariablesFromDatabase(&variables, *database_); - const std::string& status = variables[gpu::kStatus]; - if (status == gpu::kStatusSupported) { - return gpu::CompatibilityStatus::kSupported; - } else if (status == gpu::kStatusUnsupported) { - return gpu::CompatibilityStatus::kUnsupported; - } else { - return gpu::CompatibilityStatus::kUnknown; - } + return StringToCompatibilityStatus(variables[gpu::kStatus]); } TfLiteGpuDelegateOptionsV2 GPUCompatibilityList::GetBestOptionsFor( @@ -156,5 +150,29 @@ std::map GPUCompatibilityList::InfosToMap( return variables; } +// static +std::string GPUCompatibilityList::CompatibilityStatusToString( + gpu::CompatibilityStatus status) { + switch (status) { + case gpu::CompatibilityStatus::kSupported: + return gpu::kStatusSupported; + case gpu::CompatibilityStatus::kUnsupported: + return gpu::kStatusUnsupported; + case gpu::CompatibilityStatus::kUnknown: + return gpu::kStatusUnknown; + } +} + +// static +gpu::CompatibilityStatus GPUCompatibilityList::StringToCompatibilityStatus( + absl::string_view status) { + if (status == gpu::kStatusSupported) { + return gpu::CompatibilityStatus::kSupported; + } else if (status == gpu::kStatusUnsupported) { + return gpu::CompatibilityStatus::kUnsupported; + } + return gpu::CompatibilityStatus::kUnknown; +} + } // namespace acceleration } // namespace tflite diff --git a/tensorflow/lite/experimental/acceleration/compatibility/gpu_compatibility.h b/tensorflow/lite/experimental/acceleration/compatibility/gpu_compatibility.h index de2b66c5d7f2b8..59f73c2c9a7759 100644 --- a/tensorflow/lite/experimental/acceleration/compatibility/gpu_compatibility.h +++ b/tensorflow/lite/experimental/acceleration/compatibility/gpu_compatibility.h @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/strings/string_view.h" #include "tensorflow/lite/delegates/gpu/common/gpu_info.h" #include "tensorflow/lite/delegates/gpu/delegate_options.h" #include "tensorflow/lite/experimental/acceleration/compatibility/android_info.h" @@ -113,6 +114,16 @@ class GPUCompatibilityList { const AndroidInfo& android_info, const ::tflite::gpu::GpuInfo& gpu_info) const; + // Converts the compatibility status enum value to the corresponding status + // string. + static std::string CompatibilityStatusToString( + gpu::CompatibilityStatus status); + + // Converts the status string to the corresponding compatibility status enum + // value. + static gpu::CompatibilityStatus StringToCompatibilityStatus( + absl::string_view status); + protected: const DeviceDatabase* database_; diff --git a/tensorflow/lite/experimental/acceleration/compatibility/gpu_compatibility_test.cc b/tensorflow/lite/experimental/acceleration/compatibility/gpu_compatibility_test.cc index c7427cea792dde..d7d4538c94c718 100644 --- a/tensorflow/lite/experimental/acceleration/compatibility/gpu_compatibility_test.cc +++ b/tensorflow/lite/experimental/acceleration/compatibility/gpu_compatibility_test.cc @@ -152,4 +152,34 @@ TEST(GPUCompatibility, CreationWithNullCompatibilityListFlatbuffer) { EXPECT_EQ(list, nullptr); } +TEST(GPUCompatibility, ConvertCompatibilityStatusToStringCorrectly) { + EXPECT_EQ( + tflite::acceleration::GPUCompatibilityList::CompatibilityStatusToString( + tflite::acceleration::gpu::CompatibilityStatus::kSupported), + tflite::acceleration::gpu::kStatusSupported); + EXPECT_EQ( + tflite::acceleration::GPUCompatibilityList::CompatibilityStatusToString( + tflite::acceleration::gpu::CompatibilityStatus::kUnsupported), + tflite::acceleration::gpu::kStatusUnsupported); + EXPECT_EQ( + tflite::acceleration::GPUCompatibilityList::CompatibilityStatusToString( + tflite::acceleration::gpu::CompatibilityStatus::kUnknown), + tflite::acceleration::gpu::kStatusUnknown); +} + +TEST(GPUCompatibility, ConvertStringToCompatibilityStatusCorrectly) { + EXPECT_EQ( + tflite::acceleration::GPUCompatibilityList::StringToCompatibilityStatus( + tflite::acceleration::gpu::kStatusSupported), + tflite::acceleration::gpu::CompatibilityStatus::kSupported); + EXPECT_EQ( + tflite::acceleration::GPUCompatibilityList::StringToCompatibilityStatus( + tflite::acceleration::gpu::kStatusUnsupported), + tflite::acceleration::gpu::CompatibilityStatus::kUnsupported); + EXPECT_EQ( + tflite::acceleration::GPUCompatibilityList::StringToCompatibilityStatus( + tflite::acceleration::gpu::kStatusUnknown), + tflite::acceleration::gpu::CompatibilityStatus::kUnknown); +} + } // namespace From 24f323096dd48615ca59c208708ca2cb8d88685f Mon Sep 17 00:00:00 2001 From: Justin Lebar Date: Wed, 12 Jul 2023 10:58:27 -0700 Subject: [PATCH 202/376] [XLA:GPU] Support Conv-Bias-Relu6/LeakyRelu fusion in XLA using cuDNN runtime fusion This is https://github.com/tensorflow/tensorflow/pull/60377 authored by https://github.com/Young768, with some modifications: - Fix formatting and remove trailing whitespace. - Tests in cudnn_fused_conv_rewriter_test.cc were checking GetCudaComputeCapability(), i.e. the compute capability of the installed GPU. This isn't necessary: We're not *running* the convs, we're just checking that they get pattern-matched correctly. Change the tests so that they explicitly pass a compute capability to the pass and then unconditionally check the pattern-matching. - The relu6 and leaky-relu matchers (and indeed the existing elu matcher) were incorrect. They created a gep_pattern variable and used it three times. The intent was that these three instances of the matcher would all match the same gep instruction, but that's not how the matchers work. This has been reworked to match the intent of the code. - The existing elu matcher bails if the conv has a side-input, saying that the cudnn runtime fusion engine does not do well in this case. We saw the same thing for relu6, we get pathologically-slow convs when we have side-inputs. So exclude these from relu6 and leaky-relu fusion. PiperOrigin-RevId: 547544055 --- .../transforms/lmhlo_gpu_to_gpu_runtime.cc | 1 + .../gpu/transforms/tests/lmhlo_gpu_conv.mlir | 1 + .../transforms/tests/outline_cuda_graphs.mlir | 1 + .../xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.td | 3 +- .../lhlo_gpu/IR/lhlo_gpu_ops_enums.td | 3 +- .../tests/Dialect/lhlo_gpu/lhlo_gpu_ops.mlir | 1 + tensorflow/compiler/xla/service/gpu/BUILD | 1 + .../xla/service/gpu/backend_configs.proto | 9 + .../xla/service/gpu/conv_algorithm_picker.cc | 8 +- .../service/gpu/cudnn_fused_conv_rewriter.cc | 187 ++++++++++++--- .../gpu/cudnn_fused_conv_rewriter_test.cc | 225 ++++++++++++++++++ .../xla/service/gpu/gpu_conv_runner.cc | 9 +- .../xla/service/gpu/gpu_conv_runner.h | 1 + .../xla/service/gpu/ir_emitter_unnested.cc | 2 + .../compiler/xla/service/gpu/runtime/conv.cc | 27 ++- .../compiler/xla/service/hlo_graph_dumper.cc | 3 + ...aot_compile_test_autotune_results.prototxt | 2 +- .../compiler/xla/stream_executor/dnn.cc | 2 + .../mhlo_to_hlo/attribute_exporter.cc | 2 + .../mhlo_to_lhlo_with_xla.cc | 4 + 20 files changed, 440 insertions(+), 52 deletions(-) diff --git a/tensorflow/compiler/xla/mlir/backends/gpu/transforms/lmhlo_gpu_to_gpu_runtime.cc b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/lmhlo_gpu_to_gpu_runtime.cc index cb07346409e8be..1f43ff28261b0a 100644 --- a/tensorflow/compiler/xla/mlir/backends/gpu/transforms/lmhlo_gpu_to_gpu_runtime.cc +++ b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/lmhlo_gpu_to_gpu_runtime.cc @@ -372,6 +372,7 @@ class ConvOpLowering : public OpRewritePattern { if (auto fused = dyn_cast(op.getOperation())) { call->setAttr(b.getStringAttr("activation_mode"), fused.getActivationModeAttr()); + set_attr("leakyrelu_alpha", fused.getLeakyreluAlphaAttr()); } // Copy attributes specific for fused convolutions with side input. diff --git a/tensorflow/compiler/xla/mlir/backends/gpu/transforms/tests/lmhlo_gpu_conv.mlir b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/tests/lmhlo_gpu_conv.mlir index cd52e0ae826194..59d6ca10a9bdca 100644 --- a/tensorflow/compiler/xla/mlir/backends/gpu/transforms/tests/lmhlo_gpu_conv.mlir +++ b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/tests/lmhlo_gpu_conv.mlir @@ -212,6 +212,7 @@ func.func @conv_forward_fused(%input: memref<8x5x5x1xf32, #map1>, reverse = [0, 0] } { activation_mode = #lmhlo_gpu, + leakyrelu_alpha = 0.0 : f64, backend_config = #lmhlo_gpu.convolution_backend_config< algorithm = 11, is_cudnn_frontend = true, diff --git a/tensorflow/compiler/xla/mlir/backends/gpu/transforms/tests/outline_cuda_graphs.mlir b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/tests/outline_cuda_graphs.mlir index 3cdae0c117489d..ce2a7227b97a89 100644 --- a/tensorflow/compiler/xla/mlir/backends/gpu/transforms/tests/outline_cuda_graphs.mlir +++ b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/tests/outline_cuda_graphs.mlir @@ -412,6 +412,7 @@ module attributes {gpu.container_module} { reverse = [0, 0] } { activation_mode = #lmhlo_gpu, + leakyrelu_alpha = 0.0 : f64, backend_config = #lmhlo_gpu.convolution_backend_config< algorithm = -1, is_cudnn_frontend = true, diff --git a/tensorflow/compiler/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.td b/tensorflow/compiler/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.td index 72fd56b2a15ed5..7e61139b761477 100644 --- a/tensorflow/compiler/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.td +++ b/tensorflow/compiler/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.td @@ -100,7 +100,8 @@ def LHLOGPU_ConvForwardFusedOp : LHLOGPU_ConvBaseOp<"conv_forward_fused"> { Arg:$output, Arg:$scratch), GpuConvolutionAttributes<(ins - ActivationAttr:$activation_mode)>.attributes); + ActivationAttr:$activation_mode, + F64Attr:$leakyrelu_alpha)>.attributes); } // output = activation(result_scale * conv(input, filter) + diff --git a/tensorflow/compiler/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops_enums.td b/tensorflow/compiler/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops_enums.td index 2a527ac553779f..1e151212a3718e 100644 --- a/tensorflow/compiler/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops_enums.td +++ b/tensorflow/compiler/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops_enums.td @@ -30,12 +30,13 @@ def ActivationModeRelu6 : I32EnumAttrCase<"Relu6", 4>; def ActivationModeReluX : I32EnumAttrCase<"ReluX", 5>; def ActivationModeBandPass : I32EnumAttrCase<"BandPass", 6>; def ActivationModeElu: I32EnumAttrCase<"Elu", 7>; +def ActivationModeLeakyRelu: I32EnumAttrCase<"LeakyRelu", 8>; def Activation: I32EnumAttr<"Activation", "Activation applied with fused convolution", [ActivationModeNone, ActivationModeSigmoid, ActivationModeTanh, ActivationModeRelu, ActivationModeRelu6, ActivationModeReluX, - ActivationModeBandPass, ActivationModeElu]> { + ActivationModeBandPass, ActivationModeElu, ActivationModeLeakyRelu]> { let genSpecializedAttr = 0; let cppNamespace = "::mlir::lmhlo_gpu"; } diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/lhlo_gpu/lhlo_gpu_ops.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/lhlo_gpu/lhlo_gpu_ops.mlir index d925e6199d1be2..ca4b2c0be2a117 100644 --- a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/lhlo_gpu/lhlo_gpu_ops.mlir +++ b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/lhlo_gpu/lhlo_gpu_ops.mlir @@ -141,6 +141,7 @@ func.func @conv_fused(%input : memref<1x17x9x9xf16>, %filter : memref<3x3x17x32x dim_numbers = [b, f, 0, 1]x[0, 1, i, o]->[b, f, 0, 1], window = {stride = [1, 1], pad = [[1, 1], [1, 1]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]} { activation_mode = #lmhlo_gpu, + leakyrelu_alpha = 0.0 : f64, backend_config = #lmhlo_gpu.convolution_backend_config< algorithm = 0, tensor_ops_enabled = true, diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 4e654f9bf38446..a3a5b5c0d3fa78 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -3167,6 +3167,7 @@ xla_cc_test( xla_cc_test( name = "cudnn_fused_conv_rewriter_test", srcs = ["cudnn_fused_conv_rewriter_test.cc"], + shard_count = 10, tags = [ "gpu", "no_oss", diff --git a/tensorflow/compiler/xla/service/gpu/backend_configs.proto b/tensorflow/compiler/xla/service/gpu/backend_configs.proto index ffabfd56d3bb93..cd2f21a9eb8656 100644 --- a/tensorflow/compiler/xla/service/gpu/backend_configs.proto +++ b/tensorflow/compiler/xla/service/gpu/backend_configs.proto @@ -39,6 +39,15 @@ message CudnnConvBackendConfig { // is provided, this field must be 0. double side_input_scale = 5; + // The negative slope coefficient alpha for leaky_relu activation, used only + // when activation_mode is kLeakyRelu. + // + // leakyrelu(x) is defined as x > 0 ? x : alpha * x. + // + // Since this is a proto3 proto, leakyrelu_alpha is 0 if not specified (in + // which case the leakyrelu activation is equivalent to relu). + double leakyrelu_alpha = 8; + // If the filter (and bias, if present) have been reordered, set this flag. // It's placed into a `oneof` block to skip the serialization when not set. oneof filter_and_bias_reordering_oneof { diff --git a/tensorflow/compiler/xla/service/gpu/conv_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/conv_algorithm_picker.cc index b263969a5ee2a5..46e87ec03531e8 100644 --- a/tensorflow/compiler/xla/service/gpu/conv_algorithm_picker.cc +++ b/tensorflow/compiler/xla/service/gpu/conv_algorithm_picker.cc @@ -142,10 +142,10 @@ StatusOr> GetAlgorithms( BiasTypeForInputType(input_type), output_type, /* conv_input_scale = */ config.conv_result_scale, /* side_input_scale = */ config.fusion->side_input_scale, - /* leakyrelu_alpha = */ 0.0, stream, config.input_descriptor, - config.filter_descriptor, config.bias_descriptor, - config.output_descriptor, config.conv_desc, use_fallback, - config.fusion->mode, numeric_options, &runners)); + /* leakyrelu_alpha = */ config.fusion->leakyrelu_alpha, stream, + config.input_descriptor, config.filter_descriptor, + config.bias_descriptor, config.output_descriptor, config.conv_desc, + use_fallback, config.fusion->mode, numeric_options, &runners)); for (auto& runner : runners) { TF_ASSIGN_OR_RETURN( auto runner_cache, diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.cc b/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.cc index 1041dd4de5f5b9..6fa52445f637fa 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.cc @@ -59,17 +59,20 @@ bool IsConvDepthwise(const HloInstruction* instr) { return input_feature_count == feature_group_count; } +// We don't want to upgrade depthwise convolutions to ConvBiasActivation, +// because the fused CUDNN functions are slower for some of those. bool IsNonDepthwiseConvCustomCall(const HloInstruction* instr) { return IsConvCustomCall(instr) && !IsConvDepthwise(instr); } -bool IsExponentialMinusOne(const HloInstruction* instr) { - return instr->opcode() == HloOpcode::kExpm1; -} - -bool HasThreeUsers(const HloInstruction* instr) { - int64_t user_count = instr->user_count(); - return user_count == 3; +// elu, relu6, and leaky-relu activations are supported in cudnn via the +// "runtime fusion" engine, which JIT compiles C++ code. This can be slow to +// compile, so we guard it with a debug option. Also nvidia currently +// recommends that we enable this only on Ampere+. +bool ShouldUseCudnnRuntimeFusion(const DebugOptions& debug_opts, + se::CudaComputeCapability cc) { + return debug_opts.xla_gpu_use_runtime_fusion() && + cc.IsAtLeast(se::CudaComputeCapability::AMPERE); } // Can instr be converted to type `dst_ty` without losing any precision? For @@ -247,8 +250,6 @@ StatusOr FuseConvAlpha(HloComputation* comp) { HloInstruction* gte = nullptr; HloInstruction* alpha = nullptr; - // We don't want to upgrade depthwise convolutions to ConvBiasActivation, - // because the fused CUDNN functions are slower for some of those. auto pattern = m::MultiplyAnyOrder( m::GetTupleElement( >e, m::Op(&conv).WithPredicate(IsNonDepthwiseConvCustomCall), 0) @@ -298,8 +299,6 @@ StatusOr FuseBiasOrSideInput(HloComputation* comp) { HloInstruction* gte = nullptr; HloInstruction* addend = nullptr; - // We don't want to upgrade depthwise convolutions to ConvBiasActivation, - // because the fused CUDNN functions are slower for some of those. auto pattern = m::AddAnyOrder( m::GetTupleElement(>e, m::Op(&conv) @@ -510,43 +509,42 @@ StatusOr FuseSideInputAlpha(HloComputation* comp) { } StatusOr FuseElu(HloComputation* comp, se::CudaComputeCapability cc) { + if (!ShouldUseCudnnRuntimeFusion(comp->parent()->config().debug_options(), + cc)) { + return false; + } + bool changed = false; for (HloInstruction* instr : comp->MakeInstructionPostOrder()) { - const DebugOptions& debug_options = - instr->GetModule()->config().debug_options(); - if (!debug_options.xla_gpu_use_runtime_fusion() || - !cc.IsAtLeast(se::CudaComputeCapability::AMPERE)) { - return false; - } - - HloInstruction* gte; + HloInstruction *gte1, *gte2, *gte3; HloInstruction* conv; HloInstruction* expm1; - // In Elu computation, the GetTupleElement node will have three users: - // Compare, ExponentialMinusOnem, and Select. - // We don't want to upgrade depthwise convolutions to ConvBiasActivation, - // because the fused CUDNN functions are slower for some of those. - auto gte_pattern = - m::GetTupleElement(>e, - m::Op(&conv) - .WithPredicate(IsNonDepthwiseConvCustomCall) - .WithOneUse()) - .WithElementType(F16) - .WithPredicate(HasThreeUsers); if (!Match(instr, - m::Select(m::Compare(gte_pattern, + m::Select(m::Compare(m::GetTupleElement(>e1, m::Op()), m::Broadcast(m::ConstantEffectiveScalar(0))) .WithComparisonDirection(ComparisonDirection::kGt) .WithOneUse(), - gte_pattern, + m::GetTupleElement( + >e2, + m::Op(&conv) + .WithPredicate(IsNonDepthwiseConvCustomCall) + .WithOneUse(), + /*tuple_index=*/0) + // TODO(jlebar): Why only fp16? + .WithElementType(F16), m::Op(&expm1) - .WithPredicate(IsExponentialMinusOne) - .WithOperand(0, gte_pattern) + .WithOpcode(HloOpcode::kExpm1) + .WithOperand(0, m::GetTupleElement(>e3, m::Op())) .WithOneUse()))) { continue; } + // The three GTEs should be the same, and these should be the only uses. + if (gte1 != gte2 || gte2 != gte3 || gte1->user_count() != 3) { + continue; + } + // In some cases, the XLA optimizes the inputs of the convolution by // moving and broadcasting the bias to the side input, e.g., when the input // spatial dimensions are all ones and filter spatial dimentsions are all @@ -584,7 +582,7 @@ StatusOr FuseElu(HloComputation* comp, se::CudaComputeCapability cc) { TF_ASSIGN_OR_RETURN(conv, EnsureIsConvBiasActivation(conv)); config.set_activation_mode(se::dnn::kElu); TF_RETURN_IF_ERROR(conv->set_backend_config(config)); - TF_RETURN_IF_ERROR(comp->ReplaceInstruction(instr, gte)); + TF_RETURN_IF_ERROR(comp->ReplaceInstruction(instr, gte1)); changed = true; } return changed; @@ -595,8 +593,6 @@ StatusOr FuseRelu(HloComputation* comp) { for (HloInstruction* instr : comp->MakeInstructionPostOrder()) { HloInstruction* gte; HloInstruction* conv; - // We don't want to upgrade depthwise convolutions to ConvBiasActivation, - // because the fused CUDNN functions are slower for some of those. if (!Match(instr, m::MaximumAnyOrder( m::Broadcast(m::ConstantEffectiveScalar(0)), @@ -627,6 +623,115 @@ StatusOr FuseRelu(HloComputation* comp) { return changed; } +StatusOr FuseRelu6(HloComputation* comp, se::CudaComputeCapability cc) { + if (!ShouldUseCudnnRuntimeFusion(comp->parent()->config().debug_options(), + cc)) { + return false; + } + + bool changed = false; + for (HloInstruction* instr : comp->MakeInstructionPostOrder()) { + HloInstruction *gte, *conv; + if (!Match( + instr, + m::Clamp(m::Broadcast(m::ConstantEffectiveScalar(0)), + m::GetTupleElement( + >e, m::Op(&conv) + .WithPredicate(IsNonDepthwiseConvCustomCall) + .WithOneUse()) + // TODO(jlebar): Why only fp16? + .WithElementType(F16) + .WithOneUse(), + m::Broadcast(m::ConstantEffectiveScalar(6))))) { + continue; + } + TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig config, + conv->backend_config()); + if (config.activation_mode() != se::dnn::kNone) { + continue; + } + + // cudnn runtime fusions seem to be very slow when a side input is present. + // TODO(kaixih@nvidia): remove this check when cuDNN fixes it. + if (conv->operands().size() > 3) { + continue; + } + + if (!ConsumeFuel("cudnn-fused-convolution-rewriter", [&] { + return absl::StrCat("FuseRelu6: ", conv->ToString()); + })) { + continue; + } + TF_ASSIGN_OR_RETURN(conv, EnsureIsConvBiasActivation(conv)); + config.set_activation_mode(se::dnn::kRelu6); + TF_RETURN_IF_ERROR(conv->set_backend_config(config)); + TF_RETURN_IF_ERROR(comp->ReplaceInstruction(instr, gte)); + changed = true; + } + return changed; +} + +StatusOr FuseLeakyRelu(HloComputation* comp, + se::CudaComputeCapability cc) { + if (!ShouldUseCudnnRuntimeFusion(comp->parent()->config().debug_options(), + cc)) { + return false; + } + + bool changed = false; + for (HloInstruction* instr : comp->MakeInstructionPostOrder()) { + HloInstruction *gte1, *gte2, *gte3, *conv, *alpha; + if (!Match(instr, + m::Select( + m::Compare(m::GetTupleElement(>e1, m::Op()), + m::Broadcast(m::ConstantEffectiveScalar(0))) + .WithComparisonDirection(ComparisonDirection::kGt) + .WithOneUse(), + m::GetTupleElement( + >e2, m::Op(&conv) + .WithPredicate(IsNonDepthwiseConvCustomCall) + .WithOneUse()) + // TODO(jlebar): Why only fp16? + .WithElementType(F16), + m::Multiply(m::GetTupleElement(>e3, m::Op()), + m::Broadcast(m::ConstantEffectiveScalar(&alpha))) + .WithOneUse()))) { + continue; + } + + // The three GTEs should be the same, and these should be the only uses. + if (gte1 != gte2 || gte2 != gte3 || gte1->user_count() != 3) { + continue; + } + + TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig config, + conv->backend_config()); + if (config.activation_mode() != se::dnn::kNone) { + continue; + } + + // cudnn runtime fusions seem to be very slow when a side input is present. + // TODO(kaixih@nvidia): remove this check when cuDNN fixes it. + if (conv->operands().size() > 3) { + continue; + } + + if (!ConsumeFuel("cudnn-fused-convolution-rewriter", [&] { + return absl::StrCat("FuseLeakyRelu: ", conv->ToString()); + })) { + continue; + } + TF_ASSIGN_OR_RETURN(conv, EnsureIsConvBiasActivation(conv)); + config.set_activation_mode(se::dnn::kLeakyRelu); + TF_ASSIGN_OR_RETURN(Literal alpha_f64, alpha->literal().Convert(F64)); + config.set_leakyrelu_alpha(alpha_f64.GetFirstElement()); + TF_RETURN_IF_ERROR(conv->set_backend_config(config)); + TF_RETURN_IF_ERROR(comp->ReplaceInstruction(instr, gte1)); + changed = true; + } + return changed; +} + StatusOr FuseConvertToF16(HloComputation* comp) { bool changed = false; for (HloInstruction* instr : comp->MakeInstructionPostOrder()) { @@ -934,6 +1039,10 @@ StatusOr CudnnFusedConvRewriter::Run( any_changed |= changed; TF_ASSIGN_OR_RETURN(changed, FuseElu(comp, compute_capability_)); any_changed |= changed; + TF_ASSIGN_OR_RETURN(changed, FuseRelu6(comp, compute_capability_)); + any_changed |= changed; + TF_ASSIGN_OR_RETURN(changed, FuseLeakyRelu(comp, compute_capability_)); + any_changed |= changed; TF_ASSIGN_OR_RETURN(changed, FuseConvertToF16(comp)); any_changed |= changed; @@ -953,6 +1062,10 @@ StatusOr CudnnFusedConvRewriter::Run( any_changed |= changed; TF_ASSIGN_OR_RETURN(changed, FuseElu(comp, compute_capability_)); any_changed |= changed; + TF_ASSIGN_OR_RETURN(changed, FuseRelu6(comp, compute_capability_)); + any_changed |= changed; + TF_ASSIGN_OR_RETURN(changed, FuseLeakyRelu(comp, compute_capability_)); + any_changed |= changed; // Check that we don't have any convs outputting integer types other than // s8 - cudnn does not support these. They should have been transformed to diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter_test.cc b/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter_test.cc index 0e6d84cd2960fd..42326417821860 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter_test.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter_test.cc @@ -260,6 +260,59 @@ TEST_F(CudnnFusedConvRewriterTest, DontFuseEluWithDepthwiseConv) { })"); } +TEST_F(CudnnFusedConvRewriterTest, TestRelu6) { + if (!GetCudaComputeCapability().IsAtLeast( + se::CudaComputeCapability::AMPERE)) { + GTEST_SKIP() << "Conv-Bias-Relu6 fusion is supported and recommended with " + "the Nvidia Ampere+ GPUs."; + } + // sum = conv(x, w) + bias + // clamp(0, sum, 6); + TestMatchWithAllTypes(R"( + HloModule Test + ENTRY Test { + zero = TYPE[] constant(0) + zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={} + six = TYPE[] constant(6) + sixes = TYPE[1,3,3,64] broadcast(six), dimensions={} + input = TYPE[1,3,3,64] parameter(0) + filter = TYPE[3,3,64,64] parameter(1) + bias = TYPE[64] parameter(2) + conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1 + broadcasted_bias = TYPE[1,3,3,64] broadcast(bias), dimensions={3} + sum = TYPE[1,3,3,64] add(conv, broadcasted_bias) + ROOT relu6 = TYPE[1,3,3,64] clamp(zeros, sum, sixes) + })"); +} + +TEST_F(CudnnFusedConvRewriterTest, TestLeakyRelu) { + if (!GetCudaComputeCapability().IsAtLeast( + se::CudaComputeCapability::AMPERE)) { + GTEST_SKIP() + << "Conv-Bias-LeakyRelu fusion is supported and recommended with " + "the Nvidia Ampere+ GPUs."; + } + // sum = conv(x, w) + bias + // select(compare(sum, 0, GT), sum, multiply(sum, alpha)); + TestMatchWithAllTypes(R"( + HloModule Test + ENTRY Test { + zero = TYPE[] constant(0) + zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={} + alpha = TYPE[] constant(0.2) + alphas = TYPE[1,3,3,64] broadcast(alpha), dimensions={} + input = TYPE[1,3,3,64] parameter(0) + filter = TYPE[3,3,64,64] parameter(1) + bias = TYPE[64] parameter(2) + conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1 + broadcasted_bias = TYPE[1,3,3,64] broadcast(bias), dimensions={3} + sum = TYPE[1,3,3,64] add(conv, broadcasted_bias) + cmp = pred[1,3,3,64] compare(sum, zeros), direction=GT + mul = TYPE[1,3,3,64] multiply(sum, alphas) + ROOT elu = TYPE[1,3,3,64] select(cmp, sum, mul) + })"); +} + TEST_F(CudnnFusedConvRewriterTest, TestSideInputOnly) { // max(0, conv(x, w) + side_input); TestMatchWithAllTypes(R"( @@ -1013,6 +1066,178 @@ TEST_F(CudnnFusedConvRewriterHloTest, DontFuseEluIfMultipleUses) { EXPECT_EQ(config.activation_mode(), se::dnn::kNone); } +TEST_F(CudnnFusedConvRewriterHloTest, FuseRelu6) { + const std::string module_str = R"( + HloModule Test + ENTRY Test { + inputs = f16[1,17,9,9] parameter(0) + filters = f16[3,3,17,32] parameter(1) + bias = f16[32] parameter(2) + bias_broadcast = f16[1,32,9,9] broadcast(bias), dimensions={1} + zero = f16[] constant(0) + zeros = f16[1,32,9,9] broadcast(zero), dimensions={} + sixes = f16[1,32,9,9] broadcast(f16[] constant(6)), dimensions={} + conv = f16[1,32,9,9] convolution(inputs, filters), + window={size=3x3 pad=1_1x1_1}, + dim_labels=bf01_01io->bf01 + sum = add(conv, bias_broadcast) + ROOT relu = clamp(zeros, sum, sixes) + })"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); + + GpuConvRewriter rewriter; + TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); + // relu6 fusion is only enabled on Ampere+. + CudnnFusedConvRewriter fuser{se::CudaComputeCapability(8, 0)}; + TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); + SCOPED_TRACE(m->ToString()); + const HloInstruction* conv; + ASSERT_THAT( + m->entry_computation()->root_instruction(), + GmockMatch( + m::GetTupleElement( + m::CustomCall(&conv, {kCudnnConvBiasActivationForwardCallTarget}, + m::Parameter(0), m::Parameter(1), m::Parameter(2)), + 0) + .WithShape(F16, {1, 32, 9, 9}))); + TF_ASSERT_OK_AND_ASSIGN(auto config, + conv->backend_config()); + EXPECT_EQ(config.activation_mode(), se::dnn::kRelu6); +} + +TEST_F(CudnnFusedConvRewriterHloTest, DontFuseRelu6IfMultipleUses) { + const std::string module_str = R"( + HloModule Test + ENTRY Test { + inputs = f16[1,17,9,9] parameter(0) + filters = f16[3,3,17,32] parameter(1) + bias = f16[1,32,9,9] broadcast(f16[32] parameter(2)), dimensions={1} + zeros = f16[1,32,9,9] broadcast(f16[] constant(0)), dimensions={} + sixes = f16[1,32,9,9] broadcast(f16[] constant(6)), dimensions={} + conv = f16[1,32,9,9] convolution(inputs, filters), + window={size=3x3 pad=1_1x1_1}, + dim_labels=bf01_01io->bf01 + sum = add(conv, bias) + relu = clamp(zeros, sum, sixes) + not_relu = minimum(sum, zeros) + ROOT root = tuple(relu, not_relu) + })"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); + + GpuConvRewriter rewriter; + TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); + CudnnFusedConvRewriter fuser{GetCudaComputeCapability()}; + TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); + + SCOPED_TRACE(m->ToString()); + const HloInstruction* conv; + ASSERT_THAT( + m->entry_computation()->root_instruction(), + GmockMatch(m::Tuple( + m::Clamp(m::Broadcast(m::ConstantEffectiveScalar(0)), + m::GetTupleElement( + m::CustomCall( + &conv, {kCudnnConvBiasActivationForwardCallTarget}, + m::Parameter(0), m::Parameter(1), m::Parameter(2)), + 0) + .WithShape(F16, {1, 32, 9, 9}), + m::Broadcast(m::ConstantEffectiveScalar(6))), + m::Minimum()))); + TF_ASSERT_OK_AND_ASSIGN(auto config, + conv->backend_config()); + EXPECT_EQ(config.activation_mode(), se::dnn::kNone); +} + +TEST_F(CudnnFusedConvRewriterHloTest, FuseLeakyRelu) { + const std::string module_str = R"( + HloModule Test + ENTRY Test { + inputs = f16[1,16,9,9] parameter(0) + filters = f16[3,3,16,32] parameter(1) + bias = f16[1,32,9,9] broadcast(f16[32] parameter(2)), dimensions={1} + zeros = f16[1,32,9,9] broadcast(f16[] constant(0)), dimensions={} + alphas = f16[1,32,9,9] broadcast(f16[] constant(0.2)), dimensions={} + conv = f16[1,32,9,9] convolution(inputs, filters), + window={size=3x3 pad=1_1x1_1}, + dim_labels=bf01_01io->bf01 + sum = add(conv, bias) + cmp = compare(sum, zeros), direction=GT + mul = multiply(sum, alphas) + ROOT leaky_relu = select(cmp, sum, mul) + })"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); + + GpuConvRewriter rewriter; + TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); + // Leaky-relu fusion is only enabled on Ampere+. + CudnnFusedConvRewriter fuser{se::CudaComputeCapability(8, 0)}; + TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); + + SCOPED_TRACE(m->ToString()); + const HloInstruction* conv; + ASSERT_THAT( + m->entry_computation()->root_instruction(), + GmockMatch( + m::GetTupleElement( + m::CustomCall(&conv, {kCudnnConvBiasActivationForwardCallTarget}, + m::Parameter(0), m::Parameter(1), m::Parameter(2)), + 0) + .WithShape(F16, {1, 32, 9, 9}))); + TF_ASSERT_OK_AND_ASSIGN(auto config, + conv->backend_config()); + EXPECT_EQ(config.activation_mode(), se::dnn::kLeakyRelu); +} + +TEST_F(CudnnFusedConvRewriterHloTest, DontFuseLeakyReluIfMultipleUses) { + const std::string module_str = R"( + HloModule Test + ENTRY Test { + inputs = f16[1,16,9,9] parameter(0) + filters = f16[3,3,16,32] parameter(1) + bias = f16[1,32,9,9] broadcast(f16[32] parameter(2)), dimensions={1} + zeros = f16[1,32,9,9] broadcast(f16[] constant(0)), dimensions={} + alphas = f16[1,32,9,9] broadcast(f16[] constant(0.2)), dimensions={} + conv = f16[1,32,9,9] convolution(inputs, filters), + window={size=3x3 pad=1_1x1_1}, + dim_labels=bf01_01io->bf01 + sum = add(conv, bias) + cmp = compare(sum, zeros), direction=GT + mul = multiply(sum, alphas) + leaky_relu = select(cmp, sum, mul) + not_leaky_relu = minimum(sum, zeros) + ROOT root = tuple(leaky_relu, not_leaky_relu) + })"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); + + GpuConvRewriter rewriter; + TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); + CudnnFusedConvRewriter fuser{GetCudaComputeCapability()}; + TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); + + SCOPED_TRACE(m->ToString()); + const HloInstruction* conv; + auto gte_pattern = + m::GetTupleElement( + m::CustomCall(&conv, {kCudnnConvBiasActivationForwardCallTarget}, + m::Parameter(0), m::Parameter(1), m::Parameter(2)), + 0) + .WithShape(F16, {1, 32, 9, 9}); + ASSERT_THAT( + m->entry_computation()->root_instruction(), + GmockMatch(m::Tuple( + m::Select(m::Compare(gte_pattern, + m::Broadcast(m::ConstantEffectiveScalar(0))) + .WithComparisonDirection(ComparisonDirection::kGt) + .WithOneUse(), + gte_pattern, + m::Multiply(gte_pattern, + m::Broadcast(m::ConstantEffectiveScalar()))), + m::Minimum()))); + TF_ASSERT_OK_AND_ASSIGN(auto config, + conv->backend_config()); + EXPECT_EQ(config.activation_mode(), se::dnn::kNone); +} + TEST_F(CudnnFusedConvRewriterHloTest, DontFuseAlphaIfMultipleUsers) { const std::string module_str = R"( HloModule Test diff --git a/tensorflow/compiler/xla/service/gpu/gpu_conv_runner.cc b/tensorflow/compiler/xla/service/gpu/gpu_conv_runner.cc index 3915160c065e7a..3e9ba809238283 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_conv_runner.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_conv_runner.cc @@ -133,7 +133,7 @@ Status RunGpuConvForwardActivation(const GpuConvParams& params, output_type, params.config->conv_result_scale, params.config->fusion->side_input_scale, - /* leakyrelu_alpha = */ 0.0, + params.config->fusion->leakyrelu_alpha, params.config->input_descriptor, params.config->filter_descriptor, params.config->bias_descriptor, @@ -295,15 +295,18 @@ StatusOr GetGpuConvConfig( } if (config.kind == CudnnConvKind::kForwardActivation) { - config.fusion.emplace(); - GpuConvConfig::FusionConfig& fusion = *config.fusion; if (!se::dnn::ActivationMode_IsValid(backend_config.activation_mode())) { return InternalError("Bad activation mode: %s", backend_config.ShortDebugString()); } + + GpuConvConfig::FusionConfig fusion; fusion.mode = static_cast(backend_config.activation_mode()); fusion.side_input_scale = backend_config.side_input_scale(); + fusion.leakyrelu_alpha = backend_config.leakyrelu_alpha(); + + config.fusion = fusion; } const Window& window = desc.window; diff --git a/tensorflow/compiler/xla/service/gpu/gpu_conv_runner.h b/tensorflow/compiler/xla/service/gpu/gpu_conv_runner.h index 0a1d22f7073d3a..0c27d10099121e 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_conv_runner.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_conv_runner.h @@ -50,6 +50,7 @@ struct GpuConvConfig { struct FusionConfig { se::dnn::ActivationMode mode; double side_input_scale; + double leakyrelu_alpha = 0.0; }; PrimitiveType input_type; diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index a8524de7f3bc46..5c00bb6ab9aebd 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -1086,6 +1086,8 @@ Status IrEmitterUnnested::EmitConvolutionThunk(mlir::Operation* op) { descriptor.kind = CudnnConvKind::kForwardActivation; fill_conv_descriptor(conv); TF_RETURN_IF_ERROR(set_activation_mode(conv)); + descriptor.backend_config.set_leakyrelu_alpha( + conv.getLeakyreluAlpha().convertToDouble()); } else if (auto conv = dyn_cast(op)) { descriptor.kind = CudnnConvKind::kForwardActivation; fill_conv_descriptor(conv); diff --git a/tensorflow/compiler/xla/service/gpu/runtime/conv.cc b/tensorflow/compiler/xla/service/gpu/runtime/conv.cc index c125a110463332..ccc8bbf79cb1ec 100644 --- a/tensorflow/compiler/xla/service/gpu/runtime/conv.cc +++ b/tensorflow/compiler/xla/service/gpu/runtime/conv.cc @@ -228,6 +228,10 @@ struct SideInputAttrs { double side_input_scale; }; +struct LeakyReluAlphaAttrs { + double leaky_relu_alpha; +}; + } // namespace static GpuConvDescriptor GetConvDescriptor( @@ -239,7 +243,8 @@ static GpuConvDescriptor GetConvDescriptor( ConvDimensionNumbers dims, Window w, ConvBackendConfig b, ConvAttrs attrs, // Conv-specific arguments and attributes std::optional fused = std::nullopt, - std::optional side_input = std::nullopt) { + std::optional side_input = std::nullopt, + std::optional leakyrelu_alpha = std::nullopt) { // Build a convolution descriptor from the attributes. GpuConvDescriptor descriptor; descriptor.kind = kind; @@ -313,6 +318,11 @@ static GpuConvDescriptor GetConvDescriptor( if (fused.has_value()) descriptor.backend_config.set_activation_mode(fused->activation_mode); + // Set attributes specific for fused convolutions with leaky_relu_alpha. + if (leakyrelu_alpha.has_value()) + descriptor.backend_config.set_leakyrelu_alpha( + leakyrelu_alpha->leaky_relu_alpha); + // Set attributes specific for convolutions with side input. if (side_input.has_value()) descriptor.backend_config.set_side_input_scale( @@ -344,7 +354,8 @@ static absl::Status ConvImpl( int64_t feature_group_count, double result_scale, // Optional attributes for fused convolutions. std::optional activation_mode = std::nullopt, - std::optional side_input_scale = std::nullopt) { + std::optional side_input_scale = std::nullopt, + std::optional leakyrelu_alpha = std::nullopt) { // Build config for optional attributes. std::optional fused_attrs = std::nullopt; if (activation_mode.has_value()) fused_attrs = {*activation_mode}; @@ -352,6 +363,9 @@ static absl::Status ConvImpl( std::optional side_input_attrs = std::nullopt; if (side_input_scale.has_value()) side_input_attrs = {*side_input_scale}; + std::optional leakyrelu_alpha_attrs = std::nullopt; + if (leakyrelu_alpha.has_value()) leakyrelu_alpha_attrs = {*leakyrelu_alpha}; + bool runtime_autotuning = false; if (backend_config.algorithm == -1) { // Set the algorithm back to the default algorithm to avoid error from @@ -369,7 +383,7 @@ static absl::Status ConvImpl( {window_strides, padding, lhs_dilation, rhs_dilation, window_reversal}, backend_config, {feature_group_count, result_scale}, fused_attrs, - side_input_attrs); + side_input_attrs, leakyrelu_alpha_attrs); TF_ASSIGN_OR_RETURN(GpuConvConfig conv_config, GetGpuConvConfig(descriptor, "")); @@ -495,6 +509,7 @@ XLA_RUNTIME_DEFINE_CUSTOM_CALL_TEMPLATE( ) .Value(std::optional()) // activation_mode .Value(std::optional()) // side_input_scale + .Value(std::optional()) // leaky_relu_alpha ); XLA_RUNTIME_DEFINE_CUSTOM_CALL( @@ -513,7 +528,8 @@ XLA_RUNTIME_DEFINE_CUSTOM_CALL( .Arg() // scratch ) .Attr("activation_mode") - .Value(std::optional()) // side_input_scale + .Value(std::optional()) // side_input_scale + .Attr("leakyrelu_alpha") // leaky_relu_alpha ); XLA_RUNTIME_DEFINE_CUSTOM_CALL( @@ -532,7 +548,8 @@ XLA_RUNTIME_DEFINE_CUSTOM_CALL( .Arg() // scratch ) .Attr("activation_mode") - .Attr("side_input_scale")); + .Attr("side_input_scale") + .Value(std::optional())); // leaky_relu_alpha //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index 1d258ee5deeff1..5c4c4ba2a02566 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -1205,6 +1205,9 @@ ExtractCudnnConvBackendConfigProps(const gpu::CudnnConvBackendConfig& config) { if (config.side_input_scale() != 0 && config.side_input_scale() != 1) { props.emplace_back("side_input_scale", StrCat(config.side_input_scale())); } + if (config.activation_mode() == se::dnn::ActivationMode::kLeakyRelu) { + props.emplace_back("leakyrelu_alpha", StrCat(config.leakyrelu_alpha())); + } props.emplace_back( "activation_mode", se::dnn::ActivationModeString( diff --git a/tensorflow/compiler/xla/service/xla_aot_compile_test_autotune_results.prototxt b/tensorflow/compiler/xla/service/xla_aot_compile_test_autotune_results.prototxt index 7db99942aaec54..6b2ad5fab19bdf 100644 --- a/tensorflow/compiler/xla/service/xla_aot_compile_test_autotune_results.prototxt +++ b/tensorflow/compiler/xla/service/xla_aot_compile_test_autotune_results.prototxt @@ -23,7 +23,7 @@ results { } results { device: "sm_6.0 with 17071734784B RAM, 56 cores, 1480500KHz clock, 715000KHz mem clock, 4194304B L2$" - hlo: "(f32[1,1,2,3]{3,2,1,0}, u8[0]{0}) custom-call(f32[2,1,4,4]{3,2,1,0}, f32[2,1,3,2]{3,2,1,0}), window={size=2x3}, dim_labels=bf01_oi01->bf01, custom_call_target=\"__cudnn$convBackwardFilter\", backend_config={\"activation_mode\":\"kNone\",\"conv_result_scale\":1,\"side_input_scale\":0}" + hlo: "(f32[1,1,2,3]{3,2,1,0}, u8[0]{0}) custom-call(f32[2,1,4,4]{3,2,1,0}, f32[2,1,3,2]{3,2,1,0}), window={size=2x3}, dim_labels=bf01_oi01->bf01, custom_call_target=\"__cudnn$convBackwardFilter\", backend_config={\"activation_mode\":\"kNone\",\"conv_result_scale\":1,\"side_input_scale\":0,\"leakyrelu_alpha\":0}" result { run_time { nanos: 45408 diff --git a/tensorflow/compiler/xla/stream_executor/dnn.cc b/tensorflow/compiler/xla/stream_executor/dnn.cc index 03815ff24df596..592f354c5addf0 100644 --- a/tensorflow/compiler/xla/stream_executor/dnn.cc +++ b/tensorflow/compiler/xla/stream_executor/dnn.cc @@ -367,6 +367,8 @@ std::string ActivationModeString(ActivationMode mode) { return "bandpass"; case ActivationMode::kElu: return "elu"; + case ActivationMode::kLeakyRelu: + return "leakyrelu"; default: return absl::StrCat("unknown: ", static_cast(mode)); } diff --git a/tensorflow/compiler/xla/translate/mhlo_to_hlo/attribute_exporter.cc b/tensorflow/compiler/xla/translate/mhlo_to_hlo/attribute_exporter.cc index 2f5161e5f5c4a3..8796c517915fb1 100644 --- a/tensorflow/compiler/xla/translate/mhlo_to_hlo/attribute_exporter.cc +++ b/tensorflow/compiler/xla/translate/mhlo_to_hlo/attribute_exporter.cc @@ -76,6 +76,8 @@ StatusOr ConvertConvActivationMode( return stream_executor::dnn::kBandPass; case mlir::lmhlo_gpu::Activation::Elu: return stream_executor::dnn::kElu; + case mlir::lmhlo_gpu::Activation::LeakyRelu: + return stream_executor::dnn::kLeakyRelu; default: return InternalError("Unexpected activation"); } diff --git a/tensorflow/compiler/xla/translate/mhlo_to_lhlo_with_xla/mhlo_to_lhlo_with_xla.cc b/tensorflow/compiler/xla/translate/mhlo_to_lhlo_with_xla/mhlo_to_lhlo_with_xla.cc index 40babb062f280a..6b5622727e2d23 100644 --- a/tensorflow/compiler/xla/translate/mhlo_to_lhlo_with_xla/mhlo_to_lhlo_with_xla.cc +++ b/tensorflow/compiler/xla/translate/mhlo_to_lhlo_with_xla/mhlo_to_lhlo_with_xla.cc @@ -1104,6 +1104,8 @@ static tsl::StatusOr GetLHLOActivation( return mlir::lmhlo_gpu::Activation::BandPass; case stream_executor::dnn::kElu: return mlir::lmhlo_gpu::Activation::Elu; + case stream_executor::dnn::kLeakyRelu: + return mlir::lmhlo_gpu::Activation::LeakyRelu; default: return xla::InternalError("Unknown activation"); } @@ -1235,6 +1237,8 @@ tsl::StatusOr LhloDialectEmitter::EmitDnnConvolution( auto cnn_fused, CreateOpWithoutAttrs(custom_call)); TF_RETURN_IF_ERROR(set_activation(cnn_fused)); + cnn_fused.setLeakyreluAlphaAttr( + builder_.getF64FloatAttr(backend_config.leakyrelu_alpha())); return set_common_conv_attributes(cnn_fused); } From 05f947cb4a8393ca84be5cc879b4d030b965114a Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Wed, 12 Jul 2023 11:25:09 -0700 Subject: [PATCH 203/376] [xla:gpu] Remove async passes from default Gpu pipeline PiperOrigin-RevId: 547552220 --- .../mlir/backends/gpu/transforms/passes.cc | 4 --- .../transforms/compilation_pipeline_gpu.cc | 28 +++++++++++-------- .../transforms/compilation_pipeline_gpu.h | 5 ++-- .../compiler/xla/runtime/custom_call_test.cc | 3 +- 4 files changed, 20 insertions(+), 20 deletions(-) diff --git a/tensorflow/compiler/xla/mlir/backends/gpu/transforms/passes.cc b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/passes.cc index d0c4861dbf3cac..54f91186b8be8f 100644 --- a/tensorflow/compiler/xla/mlir/backends/gpu/transforms/passes.cc +++ b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/passes.cc @@ -15,9 +15,6 @@ limitations under the License. #include "tensorflow/compiler/xla/mlir/backends/gpu/transforms/passes.h" -#include -#include - #include "mlir/Pass/PassManager.h" // from @llvm-project #include "mlir/Transforms/Passes.h" // from @llvm-project @@ -34,7 +31,6 @@ void populateXlaGpuRuntimePasses(mlir::OpPassManager& pm, // Clean up IR before converting it to the runtime operations. pm.addPass(createCSEPass()); - pm.addPass(createCanonicalizerPass()); // Convert global memrefs corresponding to constant arguments. pm.addPass(createConvertMemrefGetGlobalToArgPass()); diff --git a/tensorflow/compiler/xla/mlir/runtime/transforms/compilation_pipeline_gpu.cc b/tensorflow/compiler/xla/mlir/runtime/transforms/compilation_pipeline_gpu.cc index a110e81c681531..871fbb41aa0214 100644 --- a/tensorflow/compiler/xla/mlir/runtime/transforms/compilation_pipeline_gpu.cc +++ b/tensorflow/compiler/xla/mlir/runtime/transforms/compilation_pipeline_gpu.cc @@ -28,7 +28,6 @@ limitations under the License. #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project #include "mlir/Dialect/SCF/IR/SCF.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" // from @llvm-project #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" // from @llvm-project #include "mlir/Transforms/Passes.h" // from @llvm-project @@ -63,20 +62,24 @@ void RegisterTestlibDialect(DialectRegistry& dialects) { } static void CreateDefaultXlaGpuRuntimeCompilationPipeline( - mlir::OpPassManager& pm, const CompilationPipelineOptions& opts) { + mlir::OpPassManager& pm, const CompilationPipelineOptions& opts, + bool add_async_passes) { pm.addPass(mlir::createConvertSCFToCFPass()); - pm.addPass(mlir::createAsyncFuncToAsyncRuntimePass()); + + if (add_async_passes) pm.addPass(mlir::createAsyncFuncToAsyncRuntimePass()); // Export functions to the XLA runtime. pm.addPass(CreateExportRuntimeFunctionsPass()); pm.addPass(CreateConvertCustomCallsPass()); pm.addPass(CreateConvertAssertsPass()); - // Lower from high level async operations to async runtime. - pm.addPass(mlir::createAsyncToAsyncRuntimePass()); + if (add_async_passes) { + // Lower from high level async operations to async runtime. + pm.addPass(mlir::createAsyncToAsyncRuntimePass()); - // Add async.runtime reference counting operations. - pm.addPass(mlir::createAsyncRuntimePolicyBasedRefCountingPass()); + // Add async.runtime reference counting operations. + pm.addPass(mlir::createAsyncRuntimePolicyBasedRefCountingPass()); + } // Convert runtime operations and custom calls to LLVM dialect. ConvertRuntimeToLLvmOpts rt_to_llvm_opts = { @@ -86,7 +89,7 @@ static void CreateDefaultXlaGpuRuntimeCompilationPipeline( pm.addPass(CreateConvertRuntimeToLLVMPass(std::move(rt_to_llvm_opts))); // Convert async dialect to LLVM once everything else is in the LLVM dialect. - pm.addPass(mlir::createConvertAsyncToLLVMPass()); + if (add_async_passes) pm.addPass(mlir::createConvertAsyncToLLVMPass()); // Convert everything else to LLVM dialect. pm.addPass(mlir::createFinalizeMemRefToLLVMConversionPass()); @@ -94,13 +97,14 @@ static void CreateDefaultXlaGpuRuntimeCompilationPipeline( pm.addPass(mlir::createReconcileUnrealizedCastsPass()); // Clean up IR before passing it to LLVM. - pm.addPass(mlir::createCanonicalizerPass()); pm.addPass(mlir::createCSEPass()); } void CreateDefaultXlaGpuRuntimeCompilationPipeline( - PassManager& passes, const CompilationPipelineOptions& opts) { - CreateDefaultXlaGpuRuntimeCompilationPipeline(*passes, opts); + PassManager& passes, const CompilationPipelineOptions& opts, + bool add_async_passes) { + CreateDefaultXlaGpuRuntimeCompilationPipeline(*passes, opts, + add_async_passes); } void AppendXlaGpuDialectRegistry(mlir::MLIRContext& context) { @@ -111,7 +115,7 @@ void AppendXlaGpuDialectRegistry(mlir::MLIRContext& context) { static void CreateDefaultGpuPipeline(mlir::OpPassManager& pm) { CompilationPipelineOptions copts; - CreateDefaultXlaGpuRuntimeCompilationPipeline(pm, copts); + CreateDefaultXlaGpuRuntimeCompilationPipeline(pm, copts, false); } static mlir::PassPipelineRegistration<> kXlaRuntimePipeline( diff --git a/tensorflow/compiler/xla/mlir/runtime/transforms/compilation_pipeline_gpu.h b/tensorflow/compiler/xla/mlir/runtime/transforms/compilation_pipeline_gpu.h index 5f78e16fbf3719..4bf1bc7c2d4e66 100644 --- a/tensorflow/compiler/xla/mlir/runtime/transforms/compilation_pipeline_gpu.h +++ b/tensorflow/compiler/xla/mlir/runtime/transforms/compilation_pipeline_gpu.h @@ -16,8 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_MLIR_RUNTIME_TRANSFORMS_COMPILATION_PIPELINE_GPU_H_ #define TENSORFLOW_COMPILER_XLA_MLIR_RUNTIME_TRANSFORMS_COMPILATION_PIPELINE_GPU_H_ -#include - #include "tensorflow/compiler/xla/mlir/runtime/transforms/compilation_pipeline_options.h" #include "tensorflow/compiler/xla/runtime/compiler.h" @@ -42,7 +40,8 @@ void RegisterTestlibDialect(DialectRegistry& dialects); // it is expected that all end users will construct their own compilation // pipelines from the available XLA and MLIR passes. void CreateDefaultXlaGpuRuntimeCompilationPipeline( - PassManager& passes, const CompilationPipelineOptions& opts); + PassManager& passes, const CompilationPipelineOptions& opts, + bool add_async_passes = false); void AppendXlaGpuDialectRegistry(mlir::MLIRContext& context); diff --git a/tensorflow/compiler/xla/runtime/custom_call_test.cc b/tensorflow/compiler/xla/runtime/custom_call_test.cc index 28888354a2e09c..0098ec53cf5f9a 100644 --- a/tensorflow/compiler/xla/runtime/custom_call_test.cc +++ b/tensorflow/compiler/xla/runtime/custom_call_test.cc @@ -72,7 +72,8 @@ static absl::StatusOr Compile( }; opts.compiler.create_compilation_pipeline = [=](PassManager& passes) { - CreateDefaultXlaGpuRuntimeCompilationPipeline(passes, copts); + CreateDefaultXlaGpuRuntimeCompilationPipeline(passes, copts, + /*add_async_passes=*/true); }; return JitExecutable::Instantiate(source, opts, exported); From 729bf538dd1c69a4746aadd51c7ce72c4ff92b3b Mon Sep 17 00:00:00 2001 From: Matt Callanan Date: Wed, 12 Jul 2023 11:29:48 -0700 Subject: [PATCH 204/376] #tf-data Set up `file_locality` experiment. PiperOrigin-RevId: 547553450 --- tensorflow/core/data/dataset_utils.cc | 2 ++ tensorflow/core/data/utils.cc | 2 ++ tensorflow/core/data/utils.h | 4 ++++ 3 files changed, 8 insertions(+) diff --git a/tensorflow/core/data/dataset_utils.cc b/tensorflow/core/data/dataset_utils.cc index 71bfed7bb1d8ae..7332de6026402f 100644 --- a/tensorflow/core/data/dataset_utils.cc +++ b/tensorflow/core/data/dataset_utils.cc @@ -977,6 +977,8 @@ REGISTER_DATASET_EXPERIMENT("stage_based_autotune_v2", RandomJobSamplePercentage<0>, IndependentHostTasks); REGISTER_DATASET_EXPERIMENT("data_transfer", RandomJobSamplePercentage<50>, AllTasks); +REGISTER_DATASET_EXPERIMENT("file_locality", RandomJobSamplePercentage<0>, + IndependentHostTasks); } // namespace } // namespace data } // namespace tensorflow diff --git a/tensorflow/core/data/utils.cc b/tensorflow/core/data/utils.cc index a8f72ce1773bac..4e5d1211644ec8 100644 --- a/tensorflow/core/data/utils.cc +++ b/tensorflow/core/data/utils.cc @@ -33,5 +33,7 @@ std::string TranslateFileName(const std::string& fname) { return fname; } std::string DefaultDataTransferProtocol() { return "grpc"; } +std::string LocalityOptimizedPath(const std::string& path) { return path; } + } // namespace data } // namespace tensorflow diff --git a/tensorflow/core/data/utils.h b/tensorflow/core/data/utils.h index 710a480e927a3d..43ae1263490580 100644 --- a/tensorflow/core/data/utils.h +++ b/tensorflow/core/data/utils.h @@ -34,6 +34,10 @@ std::string TranslateFileName(const std::string& fname); // user. std::string DefaultDataTransferProtocol(); +// Returns a path pointing to the same file as `path` with a potential locality +// optimization. +std::string LocalityOptimizedPath(const std::string& path); + } // namespace data } // namespace tensorflow From bd2188bfd66080bc86002657c6a6e97f9c5b8508 Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Wed, 12 Jul 2023 11:32:20 -0700 Subject: [PATCH 205/376] [XLA] Add missing argument `shard_count` to the table for AllGather. PiperOrigin-RevId: 547554197 --- tensorflow/compiler/xla/g3doc/operation_semantics.md | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorflow/compiler/xla/g3doc/operation_semantics.md b/tensorflow/compiler/xla/g3doc/operation_semantics.md index 83e21ca58de84f..aebf14c80eb0fe 100644 --- a/tensorflow/compiler/xla/g3doc/operation_semantics.md +++ b/tensorflow/compiler/xla/g3doc/operation_semantics.md @@ -44,6 +44,7 @@ channel_id)` | `operand` | `XlaOp` | Array to concatenate across | : : : replicas. : | `all_gather_dim` | `int64` | Concatenation dimension. | +| `shard_count` | `int64` | Size of each replica group. | | `replica_groups` | vector of vectors of | Groups between which the | : : `int64` : concatenation is performed. : | `channel_id` | optional `int64` | Optional channel ID for | From 9d0bf9026b8df05dbfbf55a9633b831f37972ec4 Mon Sep 17 00:00:00 2001 From: Victor Stone Date: Wed, 12 Jul 2023 11:42:34 -0700 Subject: [PATCH 206/376] Extract HloRematerialization options out into a new struct called Options. Lots of arguments are being passed into the HloRematerialization pass and more will be added. These arguments are also passed around in the implementation of this pass. This refactoring groups together all of the arguments of HloRematerialization into a single struct to make the interfaces of classes & functions cleaner. PiperOrigin-RevId: 547557424 --- .../service/gpu/compile_module_to_llvm_ir.cc | 6 +- .../xla/service/hlo_rematerialization.cc | 112 +++++++-------- .../xla/service/hlo_rematerialization.h | 130 +++++++++--------- .../xla/service/hlo_rematerialization_test.cc | 11 +- 4 files changed, 126 insertions(+), 133 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/compile_module_to_llvm_ir.cc b/tensorflow/compiler/xla/service/gpu/compile_module_to_llvm_ir.cc index 1f9d1b00094a69..1a23214065617d 100644 --- a/tensorflow/compiler/xla/service/gpu/compile_module_to_llvm_ir.cc +++ b/tensorflow/compiler/xla/service/gpu/compile_module_to_llvm_ir.cc @@ -321,18 +321,18 @@ Status CompileModuleToLlvmIrImpl( { HloPassPipeline pipeline("remat-pipeline"); - HloRematerialization::RematerializationSizes sizes; - pipeline.AddPass( + HloRematerialization::Options options( [pointer_size](const Shape& shape) { return GetSizeOfShape(shape, pointer_size); }, // Assume 75% of the total device memory is available for XLA. /*memory_limit_bytes=*/gpu_device_info.device_memory_size * 0.75, - /*sizes=*/&sizes, HloRematerialization::RematerializationPass::kPostFusion, /*block_size_limit=*/1, /*block_rematerialization_factor=*/1, /*compact_shape_function=*/nullptr, HloRematerialization::RematerializationMode::kRecomputeAndCompress); + HloRematerialization::RematerializationSizes sizes; + pipeline.AddPass(options, sizes); TF_ASSIGN_OR_RETURN(bool changed, pipeline.Run(hlo_module)); if (changed) { diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc index ce804404efb9ea..b2f517962c6e75 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc @@ -498,13 +498,10 @@ UsesList GetUsers(const InstructionList& instruction_list, // (LogicalBuffers) at the current point in the instruction sequence. class MemoryUsageTracker { public: - MemoryUsageTracker( - const HloComputation* computation, - const HloRematerialization::ShapeSizeFunction& size_function, - const HloRematerialization::CompactShapeFunction& compact_shape_function, - const TuplePointsToAnalysis& points_to_analysis, - const InstructionList& instruction_list, - HloRematerialization::RematerializationMode mode); + MemoryUsageTracker(const HloRematerialization::Options& options, + const HloComputation* computation, + const TuplePointsToAnalysis& points_to_analysis, + const InstructionList& instruction_list); // Starts the placement of the given instruction. This adds the sizes of the // LogicalBuffers defined by the instruction to the current memory @@ -609,6 +606,8 @@ class MemoryUsageTracker { const HloComputation* computation() const { return computation_; } + const HloRematerialization::Options& options() const { return options_; } + // Check invariants of the data structure. This is expensive to call. bool Check() const; @@ -758,12 +757,15 @@ class MemoryUsageTracker { } return users_set.size(); }; - buffers_.push_back(Buffer{ - buffer_id, defining_instruction, size_function_(shape), shape, live_out, - has_indirect_uses, index, uses, get_num_of_unique_users(uses)}); + buffers_.push_back(Buffer{buffer_id, defining_instruction, + options_.size_function(shape), shape, live_out, + has_indirect_uses, index, uses, + get_num_of_unique_users(uses)}); return buffers_.back(); } + const HloRematerialization::Options& options_; + const HloComputation* computation_; // Instruction list containing the ordering of instructions in @@ -771,13 +773,6 @@ class MemoryUsageTracker { // (BeginInstruction/EndInstruction calls). const InstructionList& instruction_list_; - // Size function returns the bytes of a given buffer. - const HloRematerialization::ShapeSizeFunction& size_function_; - - // Converts a shape into compact form, returns the same shape if a shape is - // already considered compact. - const HloRematerialization::CompactShapeFunction& compact_shape_function_; - // A map that caches existing known compact shape for each instruction. absl::flat_hash_map compact_shape_; @@ -788,23 +783,18 @@ class MemoryUsageTracker { // between the calling of BeginInstruction and EndInstruction. Item* in_progress_item_ = nullptr; - HloRematerialization::RematerializationMode mode_; // All buffers in the computation. std::vector buffers_; }; MemoryUsageTracker::MemoryUsageTracker( + const HloRematerialization::Options& options, const HloComputation* computation, - const HloRematerialization::ShapeSizeFunction& size_function, - const HloRematerialization::CompactShapeFunction& compact_shape_function, const TuplePointsToAnalysis& points_to_analysis, - const InstructionList& instruction_list, - HloRematerialization::RematerializationMode mode) - : computation_(computation), - instruction_list_(instruction_list), - size_function_(size_function), - compact_shape_function_(compact_shape_function), - mode_(mode) { + const InstructionList& instruction_list) + : options_(options), + computation_(computation), + instruction_list_(instruction_list) { PointsToSet::BufferSet live_out_set = points_to_analysis.GetPointsToSet(computation_->root_instruction()) .CreateFlattenedSet(); @@ -958,7 +948,7 @@ int64_t MemoryUsageTracker::MemoryReducedIfCompressed( const Buffer& buffer = buffers_.at(buffer_id); memory_reduced += buffer.size; - int64_t compact_shape_size = size_function_(compact_shape); + int64_t compact_shape_size = options_.size_function(compact_shape); // Account for buffers that are compressed after instruction. memory_reduced -= compact_shape_size; } @@ -1027,9 +1017,10 @@ Status MemoryUsageTracker::AddCompressInstructions(Item* original_item, Item* compressed_item, Item* uncompressed_item) { // Original buffer is now dead. - memory_usage_ -= size_function_(original_item->instruction->shape()); + memory_usage_ -= options_.size_function(original_item->instruction->shape()); // Compressed buffer is now alive. - memory_usage_ += size_function_(compressed_item->instruction->shape()); + memory_usage_ += + options_.size_function(compressed_item->instruction->shape()); UsesList placed_users; UsesList unplaced_users; @@ -1261,7 +1252,8 @@ StatusOr MemoryUsageTracker::GetCompactShape(const HloInstruction* hlo) { return it->second; } const Shape& original_shape = hlo->shape(); - TF_ASSIGN_OR_RETURN(Shape min_shape, compact_shape_function_(original_shape)); + TF_ASSIGN_OR_RETURN(Shape min_shape, + options_.compact_shape_function(original_shape)); compact_shape_[hlo] = min_shape; return min_shape; } @@ -1424,10 +1416,10 @@ MemoryUsageTracker::PickRematerializationCandidates( auto* item = block[0]; auto* candidate = item->instruction; if (item->buffers_output.size() == 1 && - (mode_ == + (options_.mode == HloRematerialization::RematerializationMode::kCompressOnly || - mode_ == HloRematerialization::RematerializationMode:: - kRecomputeAndCompress)) { + options_.mode == HloRematerialization::RematerializationMode:: + kRecomputeAndCompress)) { // Only consider compressing single output instruction. const Buffer& output_buffer = buffers_.at(item->buffers_output[0]); @@ -1442,8 +1434,10 @@ MemoryUsageTracker::PickRematerializationCandidates( // while performing the compression/uncompression, only perform // the compression if the sum of the two sizes is less than the // peak memory. - const int64_t size = size_function_(item->instruction->shape()); - const int64_t reduced_size = size_function_(compact_shape); + const int64_t size = + options_.size_function(item->instruction->shape()); + const int64_t reduced_size = + options_.size_function(compact_shape); effort++; if (memory_reduced > 0 && size + reduced_size < peak_memory_bytes) { @@ -1464,7 +1458,8 @@ MemoryUsageTracker::PickRematerializationCandidates( } } // Do not consider recomputation in compress-only mode. - if (mode_ == HloRematerialization::RematerializationMode::kCompressOnly) { + if (options_.mode == + HloRematerialization::RematerializationMode::kCompressOnly) { // break out of this loop. Move on to the next start_item. break; } @@ -1861,9 +1856,8 @@ StatusOr HloRematerialization::ComputePeakMemory( const HloComputation* computation, const HloInstructionSequence& order, const absl::flat_hash_set& execution_threads) const { InstructionList instruction_list(order); - MemoryUsageTracker tracker(computation, size_function_, - compact_shape_function_, *points_to_analysis_, - instruction_list, mode_); + MemoryUsageTracker tracker(options_, computation, *points_to_analysis_, + instruction_list); int64_t peak_memory = tracker.memory_usage(); for (auto* item = instruction_list.first(); item != nullptr; item = instruction_list.next(item)) { @@ -1927,9 +1921,8 @@ StatusOr HloRematerialization::RematerializeComputation( CHECK(!ContainsKey(rematerialized_computations_, computation)); InstructionList instruction_list(schedule->sequence(computation)); - MemoryUsageTracker memory_tracker( - computation, size_function_, compact_shape_function_, - *points_to_analysis_, instruction_list, mode_); + MemoryUsageTracker memory_tracker(options_, computation, *points_to_analysis_, + instruction_list); instruction_list.PromoteNodesToSkip([&](Item* item) { return memory_tracker.AllocatedSize(item) >= min_remat_size; @@ -2028,9 +2021,9 @@ StatusOr HloRematerialization::RematerializeComputation( min_block_size = 1; max_block_size = 1; } - if (max_block_size > block_size_limit_ || + if (max_block_size > options_.block_size_limit || second_phase_effort > - block_rematerialization_factor_ * first_phase_effort) { + options_.block_rematerialization_factor * first_phase_effort) { break; } } @@ -2112,7 +2105,7 @@ StatusOr HloRematerialization::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { VLOG(1) << "HloRematerialization() with memory limit of " - << HumanReadableNumBytes(memory_limit_bytes_); + << HumanReadableNumBytes(options_.memory_limit_bytes); XLA_VLOG_LINES(3, "Before HloRematerialization:\n" + module->ToString()); // Initialize pass object state. @@ -2132,13 +2125,12 @@ StatusOr HloRematerialization::Run( int64_t module_output_size = 0; ShapeUtil::ForEachSubshape( module->result_shape(), - [&module_output_size, module, this](const Shape& subshape, - const ShapeIndex& output_index) { - module_output_size += size_function_(subshape); + [&](const Shape& subshape, const ShapeIndex& output_index) { + module_output_size += options_.size_function(subshape); }); const int64_t adjusted_memory_limit_bytes = - memory_limit_bytes_ - module_output_size; + std::max(0, options_.memory_limit_bytes - module_output_size); VLOG(1) << "Adjusted memory limit accounting for output (" << HumanReadableNumBytes(module_output_size) << "): " << HumanReadableNumBytes(adjusted_memory_limit_bytes); @@ -2175,8 +2167,8 @@ StatusOr HloRematerialization::Run( TF_ASSIGN_OR_RETURN( bool changed, RematerializeComputation(module->entry_computation(), &module->schedule(), - adjusted_memory_limit_bytes, min_remat_size_, - execution_threads)); + adjusted_memory_limit_bytes, + options_.min_remat_size, execution_threads)); // Rematerialization can introduce dead code. This occurs if all uses of an // instruction are replaced with rematerializations of the instruction. @@ -2207,19 +2199,19 @@ StatusOr HloRematerialization::Run( << HumanReadableNumBytes(reduced_peak_memory) << " (" << reduced_peak_memory << " bytes)"; - if (sizes_ != nullptr) { - sizes_->before_bytes = before_peak_memory; - sizes_->after_bytes = current_peak_memory; - } + sizes_.before_bytes = before_peak_memory; + sizes_.after_bytes = current_peak_memory; XLA_VLOG_LINES(5, "After HloRematerialization:\n" + module->ToString()); - if (current_peak_memory > memory_limit_bytes_) { + if (current_peak_memory > options_.memory_limit_bytes) { LOG(WARNING) << absl::StrFormat( "Can't reduce memory use below %s (%d bytes) by rematerialization; " - "only reduced to %s (%d bytes)", - HumanReadableNumBytes(memory_limit_bytes_), memory_limit_bytes_, - HumanReadableNumBytes(current_peak_memory), current_peak_memory); + "only reduced to %s (%d bytes), down from %s (%d bytes) originally", + HumanReadableNumBytes(options_.memory_limit_bytes), + options_.memory_limit_bytes, HumanReadableNumBytes(current_peak_memory), + current_peak_memory, HumanReadableNumBytes(before_peak_memory), + before_peak_memory); } return changed; } diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.h b/tensorflow/compiler/xla/service/hlo_rematerialization.h index e5237451956767..ab85c9764cb5d8 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.h +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.h @@ -15,6 +15,8 @@ #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REMATERIALIZATION_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REMATERIALIZATION_H_ +#include + #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/strings/string_view.h" @@ -66,39 +68,64 @@ class HloRematerialization : public HloModulePass { static Shape DefaultCompactShapeFunction(const Shape& shape) { return shape; } - // Constructor parameters: - // - // size_function: Function which returns the size in bytes of the top-level - // buffer of the given shape. - // - // memory_limit_bytes: The threshold number of bytes to reduce memory use to - // via rematerialization. Size of aliased outputs should be subtracted - // from this. - // - // sizes: Pointer to data structure which records the peak memory usage of - // the HLO module before/after rematerialization. Value are set during - // Run(). Can be nullptr. - // - // compact_shape_function: Function which returns the compact form of a - // shape. If nullptr is provided, an default identity function is used. - explicit HloRematerialization( - const ShapeSizeFunction& size_function, int64_t memory_limit_bytes, - RematerializationSizes* sizes, RematerializationPass pass_location, - int block_size_limit, int block_rematerialization_factor, - CompactShapeFunction compact_shape_function = nullptr, - RematerializationMode mode = RematerializationMode::kRecomputeAndCompress, - int64_t min_remat_size = 0) - : size_function_(size_function), - memory_limit_bytes_(memory_limit_bytes), - sizes_(sizes), - pass_location_(pass_location), - block_size_limit_(block_size_limit), - block_rematerialization_factor_(block_rematerialization_factor), - compact_shape_function_(compact_shape_function == nullptr - ? DefaultCompactShapeFunction - : std::move(compact_shape_function)), - mode_(mode), - min_remat_size_(min_remat_size) {} + struct Options { + explicit Options(const ShapeSizeFunction& size_function, + int64_t memory_limit_bytes, + RematerializationPass pass_location, int block_size_limit, + int block_rematerialization_factor, + CompactShapeFunction compact_shape_function = nullptr, + RematerializationMode mode = + RematerializationMode::kRecomputeAndCompress, + int64_t min_remat_size = 0) + : size_function(size_function), + memory_limit_bytes(memory_limit_bytes), + pass_location(pass_location), + block_size_limit(block_size_limit), + block_rematerialization_factor(block_rematerialization_factor), + compact_shape_function(compact_shape_function == nullptr + ? DefaultCompactShapeFunction + : std::move(compact_shape_function)), + mode(mode), + min_remat_size(min_remat_size) {} + + // Function which computes the size of the top-level buffer of a shape. + const ShapeSizeFunction size_function; + + // The threshold number of bytes to reduce memory use to via + // rematerialization. Size of aliased outputs should be subtracted + // from this. + int64_t memory_limit_bytes; + + // Specifies whether this rematerialization pass occurs before or after + // multi-output fusion. + RematerializationPass pass_location; + + // Maximum number of consecutive instructions to consider for + // rematerialization. + int block_size_limit; + + // Controls the amount of effort spent trying to find large blocks for + // rematerialization. Larger values leads to longer compilation times in + // return for potentially reduced memory consumption. + int block_rematerialization_factor; + + // Converts a shape into compact form, returns the same shape if a shape is + // already considered compact. + const CompactShapeFunction compact_shape_function; + + // Holds the rematerialization strategy configuration to be used by the + // pass. + RematerializationMode mode; + + // The minimim size, in bytes, of a tensor to be considered for + // rematerialization. All tensors smaller than this size will be skipped + // over. + int64_t min_remat_size; + }; + + explicit HloRematerialization(Options options, RematerializationSizes& sizes) + : options_(std::move(options)), sizes_(sizes) {} + ~HloRematerialization() override = default; absl::string_view name() const override { return "rematerialization"; } @@ -160,36 +187,11 @@ class HloRematerialization : public HloModulePass { const absl::flat_hash_set& execution_threads, absl::string_view thread) const; - // Selects an algorithm to use for HLO scheduling. - MemorySchedulerAlgorithm scheduler_algorithm_; + const Options options_; - // Function which computes the size of the top-level buffer of a shape. - const ShapeSizeFunction size_function_; - - // The threshold number of bytes to reduce memory use to via - // rematerialization. - const int64_t memory_limit_bytes_; - - // Pointer to data structure which records the peak memory usage of the HLO - // module before/after rematerialization - RematerializationSizes* sizes_; - - // Specifies whether this rematerialization pass occurs before or after - // multi-output fusion. - RematerializationPass pass_location_; - - // Maximum number of consecutive instructions to consider for - // rematerialization. - int block_size_limit_; - - // Controls the amount of effort spent trying to find large blocks for - // rematerialization. Larger values leads to longer compilation times in - // return for potentially reduced memory consumption. - int block_rematerialization_factor_ = 1; - - // Converts a shape into compact form, returns the same shape if a shape is - // already considered compact. - const CompactShapeFunction compact_shape_function_; + // Reference to data structure which records the peak memory usage of the HLO + // module before/after rematerialization. + RematerializationSizes& sizes_; // Call graph of the hlo_module. std::unique_ptr call_graph_; @@ -221,10 +223,6 @@ class HloRematerialization : public HloModulePass { // upper bound (within a factor of 2) on the block size. int max_rematerialized_block_size_ = 0; - RematerializationMode mode_; - - int64_t min_remat_size_; - // Tracking available channel id numbers to use to apply to rematerialized // channel instructions int64_t next_channel_id_; diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc index c5a17ae983d6bb..84f470c5da87ab 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc @@ -51,13 +51,15 @@ class HloRematerializationTest : public RematerializationTestBase { ComputationSchedulerToModuleScheduler(DefaultMemoryScheduler)); TF_EXPECT_OK(scheduler.Run(module).status()); } - HloRematerialization remat( + + HloRematerialization::Options options( ByteSizeOf, memory_limit_bytes, - /*sizes=*/nullptr, HloRematerialization::RematerializationPass::kPreFusion, /*block_size_limit=*/1, /*block_rematerialization_factor=*/1, nullptr, HloRematerialization::RematerializationMode::kRecomputeAndCompress, min_remat_size); + HloRematerialization::RematerializationSizes sizes; + HloRematerialization remat(options, sizes); return remat.Run(module); } }; @@ -607,14 +609,15 @@ class CompressingRematerializationTest : public RematerializationTestBase { HloModule* module, int64_t min_remat_size = 0) { TF_EXPECT_OK(verifier().Run(module).status()); - HloRematerialization remat( + HloRematerialization::Options options( ShapeSizePadMinorTo64, memory_limit_bytes, - /*sizes=*/nullptr, HloRematerialization::RematerializationPass::kPreFusion, /*block_size_limit=*/1, /*block_rematerialization_factor=*/1, ChooseCompactLayoutForShape, HloRematerialization::RematerializationMode::kCompressOnly, min_remat_size); + HloRematerialization::RematerializationSizes sizes; + HloRematerialization remat(options, sizes); return remat.Run(module); } }; From 30813d0ad7b631241a72df10d2e8a1ab1b88add8 Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Wed, 12 Jul 2023 11:51:19 -0700 Subject: [PATCH 207/376] Integrate StableHLO at openxla/stablehlo@4add5f0 Manual changes: * CMakeLists.txt: keep MLIR-HLO-only customizations to the CMake build. * BUILD.bazel, stablehlo/dialect/CMakeLists.txt, stablehlo/dialect/Base.h, stablehlo/dialect/Base.cpp, stablehlo/dialect/ExperimentalOps.h, stablehlo/dialect/ExperimentalOps.cpp, stablehlo/transforms/StablehloCanonicalizeDynamism.cpp, stablehlo/transforms/StablehloRefineShapes.cpp, stablehlo/tests/stablehlo_canonicalize_dynamism.mlir, stablehlo/tests/stablehlo_refine_shapes.mlir: keep XLA-only customizations to StableHLO shape refinement. PiperOrigin-RevId: 547559812 --- third_party/stablehlo/temporary.patch | 9 --------- third_party/stablehlo/workspace.bzl | 4 ++-- 2 files changed, 2 insertions(+), 11 deletions(-) diff --git a/third_party/stablehlo/temporary.patch b/third_party/stablehlo/temporary.patch index 3ad60b9c2c5bb0..3bb703ca414a73 100644 --- a/third_party/stablehlo/temporary.patch +++ b/third_party/stablehlo/temporary.patch @@ -1275,15 +1275,6 @@ diff --ruN a/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp b/stablehl } }; -@@ -1143,7 +1210,7 @@ - // function. This is sufficient because we only support one function per - // program at the moment. - // TODO(#1048): Find out why .maxIterations = 1 no longer works. -- // There have been recent refactors to applyPatternsAndFoldGreedily -+ // There have been recent refactors in applyPatternsAndFoldGreedily - // upstream, and that might be the reason. - GreedyRewriteConfig config; - config.useTopDownTraversal = true; @@ -1181,7 +1248,9 @@ patterns.add(&getContext()); patterns.add(&getContext()); diff --git a/third_party/stablehlo/workspace.bzl b/third_party/stablehlo/workspace.bzl index 1934d33bf9868a..4319edf4033bb6 100644 --- a/third_party/stablehlo/workspace.bzl +++ b/third_party/stablehlo/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): # LINT.IfChange - STABLEHLO_COMMIT = "20b1da42266a1f351b8315bc195faabceaa74f3e" - STABLEHLO_SHA256 = "7ab70ba2d0aa3c7331df912b674c2825cc168cb691db171a2343d453e4a53811" + STABLEHLO_COMMIT = "4add5f0e890bc66b333e86961978f066325f8a86" + STABLEHLO_SHA256 = "4d3014703aa8d18477790b2f3040163276b50f647aa2da32396f390ea8bf6f7c" # LINT.ThenChange(Google-internal path) tf_http_archive( From d4b760aa36517523bdf6136ef2dece3cc8239dfb Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 12 Jul 2023 11:59:48 -0700 Subject: [PATCH 208/376] Add input type inference logic and WeakTensor input and output support for unary and binary ops. 'weak_tensor_unary_op_wrapper' and 'weak_tensor_binary_op_wrapper' infer various input types as a Tensor/WeakTensor, then return the result in the correct output type using the auto dtype conversion semantics. PiperOrigin-RevId: 547561912 --- tensorflow/python/BUILD | 4 + tensorflow/python/framework/BUILD | 9 +- .../python/framework/flexible_dtypes.py | 46 +- .../python/framework/flexible_dtypes_test.py | 14 + tensorflow/python/framework/weak_tensor.py | 33 ++ tensorflow/python/ops/BUILD | 68 ++- tensorflow/python/ops/numpy_ops/BUILD | 2 - .../python/ops/numpy_ops/np_array_ops.py | 42 -- .../python/ops/numpy_ops/np_math_ops.py | 45 -- .../python/ops/weak_tensor_math_ops_test.py | 458 +++++++++++++++++ tensorflow/python/ops/weak_tensor_ops.py | 464 +++++++++++++++++- tensorflow/python/ops/weak_tensor_ops_list.py | 251 ---------- tensorflow/python/ops/weak_tensor_ops_test.py | 244 +++++++-- .../python/ops/weak_tensor_test_util.py | 21 + tensorflow/tools/pip_package/BUILD | 1 - 15 files changed, 1267 insertions(+), 435 deletions(-) create mode 100644 tensorflow/python/ops/weak_tensor_math_ops_test.py delete mode 100644 tensorflow/python/ops/weak_tensor_ops_list.py diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 5d37854d700fc7..2940055b78d377 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -182,12 +182,14 @@ py_library( "//tensorflow/python/framework:config", "//tensorflow/python/framework:errors", "//tensorflow/python/framework:extension_type", + "//tensorflow/python/framework:flexible_dtypes", "//tensorflow/python/framework:for_generated_wrappers", "//tensorflow/python/framework:graph_util", "//tensorflow/python/framework:kernels", "//tensorflow/python/framework:subscribe", "//tensorflow/python/framework:tensor_spec", "//tensorflow/python/framework:test_ops", # TODO(b/183988750): Break testing code out into separate rule. + "//tensorflow/python/framework:weak_tensor", "//tensorflow/python/grappler:tf_cluster", "//tensorflow/python/grappler:tf_item", "//tensorflow/python/grappler:tf_optimizer", @@ -245,6 +247,8 @@ py_library( "//tensorflow/python/ops:tensor_array_ops", "//tensorflow/python/ops:uniform_quant_ops_gen", "//tensorflow/python/ops:variable_v1", + "//tensorflow/python/ops:weak_tensor_ops", + "//tensorflow/python/ops:weak_tensor_test_util", "//tensorflow/python/ops:weights_broadcast_ops", "//tensorflow/python/ops:while_loop", "//tensorflow/python/ops:while_v2", diff --git a/tensorflow/python/framework/BUILD b/tensorflow/python/framework/BUILD index 856015a16a00cf..ec09f6dff7626f 100644 --- a/tensorflow/python/framework/BUILD +++ b/tensorflow/python/framework/BUILD @@ -1606,7 +1606,9 @@ py_strict_library( ":dtypes", ":errors", ":extension_type", + ":ops", ":tensor", + ":tensor_conversion_registry", "//tensorflow/python/eager:context", "//third_party/py/numpy", ], @@ -1618,7 +1620,6 @@ tf_py_strict_test( main = "weak_tensor_test.py", python_version = "PY3", srcs_version = "PY3", - tags = ["no_pip"], # weak_tensor_test is not available in pip. deps = [ ":constant_op", ":dtypes", @@ -1818,10 +1819,11 @@ pytype_strict_library( name = "flexible_dtypes", srcs = ["flexible_dtypes.py"], deps = [ - ":constant_op", ":dtypes", ":ops", + ":tensor_shape", ":weak_tensor", + "//tensorflow/python/types:core", "//tensorflow/python/util:nest", "//third_party/py/numpy", ], @@ -3114,14 +3116,15 @@ py_strict_test( name = "flexible_dtypes_test", srcs = ["flexible_dtypes_test.py"], tags = [ - "no_pip", "no_windows", # TODO(b/286939592): Enable this test on Windows. ], deps = [ ":constant_op", ":dtypes", + ":extension_type", ":flexible_dtypes", ":ops", + ":tensor", ":weak_tensor", "//tensorflow/python/ops:variables", "//tensorflow/python/ops:weak_tensor_test_util", diff --git a/tensorflow/python/framework/flexible_dtypes.py b/tensorflow/python/framework/flexible_dtypes.py index 5d5abc970ebbe1..909f731faa96e9 100644 --- a/tensorflow/python/framework/flexible_dtypes.py +++ b/tensorflow/python/framework/flexible_dtypes.py @@ -19,6 +19,8 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import weak_tensor +from tensorflow.python.framework.tensor_shape import TensorShape +from tensorflow.python.types import core as core_types from tensorflow.python.util import nest # PromoMode Enum that denotes safe and all mode. @@ -372,6 +374,32 @@ def _initialize(): ) +def _is_acceptable_input_type(x): + """Determines if x is an acceptable input type for auto dtype conversion semantics.""" + acceptable_types = [ + core_types.Tensor, + core_types.TensorProtocol, + int, + float, + bool, + str, + bytes, + complex, + tuple, + list, + np.ndarray, + np.generic, + dtypes.DType, + np.dtype, + TensorShape, + weak_tensor.WeakTensor, + ] + for t in acceptable_types: + if isinstance(x, t): + return True + return False + + def _get_dtype_and_weakness(x): """Returns a TF type and weak type information from x. @@ -438,12 +466,22 @@ def _result_type_impl(*arrays_and_dtypes): TypeError: when the promotion between the input dtypes is disabled in the current mode + + NotImplementedError: when arrays_and_dtypes contains an unsupported input + type (e.g. CompositeTensor). """ promo_safety_mode = ops.get_dtype_conversion_mode() - # Drop None inputs. - valid_arrays_and_dtypes = [ - inp for inp in arrays_and_dtypes if inp is not None - ] + # Drop None inputs and check if input type is supported. + valid_arrays_and_dtypes = [] + for inp in arrays_and_dtypes: + if inp is not None: + if _is_acceptable_input_type(inp): + valid_arrays_and_dtypes.append(inp) + else: + raise NotImplementedError( + 'Auto dtype conversion semantics does not support' + f' {type(inp)} type.' + ) dtypes_and_is_weak = [ _get_dtype_and_weakness(x) for x in nest.flatten(valid_arrays_and_dtypes) diff --git a/tensorflow/python/framework/flexible_dtypes_test.py b/tensorflow/python/framework/flexible_dtypes_test.py index dfd5ec3022621b..5205e9e9dd16d0 100644 --- a/tensorflow/python/framework/flexible_dtypes_test.py +++ b/tensorflow/python/framework/flexible_dtypes_test.py @@ -19,8 +19,10 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import extension_type from tensorflow.python.framework import flexible_dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.framework import weak_tensor from tensorflow.python.ops import variables from tensorflow.python.ops import weak_tensor_test_util @@ -862,6 +864,18 @@ def testResultTypeEmptyInput(self): self.assertEqual(dtype, dtypes.float32) self.assertTrue(is_weak) + def testResultTypeUnsupportedInputType(self): + class MyTensor(extension_type.ExtensionType): + value: tensor.Tensor + + with DtypeConversionTestEnv('all'): + a = MyTensor(constant_op.constant(1)) + with self.assertRaisesRegex( + NotImplementedError, + f'Auto dtype conversion semantics does not support {type(a)} type.', + ): + _ = flexible_dtypes.result_type(a) + # Test v1 + v2 = v2 + v1. def testCommunicativity(self): with DtypeConversionTestEnv('all'): diff --git a/tensorflow/python/framework/weak_tensor.py b/tensorflow/python/framework/weak_tensor.py index 98ecce25026539..6b3f456028972d 100644 --- a/tensorflow/python/framework/weak_tensor.py +++ b/tensorflow/python/framework/weak_tensor.py @@ -24,7 +24,10 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import extension_type +from tensorflow.python.framework import ops from tensorflow.python.framework import tensor as tensor_lib +from tensorflow.python.framework import tensor_conversion_registry + _ALLOWED_WEAK_DTYPES = ( dtypes.int32, @@ -164,6 +167,16 @@ def to_tensor(self): """Converts this 'WeakTensor' into a 'tf.Tensor'.""" return self.tensor + def numpy(self): + """Copy of the contents of this WeakTensor into a NumPy array or scalar.""" + if not isinstance(self.tensor, ops.EagerTensor): + raise ValueError("WeakTensor.numpy() is only supported in eager mode.") + return self.tensor.numpy() + + def _as_graph_element(self): + """Convert `self` to a graph element.""" + return self.tensor + @classmethod def from_tensor(cls, tensor): """Converts a 'tf.Tensor' into a 'WeakTensor'.""" @@ -179,6 +192,10 @@ def dtype(self): def shape(self): return self.tensor.shape + @property + def is_tensor_like(self): + return True + __composite_gradient__ = WeakTensorGradient() @@ -201,3 +218,19 @@ def __next__(self): result = WeakTensor(self._weak_tensor.tensor[self._index]) self._index += 1 return result + + +def maybe_convert_to_weak_tensor(t, is_weak): + return WeakTensor(t) if is_weak else t + + +# convert_to_tensor(WeakTensor) should return a WeakTensor because WeakTensor is +# a 'Tensor' with a special dtype. +def weak_tensor_conversion_function(t): + if isinstance(t, WeakTensor): + return t + + +tensor_conversion_registry.register_tensor_conversion_function( + WeakTensor, weak_tensor_conversion_function +) diff --git a/tensorflow/python/ops/BUILD b/tensorflow/python/ops/BUILD index 1a8d02a84cd15f..f6d7afeb89343b 100644 --- a/tensorflow/python/ops/BUILD +++ b/tensorflow/python/ops/BUILD @@ -1,4 +1,3 @@ -load("//tensorflow:pytype.default.bzl", "pytype_strict_library") load("//tensorflow:strict.default.bzl", "py_strict_library", "py_strict_test") load("//tensorflow:tensorflow.default.bzl", "cuda_py_strict_test", "tf_py_strict_test") load("//tensorflow/python:build_defs.bzl", "tf_gen_op_strict_wrapper_private_py") @@ -4449,10 +4448,26 @@ py_strict_library( name = "weak_tensor_ops", srcs = ["weak_tensor_ops.py"], deps = [ - ":weak_tensor_ops_list", + ":array_ops", + ":array_ops_gen", + ":bitwise_ops_gen", + ":clip_ops", + ":image_ops_impl", + ":math_ops", + ":math_ops_gen", + ":nn_impl", + ":nn_ops", + ":nn_ops_gen", + ":special_math_ops", + "//tensorflow/python/framework:flexible_dtypes", + "//tensorflow/python/framework:ops", "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:weak_tensor", + "//tensorflow/python/ops/numpy_ops:np_array_ops", + "//tensorflow/python/ops/numpy_ops:np_math_ops", + "//tensorflow/python/platform:tf_logging", "//tensorflow/python/util:dispatch", + "//tensorflow/python/util:tf_decorator", ], ) @@ -4467,9 +4482,10 @@ py_strict_test( ":image_ops_impl", ":math_ops", ":weak_tensor_ops", - ":weak_tensor_ops_list", + ":weak_tensor_test_util", "//tensorflow/python/framework:constant_op", "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:extension_type", "//tensorflow/python/framework:ops", "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:test_lib", @@ -4477,33 +4493,47 @@ py_strict_test( "//tensorflow/python/ops/numpy_ops:np_array_ops", "//tensorflow/python/ops/numpy_ops:np_config", "//tensorflow/python/ops/numpy_ops:np_math_ops", + "//tensorflow/python/ops/ragged:ragged_tensor", "//tensorflow/python/platform:test", + "//tensorflow/python/util:dispatch", + "//third_party/py/numpy", "@absl_py//absl/testing:parameterized", ], ) -py_strict_library( - name = "weak_tensor_ops_list", - srcs = ["weak_tensor_ops_list.py"], +py_strict_test( + name = "weak_tensor_math_ops_test", + srcs = ["weak_tensor_math_ops_test.py"], deps = [ ":array_ops", - ":array_ops_gen", - ":bitwise_ops_gen", - ":clip_ops", - ":image_ops_impl", ":math_ops", - ":math_ops_gen", - ":nn_impl", - ":nn_ops", - ":nn_ops_gen", - ":special_math_ops", - "//tensorflow/python/ops/numpy_ops:np_array_ops", - "//tensorflow/python/ops/numpy_ops:np_math_ops", + ":tensor_array_ops", + ":weak_tensor_ops", + ":weak_tensor_test_util", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:tf2", + "//tensorflow/python/eager:backprop", + "//tensorflow/python/eager:context", + "//tensorflow/python/eager:def_function", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", + "//tensorflow/python/framework:test_lib", + "//tensorflow/python/framework:weak_tensor", + "//tensorflow/python/ops/ragged:ragged_factory_ops", + "//tensorflow/python/platform:test", + "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", ], ) -pytype_strict_library( +py_strict_library( name = "weak_tensor_test_util", srcs = ["weak_tensor_test_util.py"], - deps = ["//tensorflow/python/framework:ops"], + deps = [ + "//tensorflow/python/framework:constant_op", + "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:weak_tensor", + "//third_party/py/numpy", + ], ) diff --git a/tensorflow/python/ops/numpy_ops/BUILD b/tensorflow/python/ops/numpy_ops/BUILD index ad6c320f0ce6be..11a521d3d98053 100644 --- a/tensorflow/python/ops/numpy_ops/BUILD +++ b/tensorflow/python/ops/numpy_ops/BUILD @@ -67,7 +67,6 @@ py_strict_library( "//tensorflow/python/ops:manip_ops", "//tensorflow/python/ops:math_ops", "//tensorflow/python/ops:sort_ops", - "//tensorflow/python/util:dispatch", "//tensorflow/python/util:nest", "//third_party/py/numpy", ], @@ -139,7 +138,6 @@ py_strict_library( "//tensorflow/python/ops:sort_ops", "//tensorflow/python/ops:special_math_ops", "//tensorflow/python/ops:while_loop", - "//tensorflow/python/util:dispatch", "//third_party/py/numpy", ], ) diff --git a/tensorflow/python/ops/numpy_ops/np_array_ops.py b/tensorflow/python/ops/numpy_ops/np_array_ops.py index cd795bae931c0e..638e4935f95029 100644 --- a/tensorflow/python/ops/numpy_ops/np_array_ops.py +++ b/tensorflow/python/ops/numpy_ops/np_array_ops.py @@ -40,7 +40,6 @@ from tensorflow.python.ops.numpy_ops import np_dtypes from tensorflow.python.ops.numpy_ops import np_export from tensorflow.python.ops.numpy_ops import np_utils -from tensorflow.python.util import dispatch from tensorflow.python.util import nest @@ -52,7 +51,6 @@ def empty(shape, dtype=float): # pylint: disable=redefined-outer-name return zeros(shape, dtype) -@dispatch.add_dispatch_support @np_utils.np_doc('empty_like') def empty_like(a, dtype=None): return zeros_like(a, dtype) @@ -65,7 +63,6 @@ def zeros(shape, dtype=float): # pylint: disable=redefined-outer-name return array_ops.zeros(shape, dtype=dtype) -@dispatch.add_dispatch_support @np_utils.np_doc('zeros_like') def zeros_like(a, dtype=None): # pylint: disable=missing-docstring dtype = np_utils.result_type_unary(a, dtype) @@ -81,7 +78,6 @@ def ones(shape, dtype=float): # pylint: disable=redefined-outer-name return array_ops.ones(shape, dtype=dtype) -@dispatch.add_dispatch_support @np_utils.np_doc('ones_like') def ones_like(a, dtype=None): dtype = np_utils.result_type_unary(a, dtype) @@ -200,7 +196,6 @@ def true_fn(): # TODO(wangpeng): investigate whether we can make `copy` default to False. # pylint: disable=g-short-docstring-punctuation,g-no-space-after-docstring-summary,g-doc-return-or-yield,g-doc-args -@dispatch.add_dispatch_support @np_utils.np_doc_only('array') def array(val, dtype=None, copy=True, ndmin=0): # pylint: disable=redefined-outer-name """Since Tensors are immutable, a copy is made only if val is placed on a @@ -217,7 +212,6 @@ def array(val, dtype=None, copy=True, ndmin=0): # pylint: disable=redefined-out # pylint: enable=g-short-docstring-punctuation,g-no-space-after-docstring-summary,g-doc-return-or-yield,g-doc-args -@dispatch.add_dispatch_support @np_utils.np_doc('asarray') def asarray(a, dtype=None): if dtype: @@ -228,20 +222,17 @@ def asarray(a, dtype=None): return array(a, dtype, copy=False) -@dispatch.add_dispatch_support @np_utils.np_doc('asanyarray') def asanyarray(a, dtype=None): return asarray(a, dtype) -@dispatch.add_dispatch_support @np_utils.np_doc('ascontiguousarray') def ascontiguousarray(a, dtype=None): return array(a, dtype, ndmin=1) # Numerical ranges. -@dispatch.add_dispatch_support @np_utils.np_doc('arange') def arange(start, stop=None, step=1, dtype=None): """Returns `step`-separated values in the range [start, stop). @@ -284,7 +275,6 @@ def arange(start, stop=None, step=1, dtype=None): # Building matrices. -@dispatch.add_dispatch_support @np_utils.np_doc('diag') def diag(v, k=0): # pylint: disable=missing-docstring """Raises an error if input is not 1- or 2-d.""" @@ -320,7 +310,6 @@ def _diag_part(v, k): return result -@dispatch.add_dispatch_support @np_utils.np_doc('diagonal') def diagonal(a, offset=0, axis1=0, axis2=1): # pylint: disable=missing-docstring a = asarray(a) @@ -352,7 +341,6 @@ def _zeros(): # pylint: disable=missing-docstring return a -@dispatch.add_dispatch_support @np_utils.np_doc('diagflat') def diagflat(v, k=0): v = asarray(v) @@ -421,7 +409,6 @@ def compress(condition, a, axis=None): # pylint: disable=redefined-outer-name,m return array_ops.boolean_mask(tensor=a_t, mask=condition_t, axis=axis) -@dispatch.add_dispatch_support @np_utils.np_doc('copy') def copy(a): return array(a, copy=True) @@ -439,7 +426,6 @@ def _maybe_promote_to_int(a): return a -@dispatch.add_dispatch_support @np_utils.np_doc('cumprod') def cumprod(a, axis=None, dtype=None): # pylint: disable=missing-docstring a = asarray(a, dtype=dtype) @@ -456,7 +442,6 @@ def cumprod(a, axis=None, dtype=None): # pylint: disable=missing-docstring return math_ops.cumprod(a, axis) -@dispatch.add_dispatch_support @np_utils.np_doc('cumsum') def cumsum(a, axis=None, dtype=None): # pylint: disable=missing-docstring a = asarray(a, dtype=dtype) @@ -473,7 +458,6 @@ def cumsum(a, axis=None, dtype=None): # pylint: disable=missing-docstring return math_ops.cumsum(a, axis) -@dispatch.add_dispatch_support @np_utils.np_doc('imag') def imag(val): val = asarray(val) @@ -571,7 +555,6 @@ def size(x, axis=None): # pylint: disable=missing-docstring return array_ops.size_v2(x) -@dispatch.add_dispatch_support @np_utils.np_doc('sum') def sum(a, axis=None, dtype=None, keepdims=None): # pylint: disable=redefined-builtin return _reduce( @@ -583,7 +566,6 @@ def sum(a, axis=None, dtype=None, keepdims=None): # pylint: disable=redefined-b tf_bool_fn=math_ops.reduce_any) -@dispatch.add_dispatch_support @np_utils.np_doc('prod') def prod(a, axis=None, dtype=None, keepdims=None): return _reduce( @@ -595,7 +577,6 @@ def prod(a, axis=None, dtype=None, keepdims=None): tf_bool_fn=math_ops.reduce_all) -@dispatch.add_dispatch_support @np_utils.np_doc('mean', unsupported_params=['out']) def mean(a, axis=None, dtype=None, out=None, keepdims=None): if out is not None: @@ -609,7 +590,6 @@ def mean(a, axis=None, dtype=None, out=None, keepdims=None): promote_int=_TO_FLOAT) -@dispatch.add_dispatch_support @np_utils.np_doc('amax', unsupported_params=['out']) def amax(a, axis=None, out=None, keepdims=None): if out is not None: @@ -625,7 +605,6 @@ def amax(a, axis=None, out=None, keepdims=None): preserve_bool=True) -@dispatch.add_dispatch_support @np_utils.np_doc('amin', unsupported_params=['out']) def amin(a, axis=None, out=None, keepdims=None): if out is not None: @@ -641,7 +620,6 @@ def amin(a, axis=None, out=None, keepdims=None): preserve_bool=True) -@dispatch.add_dispatch_support @np_utils.np_doc('var') def var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=None): # pylint: disable=missing-docstring if dtype: @@ -689,7 +667,6 @@ def reduce_fn(input_tensor, axis, keepdims): return result -@dispatch.add_dispatch_support @np_utils.np_doc('std') def std(a, axis=None, keepdims=None): # pylint: disable=missing-function-docstring return _reduce( @@ -701,14 +678,12 @@ def std(a, axis=None, keepdims=None): # pylint: disable=missing-function-docstr promote_int=_TO_FLOAT) -@dispatch.add_dispatch_support @np_utils.np_doc('ravel') def ravel(a): # pylint: disable=missing-docstring a = asarray(a) return array_ops.reshape(a, [-1]) -@dispatch.add_dispatch_support @np_utils.np_doc('real') def real(val): val = asarray(val) @@ -748,7 +723,6 @@ def repeat(a, repeats, axis=None): # pylint: disable=missing-docstring return result -@dispatch.add_dispatch_support @np_utils.np_doc('around') def around(a, decimals=0): # pylint: disable=missing-docstring a = asarray(a) @@ -771,7 +745,6 @@ def around(a, decimals=0): # pylint: disable=missing-docstring setattr(np_arrays.ndarray, '__round__', around) -@dispatch.add_dispatch_support @np_utils.np_doc('reshape') def reshape(a, newshape, order='C'): """order argument can only b 'C' or 'F'.""" @@ -802,21 +775,18 @@ def _reshape_method_wrapper(a, *newshape, **kwargs): return reshape(a, newshape, order=order) -@dispatch.add_dispatch_support @np_utils.np_doc('expand_dims') def expand_dims(a, axis): a = asarray(a) return array_ops.expand_dims(a, axis=axis) -@dispatch.add_dispatch_support @np_utils.np_doc('squeeze') def squeeze(a, axis=None): a = asarray(a) return array_ops.squeeze(a, axis) -@dispatch.add_dispatch_support @np_utils.np_doc('flatten', link=np_utils.NoLink()) def flatten(a, order='C'): a = asarray(a) @@ -831,7 +801,6 @@ def flatten(a, order='C'): '(column major).') -@dispatch.add_dispatch_support @np_utils.np_doc('transpose') def transpose(a, axes=None): a = asarray(a) @@ -840,7 +809,6 @@ def transpose(a, axes=None): return array_ops.transpose(a=a, perm=axes) -@dispatch.add_dispatch_support @np_utils.np_doc('swapaxes') def swapaxes(a, axis1, axis2): # pylint: disable=missing-docstring a = asarray(a) @@ -873,7 +841,6 @@ def f(x): return a -@dispatch.add_dispatch_support @np_utils.np_doc('moveaxis') def moveaxis(a, source, destination): # pylint: disable=missing-docstring """Raises ValueError if source, destination not in (-ndim(a), ndim(a)).""" @@ -1281,7 +1248,6 @@ def tril(m, k=0): # pylint: disable=missing-docstring array_ops.broadcast_to(mask, array_ops.shape(m)), m, z) -@dispatch.add_dispatch_support @np_utils.np_doc('triu') def triu(m, k=0): # pylint: disable=missing-docstring m = asarray(m) @@ -1303,7 +1269,6 @@ def triu(m, k=0): # pylint: disable=missing-docstring array_ops.broadcast_to(mask, array_ops.shape(m)), z, m) -@dispatch.add_dispatch_support @np_utils.np_doc('flip') def flip(m, axis=None): # pylint: disable=missing-docstring m = asarray(m) @@ -1316,13 +1281,11 @@ def flip(m, axis=None): # pylint: disable=missing-docstring return array_ops.reverse(m, [axis]) -@dispatch.add_dispatch_support @np_utils.np_doc('flipud') def flipud(m): # pylint: disable=missing-docstring return flip(m, 0) -@dispatch.add_dispatch_support @np_utils.np_doc('fliplr') def fliplr(m): # pylint: disable=missing-docstring return flip(m, 1) @@ -1341,7 +1304,6 @@ def roll(a, shift, axis=None): # pylint: disable=missing-docstring return array_ops.reshape(a, original_shape) -@dispatch.add_dispatch_support @np_utils.np_doc('rot90') def rot90(m, k=1, axes=(0, 1)): # pylint: disable=missing-docstring m_rank = array_ops.rank(m) @@ -1362,7 +1324,6 @@ def rot90(m, k=1, axes=(0, 1)): # pylint: disable=missing-docstring return flip(transpose(m, perm), ax2) -@dispatch.add_dispatch_support @np_utils.np_doc('vander') def vander(x, N=None, increasing=False): # pylint: disable=missing-docstring,invalid-name x = asarray(x) @@ -1521,19 +1482,16 @@ def take_along_axis(arr, indices, axis): # pylint: disable=missing-docstring # pylint: disable=redefined-builtin,undefined-variable -@dispatch.add_dispatch_support @np_utils.np_doc('max', link=np_utils.AliasOf('amax')) def max(a, axis=None, keepdims=None): return amax(a, axis=axis, keepdims=keepdims) -@dispatch.add_dispatch_support @np_utils.np_doc('min', link=np_utils.AliasOf('amin')) def min(a, axis=None, keepdims=None): return amin(a, axis=axis, keepdims=keepdims) -@dispatch.add_dispatch_support @np_utils.np_doc('round', link=np_utils.AliasOf('around')) def round(a, decimals=0): return around(a, decimals=decimals) diff --git a/tensorflow/python/ops/numpy_ops/np_math_ops.py b/tensorflow/python/ops/numpy_ops/np_math_ops.py index 8d8dae2a69fbab..f93cc1c708185b 100644 --- a/tensorflow/python/ops/numpy_ops/np_math_ops.py +++ b/tensorflow/python/ops/numpy_ops/np_math_ops.py @@ -41,7 +41,6 @@ from tensorflow.python.ops.numpy_ops import np_dtypes from tensorflow.python.ops.numpy_ops import np_export from tensorflow.python.ops.numpy_ops import np_utils -from tensorflow.python.util import dispatch pi = np_export.np_export_constant(__name__, 'pi', np.pi) @@ -570,7 +569,6 @@ def bitwise_xor(x1, x2): return _bitwise_binary_op(bitwise_ops.bitwise_xor, x1, x2) -@dispatch.add_dispatch_support @np_utils.np_doc('bitwise_not', link=np_utils.AliasOf('invert')) def bitwise_not(x): @@ -603,73 +601,61 @@ def _scalar(tf_fn, x, promote_to_float=False): return tf_fn(x) -@dispatch.add_dispatch_support @np_utils.np_doc('log') def log(x): return _scalar(math_ops.log, x, True) -@dispatch.add_dispatch_support @np_utils.np_doc('exp') def exp(x): return _scalar(math_ops.exp, x, True) -@dispatch.add_dispatch_support @np_utils.np_doc('sqrt') def sqrt(x): return _scalar(math_ops.sqrt, x, True) -@dispatch.add_dispatch_support @np_utils.np_doc('abs', link=np_utils.AliasOf('absolute')) def abs(x): # pylint: disable=redefined-builtin return _scalar(math_ops.abs, x) -@dispatch.add_dispatch_support @np_utils.np_doc('absolute') def absolute(x): return abs(x) -@dispatch.add_dispatch_support @np_utils.np_doc('fabs') def fabs(x): return abs(x) -@dispatch.add_dispatch_support @np_utils.np_doc('ceil') def ceil(x): return _scalar(math_ops.ceil, x, True) -@dispatch.add_dispatch_support @np_utils.np_doc('floor') def floor(x): return _scalar(math_ops.floor, x, True) -@dispatch.add_dispatch_support @np_utils.np_doc('conj') def conj(x): return _scalar(math_ops.conj, x) -@dispatch.add_dispatch_support @np_utils.np_doc('negative') def negative(x): return _scalar(math_ops.negative, x) -@dispatch.add_dispatch_support @np_utils.np_doc('reciprocal') def reciprocal(x): return _scalar(math_ops.reciprocal, x) -@dispatch.add_dispatch_support @np_utils.np_doc('signbit') def signbit(x): @@ -681,79 +667,66 @@ def f(x): return _scalar(f, x) -@dispatch.add_dispatch_support @np_utils.np_doc('sin') def sin(x): return _scalar(math_ops.sin, x, True) -@dispatch.add_dispatch_support @np_utils.np_doc('cos') def cos(x): return _scalar(math_ops.cos, x, True) -@dispatch.add_dispatch_support @np_utils.np_doc('tan') def tan(x): return _scalar(math_ops.tan, x, True) -@dispatch.add_dispatch_support @np_utils.np_doc('sinh') def sinh(x): return _scalar(math_ops.sinh, x, True) -@dispatch.add_dispatch_support @np_utils.np_doc('cosh') def cosh(x): return _scalar(math_ops.cosh, x, True) -@dispatch.add_dispatch_support @np_utils.np_doc('tanh') def tanh(x): return _scalar(math_ops.tanh, x, True) -@dispatch.add_dispatch_support @np_utils.np_doc('arcsin') def arcsin(x): return _scalar(math_ops.asin, x, True) -@dispatch.add_dispatch_support @np_utils.np_doc('arccos') def arccos(x): return _scalar(math_ops.acos, x, True) -@dispatch.add_dispatch_support @np_utils.np_doc('arctan') def arctan(x): return _scalar(math_ops.atan, x, True) -@dispatch.add_dispatch_support @np_utils.np_doc('arcsinh') def arcsinh(x): return _scalar(math_ops.asinh, x, True) -@dispatch.add_dispatch_support @np_utils.np_doc('arccosh') def arccosh(x): return _scalar(math_ops.acosh, x, True) -@dispatch.add_dispatch_support @np_utils.np_doc('arctanh') def arctanh(x): return _scalar(math_ops.atanh, x, True) -@dispatch.add_dispatch_support @np_utils.np_doc('deg2rad') def deg2rad(x): @@ -763,7 +736,6 @@ def f(x): return _scalar(f, x, True) -@dispatch.add_dispatch_support @np_utils.np_doc('rad2deg') def rad2deg(x): return x * (180.0 / np.pi) @@ -774,7 +746,6 @@ def rad2deg(x): ] -@dispatch.add_dispatch_support @np_utils.np_doc('angle') def angle(z, deg=False): # pylint: disable=missing-function-docstring @@ -791,7 +762,6 @@ def f(x): return y -@dispatch.add_dispatch_support @np_utils.np_doc('cbrt') def cbrt(x): @@ -803,13 +773,11 @@ def f(x): return _scalar(f, x, True) -@dispatch.add_dispatch_support @np_utils.np_doc('conjugate', link=np_utils.AliasOf('conj')) def conjugate(x): return _scalar(math_ops.conj, x) -@dispatch.add_dispatch_support @np_utils.np_doc('exp2') def exp2(x): @@ -819,13 +787,11 @@ def f(x): return _scalar(f, x, True) -@dispatch.add_dispatch_support @np_utils.np_doc('expm1') def expm1(x): return _scalar(math_ops.expm1, x, True) -@dispatch.add_dispatch_support @np_utils.np_doc('fix') def fix(x): @@ -881,7 +847,6 @@ def nan_reduction(a, axis=None, dtype=None, keepdims=False): nanprod = _make_nan_reduction('nanprod', np_array_ops.prod, 1) -@dispatch.add_dispatch_support @np_utils.np_doc('nanmean') def nanmean(a, axis=None, dtype=None, keepdims=None): # pylint: disable=missing-docstring a = np_array_ops.array(a) @@ -922,31 +887,26 @@ def isposinf(x): return False -@dispatch.add_dispatch_support @np_utils.np_doc('log2') def log2(x): return log(x) / np.log(2) -@dispatch.add_dispatch_support @np_utils.np_doc('log10') def log10(x): return log(x) / np.log(10) -@dispatch.add_dispatch_support @np_utils.np_doc('log1p') def log1p(x): return _scalar(math_ops.log1p, x, True) -@dispatch.add_dispatch_support @np_utils.np_doc('positive') def positive(x): return _scalar(lambda x: x, x) -@dispatch.add_dispatch_support @np_utils.np_doc('sinc') def sinc(x): @@ -958,13 +918,11 @@ def f(x): return _scalar(f, x, True) -@dispatch.add_dispatch_support @np_utils.np_doc('square') def square(x): return _scalar(math_ops.square, x) -@dispatch.add_dispatch_support @np_utils.np_doc('diff') def diff(a, n=1, axis=-1): # pylint: disable=missing-function-docstring @@ -1246,7 +1204,6 @@ def _argsort(a, axis, stable): return np_array_ops.array(tf_ans, dtype=np.intp) -@dispatch.add_dispatch_support @np_utils.np_doc('sort') def sort(a, axis=-1, kind='quicksort', order=None): # pylint: disable=missing-docstring if kind != 'quicksort': @@ -1293,7 +1250,6 @@ def append(arr, values, axis=None): return concatenate([arr, values], axis=axis) -@dispatch.add_dispatch_support @np_utils.np_doc('average') def average(a, axis=None, weights=None, returned=False): # pylint: disable=missing-docstring if axis is not None and not isinstance(axis, int): @@ -1356,7 +1312,6 @@ def rank_not_equal_case(): return avg -@dispatch.add_dispatch_support @np_utils.np_doc('trace') def trace(a, offset=0, axis1=0, axis2=1, dtype=None): # pylint: disable=missing-docstring if dtype: diff --git a/tensorflow/python/ops/weak_tensor_math_ops_test.py b/tensorflow/python/ops/weak_tensor_math_ops_test.py new file mode 100644 index 00000000000000..82dda558c67c8f --- /dev/null +++ b/tensorflow/python/ops/weak_tensor_math_ops_test.py @@ -0,0 +1,458 @@ +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tensorflow.ops.math_ops on WeakTensor.""" +from absl.testing import parameterized +import numpy as np + +from tensorflow.core.framework import full_type_pb2 +from tensorflow.python import tf2 +from tensorflow.python.eager import backprop +from tensorflow.python.eager import context +from tensorflow.python.eager import def_function +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor +from tensorflow.python.framework import test_util +from tensorflow.python.framework.weak_tensor import WeakTensor +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import tensor_array_ops +from tensorflow.python.ops import weak_tensor_ops # pylint: disable=unused-import +from tensorflow.python.ops import weak_tensor_test_util +from tensorflow.python.ops.ragged import ragged_factory_ops +from tensorflow.python.platform import googletest + +_convert_to_input_type = weak_tensor_test_util.convert_to_input_type +_get_weak_tensor = weak_tensor_test_util.get_weak_tensor + + +@test_util.run_all_in_graph_and_eager_modes +class ReduceTest(test_util.TensorFlowTestCase, parameterized.TestCase): + # Test unary ops with optional dtype arg. + + @parameterized.parameters( + ("WeakTensor", WeakTensor), + ("Python", WeakTensor), + ("NumPy", tensor.Tensor), + ("Tensor", tensor.Tensor), + ) + def testReduceAllDims(self, input_type, result_type): + test_input = _convert_to_input_type( + [[1, 2, 3], [4, 5, 6]], input_type, np.int32 + ) + with test_util.device(use_gpu=True): + res = math_ops.reduce_sum(test_input) + self.assertIsInstance(res, result_type) + self.assertEqual(self.evaluate(res), 21) + + def testReduceExtendType(self): + test_in = np.random.randn(1000, 1000).astype(np.float32) + in_f32 = _get_weak_tensor(test_in, dtypes.float32) + in_bfl6 = math_ops.cast(test_in, dtypes.bfloat16) + + out_f32 = self.evaluate(math_ops.reduce_sum(in_f32)) + out_bf16 = self.evaluate(math_ops.reduce_sum(in_bfl6)) + expected = math_ops.cast(out_f32, dtypes.bfloat16) + + self.assertAllClose(out_bf16, expected, 1e-3) + + def testCountNonzero(self): + # simple case + x = _get_weak_tensor([[0, -2, 0], [4, 0, 0]], dtypes.int32) + self.assertEqual(self.evaluate(math_ops.count_nonzero(x)), 2) + + # boolean input + x = math_ops.not_equal(x, 0) + self.assertEqual(self.evaluate(math_ops.count_nonzero(x)), 2) + + # would overflow if int8 would be used for internal calculations + x = 2 * np.ones(512, dtype=np.int8) + self.assertEqual(self.evaluate(math_ops.count_nonzero(x)), 512) + + @parameterized.parameters( + ("WeakTensor", WeakTensor), + ("Python", WeakTensor), + ("NumPy", tensor.Tensor), + ("Tensor", tensor.Tensor), + ) + def testReduceExplicitAxes(self, input_type, result_type): + x = _convert_to_input_type([[1, 2, 3], [4, 5, 6]], input_type, np.int32) + with test_util.device(use_gpu=True): + for axis in (0, -2): + res = math_ops.reduce_sum(x, axis=axis) + self.assertIsInstance(res, result_type) + self.assertAllEqual(res, [5, 7, 9]) + for axis in (1, -1): + res = math_ops.reduce_sum(x, axis=axis) + self.assertIsInstance(res, result_type) + self.assertAllEqual(res, [6, 15]) + for axis in (None, (0, 1), (1, 0), (-1, 0), (0, -1), (-2, 1), (1, -2), + (-1, -2), (-2, -1)): + res = math_ops.reduce_sum(x, axis=axis) + self.assertIsInstance(res, result_type) + self.assertEqual(self.evaluate(res), 21) + + def testReduceInvalidAxis(self): + if context.executing_eagerly(): + # The shape check is in run a graph construction time. In eager mode, + # it misses the check, magically return result given wrong shape. + return + x = _get_weak_tensor([[1, 2, 3], [4, 5, 6]], dtype=np.int32) + axis = np.array([[0], [1]]) + with self.assertRaisesRegex(ValueError, "must be at most rank 1"): + math_ops.reduce_sum(x, axis) + + def testReduceVar(self): + x = _get_weak_tensor([[0, 0, 0], [0, 0, 0]], dtype=dtypes.float32) + self.assertAllClose(self.evaluate(math_ops.reduce_variance(x)), 0) + self.assertAllClose( + self.evaluate(math_ops.reduce_variance(x, axis=0)), [0, 0, 0]) + + x = _get_weak_tensor([[1, 2, 1, 1], [1, 1, 0, 1]]) + with self.assertRaisesRegex(TypeError, "must be either real or complex"): + math_ops.reduce_variance(x) + + x = _get_weak_tensor([[1.0, 2.0, 1.0, 1.0], [1.0, 1.0, 0.0, 1.0]]) + self.assertEqual(self.evaluate(math_ops.reduce_variance(x)), 0.25) + x_np = np.array([[1, 2, 1, 1], [1, 1, 0, 1]], "float32") + self.assertEqual(np.var(x_np), 0.25) + self.assertEqual(self.evaluate(math_ops.reduce_variance(x_np)), 0.25) + + x = ragged_factory_ops.constant([[5., 1., 4., 1.], [], [5., 9., 2.], [5.], + []]) + self.assertAllClose(math_ops.reduce_variance(x, axis=0), [0., 16., 1., 0.]) + + def testReduceVarComplex(self): + # Ensure that complex values are handled to be consistent with numpy + complex_ys = [([0 - 1j, 0 + 1j], dtypes.float64), + (np.array([0 - 1j, 0 + 1j], "complex64"), dtypes.float32), + (np.array([0 - 1j, 0 + 1j], "complex128"), dtypes.float64)] + for y, dtype in complex_ys: + y_result = math_ops.reduce_variance(y) + self.assertEqual(np.var(y), 1.0) + self.assertEqual(self.evaluate(y_result), 1.0) + self.assertEqual(y_result.dtype, dtype) + + def testReduceStd(self): + x = _get_weak_tensor([[0, 0, 0], [0, 0, 0]], dtypes.float32) + self.assertAllClose(self.evaluate(math_ops.reduce_std(x)), 0) + self.assertAllClose( + self.evaluate(math_ops.reduce_std(x, axis=0)), [0, 0, 0]) + + x = _get_weak_tensor([[1, 2, 1, 1], [1, 1, 0, 1]]) + with self.assertRaisesRegex(TypeError, "must be either real or complex"): + math_ops.reduce_std(x) + + x = [[1., 2., 1., 1.], [1., 1., 0., 1.]] + res = math_ops.reduce_std(x) + self.assertEqual(self.evaluate(res), 0.5) + self.assertIsInstance(res, WeakTensor) + x_np = np.array(x) + self.assertEqual(np.std(x_np), 0.5) + self.assertEqual(self.evaluate(math_ops.reduce_std(x_np)), 0.5) + self.assertIsInstance(math_ops.reduce_std(x_np), tensor.Tensor) + + x = ragged_factory_ops.constant([[5., 1., 4., 1.], [], [5., 9., 2.], [5.], + []]) + self.assertAllClose(math_ops.reduce_std(x, axis=0), [0., 4., 1., 0.]) + + def testReduceStdComplex(self): + # Ensure that complex values are handled to be consistent with numpy + complex_ys = [([0 - 1j, 0 + 1j], dtypes.float64), + (np.array([0 - 1j, 0 + 1j], "complex64"), dtypes.float32), + (np.array([0 - 1j, 0 + 1j], "complex128"), dtypes.float64)] + for y, dtype in complex_ys: + y_result = math_ops.reduce_std(y) + self.assertEqual(np.std(y), 1.0) + self.assertEqual(self.evaluate(y_result), 1.0) + self.assertEqual(y_result.dtype, dtype) + + +@test_util.run_all_in_graph_and_eager_modes +class LogSumExpTest(test_util.TensorFlowTestCase): + + def testReduceLogSumExp(self): + for dtype in [np.float16, np.float32, np.double]: + x_np = np.random.rand(5, 5).astype(dtype) + with test_util.use_gpu(): + y_tf_np = math_ops.reduce_logsumexp(x_np) + y_np = np.log(np.sum(np.exp(x_np))) + self.assertAllClose(y_tf_np, y_np) + + def testReductionIndices(self): + for dtype in [np.float16, np.float32, np.double]: + x_np = np.random.rand(5, 5).astype(dtype) + with test_util.use_gpu(): + y_tf = math_ops.reduce_logsumexp(x_np, axis=[0]) + y_np = np.log(np.sum(np.exp(x_np), axis=0)) + self.assertShapeEqual(y_np, y_tf) + y_tf_np = self.evaluate(y_tf) + self.assertAllClose(y_tf_np, y_np) + + def testReductionIndices2(self): + for dtype in [np.float16, np.float32, np.double]: + x_np = np.random.rand(5, 5).astype(dtype) + with test_util.use_gpu(): + y_tf = math_ops.reduce_logsumexp(x_np, axis=0) + y_np = np.log(np.sum(np.exp(x_np), axis=0)) + self.assertShapeEqual(y_np, y_tf) + y_tf_np = self.evaluate(y_tf) + self.assertAllClose(y_tf_np, y_np) + + def testKeepDims(self): + for dtype in [np.float16, np.float32, np.double]: + x_np = np.random.rand(5, 5).astype(dtype) + with test_util.use_gpu(): + y_tf_np = math_ops.reduce_logsumexp(x_np, keepdims=True) + self.assertEqual(y_tf_np.shape.rank, x_np.ndim) + y_np = np.log(np.sum(np.exp(x_np), keepdims=True)) + self.assertAllClose(y_tf_np, y_np) + + def testOverflow(self): + x = [1000, 1001, 1002, 1003] + for dtype in [np.float32, np.double]: + x_np = np.array(x, dtype=dtype) + max_np = np.max(x_np) + with self.assertRaisesRegex(RuntimeWarning, + "overflow encountered in exp"): + out = np.log(np.sum(np.exp(x_np))) + if out == np.inf: + raise RuntimeWarning("overflow encountered in exp") + + with test_util.use_gpu(): + x_tf = _get_weak_tensor(x_np, shape=x_np.shape) + y_tf_np = math_ops.reduce_logsumexp(x_tf) + y_np = np.log(np.sum(np.exp(x_np - max_np))) + max_np + self.assertAllClose(y_tf_np, y_np) + + def testUnderflow(self): + x = [-1000, -1001, -1002, -1003] + for dtype in [np.float32, np.double]: + x_np = np.array(x, dtype=dtype) + max_np = np.max(x_np) + with self.assertRaisesRegex(RuntimeWarning, + "divide by zero encountered in log"): + out = np.log(np.sum(np.exp(x_np))) + if out == -np.inf: + raise RuntimeWarning("divide by zero encountered in log") + + with test_util.use_gpu(): + x_tf = _get_weak_tensor(x_np, shape=x_np.shape) + y_tf_np = math_ops.reduce_logsumexp(x_tf) + y_np = np.log(np.sum(np.exp(x_np - max_np))) + max_np + self.assertAllClose(y_tf_np, y_np) + + def testInfinity(self): + with test_util.use_gpu(): + res = math_ops.reduce_logsumexp(-np.inf) + self.assertEqual(-np.inf, self.evaluate(res)) + + +@test_util.run_all_in_graph_and_eager_modes +class RoundTest(test_util.TensorFlowTestCase): + + def testRounding(self): + x = np.arange(-5.0, 5.0, .25) + for dtype in [np.float32, np.double, np.int32]: + x_np = np.array(x, dtype=dtype) + with test_util.device(use_gpu=True): + x_tf = _get_weak_tensor(x_np, shape=x_np.shape) + y_tf = math_ops.round(x_tf) + y_tf_np = self.evaluate(y_tf) + y_np = np.round(x_np) + self.assertAllClose(y_tf_np, y_np, atol=1e-2) + + +class SignTest(test_util.TensorFlowTestCase): + + def test_complex_sign_gradient(self): + with context.eager_mode(): + x = math_ops.complex(1., 1.) + with backprop.GradientTape() as t: + t.watch(x) + y = math_ops.sign(x) + self.assertAllClose( + t.gradient(y, x), math_ops.complex(0.353553, -0.353553)) + + +@test_util.run_all_in_graph_and_eager_modes +class ReciprocalNoNanTest(test_util.TensorFlowTestCase): + + allowed_dtypes = [dtypes.float32, dtypes.float64, dtypes.complex128] + + def testBasic(self): + for dtype in self.allowed_dtypes: + x = _get_weak_tensor([1.0, 2.0, 0.0, 4.0], dtype=dtype) + + y = math_ops.reciprocal_no_nan(x) + + target = _get_weak_tensor([1.0, 0.5, 0.0, 0.25], dtype=dtype) + + self.assertAllEqual(y, target) + self.assertEqual(y.dtype.base_dtype, target.dtype.base_dtype) + + def testInverse(self): + for dtype in self.allowed_dtypes: + x = np.random.choice([0, 1, 2, 4, 5], size=(5, 5, 5)) + x = _get_weak_tensor(x, dtype=dtype) + + y = math_ops.reciprocal_no_nan(math_ops.reciprocal_no_nan(x)) + + self.assertAllClose(y, x) + self.assertEqual(y.dtype.base_dtype, x.dtype.base_dtype) + + +class EqualityTest(test_util.TensorFlowTestCase, parameterized.TestCase): + + @test_util.run_all_in_graph_and_eager_modes + def testEqualityNone(self): + x = _get_weak_tensor([1.0, 2.0, 0.0, 4.0], dtype=dtypes.float32) + self.assertNotEqual(x, None) + self.assertNotEqual(None, x) + self.assertFalse(math_ops.tensor_equals(x, None)) + self.assertTrue(math_ops.tensor_not_equals(x, None)) + + @parameterized.named_parameters( + (f"-is_equals={is_equals}-float_literal_type={type(float_literal)}" # pylint: disable=g-complex-comprehension + f"-float_literal={float_literal}", is_equals, float_literal) + for float_literal in [4.6, np.float32(4.6), 4.4, np.float32(4.4)] + for is_equals in [True, False]) + def testEqualityNoDowncast(self, is_equals, float_literal): + if (tf2.enabled() and isinstance(float_literal, np.float32) or + not tf2.enabled() and isinstance(float_literal, float)): + # TODO(b/199262800): Remove this skip + self.skipTest("There is a bug in type promotion.") + if is_equals: + op = math_ops.tensor_equals + else: + op = math_ops.tensor_not_equals + x = _get_weak_tensor(4) + try: + result = op(x, float_literal) + if isinstance(result, tensor.Tensor): + result = self.evaluate(result) + except TypeError: + # Throwing a TypeError is OK + return + self.assertEqual(result, not is_equals) + + +@test_util.run_all_in_graph_and_eager_modes +class ErfcinvTest(test_util.TensorFlowTestCase): + + def testErfcinv(self): + values = _get_weak_tensor( + np.random.uniform(0.1, 1.9, size=int(1e4)).astype(np.float32) + ) + approx_id = math_ops.erfc(math_ops.erfcinv(values)) + self.assertAllClose(values, self.evaluate(approx_id)) + + +@test_util.run_all_in_graph_and_eager_modes +class ArgMaxMinTest(test_util.TensorFlowTestCase): + + def _generateRandomWeakTensor(self, dtype, shape): + if dtype.is_integer: + array = np.random.default_rng().integers( + low=dtype.min, high=dtype.max, size=shape, endpoint=True) + return _get_weak_tensor(array, dtype=dtype) + else: + array = np.random.default_rng().uniform(low=-1.0, high=1.0, size=shape) + return _get_weak_tensor(array, dtype=dtype) + + def _getValidDtypes(self): + return (dtypes.float32, dtypes.float64, dtypes.int32, dtypes.int64) + + def testArgMax(self): + shape = (24, 8) + for dtype in self._getValidDtypes(): + tf_values = self._generateRandomWeakTensor(dtype, shape) + np_values = self.evaluate(tf_values) + for axis in range(0, len(shape)): + np_max = np.argmax(np_values, axis=axis) + tf_max = math_ops.argmax(tf_values, axis=axis) + self.assertAllEqual(tf_max, np_max) + + def testArgMaxReturnsFirstOccurence(self): + for dtype in self._getValidDtypes(): + values = _get_weak_tensor( + [[10, 11, 15, 15, 10], [12, 12, 10, 10, 12]], dtype=dtype + ) + self.assertAllEqual( + math_ops.argmax(values, axis=1), + np.argmax(self.evaluate(values), axis=1)) + + # Long tensor to ensure works with multithreading/GPU + values = array_ops.zeros(shape=(193681,), dtype=dtype) + self.assertAllEqual(math_ops.argmax(values), 0) + + def testArgMaxUint16(self): + shape = (24, 8) + for dtype in self._getValidDtypes(): + tf_values = self._generateRandomWeakTensor(dtype, shape) + np_values = self.evaluate(tf_values) + for axis in range(0, len(shape)): + np_max = np.argmax(np_values, axis=axis) + tf_max = math_ops.argmax( + tf_values, axis=axis, output_type=dtypes.uint16) + self.assertAllEqual(tf_max, np_max) + + def testArgMin(self): + shape = (24, 8) + for dtype in self._getValidDtypes(): + tf_values = self._generateRandomWeakTensor(dtype, shape) + np_values = self.evaluate(tf_values) + for axis in range(0, len(shape)): + np_min = np.argmin(np_values, axis=axis) + tf_min = math_ops.argmin(tf_values, axis=axis) + self.assertAllEqual(tf_min, np_min) + + def testArgMinReturnsFirstOccurence(self): + for dtype in self._getValidDtypes(): + values = _get_weak_tensor( + [[10, 11, 15, 15, 10], [12, 12, 10, 10, 12]], dtype=dtype + ) + self.assertAllEqual( + math_ops.argmin(values, axis=1), + np.argmin(self.evaluate(values), axis=1)) + + # Long tensor to ensure works with multithreading/GPU + values = array_ops.zeros(shape=(193681,), dtype=dtype) + self.assertAllEqual(math_ops.argmin(values), 0) + + +class CastTest(test_util.TensorFlowTestCase): + + def testCastWithFullType(self): + + @def_function.function + def test_fn(): + ta = tensor_array_ops.TensorArray(dtypes.int32, size=1) + h = math_ops.cast(ta.flow, dtypes.variant) + + t = full_type_pb2.FullTypeDef( + type_id=full_type_pb2.TFT_PRODUCT, + args=[full_type_pb2.FullTypeDef(type_id=full_type_pb2.TFT_ARRAY)]) + h.op.experimental_set_type(t) + + ta = tensor_array_ops.TensorArray(dtypes.int32, flow=h) + ta = ta.write(0, _get_weak_tensor(1)) + return ta.stack() + + self.assertAllEqual(self.evaluate(test_fn()), [1]) + +if __name__ == "__main__": + ops.set_dtype_conversion_mode("all") + googletest.main() diff --git a/tensorflow/python/ops/weak_tensor_ops.py b/tensorflow/python/ops/weak_tensor_ops.py index 627768a45a0fff..083dcc8ad580d0 100644 --- a/tensorflow/python/ops/weak_tensor_ops.py +++ b/tensorflow/python/ops/weak_tensor_ops.py @@ -16,44 +16,462 @@ import inspect +from tensorflow.python.framework import flexible_dtypes +from tensorflow.python.framework import ops from tensorflow.python.framework import tensor -from tensorflow.python.framework.weak_tensor import WeakTensor -from tensorflow.python.ops import weak_tensor_ops_list +from tensorflow.python.framework import weak_tensor +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import clip_ops +from tensorflow.python.ops import gen_array_ops +from tensorflow.python.ops import gen_bitwise_ops +from tensorflow.python.ops import gen_math_ops +from tensorflow.python.ops import gen_nn_ops +from tensorflow.python.ops import image_ops_impl +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_impl +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import special_math_ops +from tensorflow.python.ops.numpy_ops import np_array_ops +from tensorflow.python.ops.numpy_ops import np_math_ops +from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import dispatch +from tensorflow.python.util import tf_decorator -# This file must depend on math_ops so that e.g. `__add__` is -# added to the Tensor class. -for operator in tensor.Tensor.OVERLOADABLE_OPERATORS: - tensor_oper = getattr(tensor.Tensor, operator) - setattr(WeakTensor, operator, tensor_oper) - # List of unary ops that have support for WeakTensor. -_TF_UNARY_APIS = weak_tensor_ops_list.ALL_UNARY_OPS +_TF_UNARY_APIS = [] +_TF_BINARY_APIS = [] + + +# ============================================================================== +# Utils to handle WeakTensor inputs and outputs. +# ============================================================================== +# pylint: disable=g-doc-args,g-doc-return-or-yield +def _convert_or_cast(x, dtype, name): + """Converts/casts the input x to dtype.""" + # TODO(b/290216343): remove this branch once we fix the precision loss bug in + # tf.cast. + if isinstance(x, (int, float, complex)): + return ops.convert_to_tensor(x, dtype=dtype, name=name) + else: + return math_ops.cast(x, dtype=dtype, name=name) + + +def weak_tensor_unary_op_wrapper(op): + """Infers input type and adds WeakTensor support to unary ops. + + This wrapper infers input type according to the auto dtype conversion + semantics - Tensor and NumPy inputs as Tensor of corresponding dtype and + WeakTensor and python inputs as WeakTensor of corresponding dtype. If the + inferred input dtype is "weak" and the op doesn't specify a return dtype, + returns WeakTensor. + """ + signature = inspect.signature(op) + arg_names = iter(signature.parameters.keys()) + x_arg_name = next(arg_names) + + def wrapper(*args, **kwargs): + if not ops.is_auto_dtype_conversion_enabled(): + return op(*args, **kwargs) + bound_arguments = signature.bind(*args, **kwargs) + bound_arguments.apply_defaults() + bound_kwargs = bound_arguments.arguments + x = bound_kwargs[x_arg_name] + # No input/output handling needed when input is a Tensor because Tensor + # input in unary op always outputs a Tensor. + if isinstance(x, tensor.Tensor): + return op(**bound_kwargs) + # Infer input type and determine the result promotion type. + try: + target_type, is_weak = flexible_dtypes.result_type(x) + # NotImplementedError is thrown from result_type when x is an + # unsupported input type (e.g. CompositeTensor). + except NotImplementedError: + logging.warning( + "The new dtype semantics do not support this input dtype. Falling" + " back to old semantics." + ) + return op(**bound_kwargs) + bound_kwargs[x_arg_name] = _convert_or_cast(x, target_type, "x") + # Only return WeakTensor when dtype is NOT specified. + if bound_kwargs.get("dtype", None) is not None: + is_weak = False + return weak_tensor.maybe_convert_to_weak_tensor(op(**bound_kwargs), is_weak) + + wrapper = tf_decorator.make_decorator(op, wrapper) + + # Update dispatch dictionary to store monkey-patched op references. + _update_weak_tensor_patched_ops_in_dispatch_dict(wrapper) + # Add the updated function to list of unary ops with WeakTensor support. + _TF_UNARY_APIS.append(wrapper) + return wrapper + + +def weak_tensor_binary_op_wrapper(op): + """Determines result promotion type and adds WeakTensor support to binary ops. + + This wrapper first infers dtype of any Tensor, WeakTensor, python/numpy + inputs. Then, both inputs are promoted to the correct promotion result dtype. + If the result promotion dtype is "weak", returns WeakTensor. + """ -def register_unary_weak_tensor_dispatcher(op): - """Add dispatch for WeakTensor inputs.""" signature = inspect.signature(op) - weak_tensor_arg_name = next(iter(signature.parameters.keys())) + arg_names = iter(signature.parameters.keys()) + x_arg_name = next(arg_names) + y_arg_name = next(arg_names) - @dispatch.dispatch_for_api(op, {weak_tensor_arg_name: WeakTensor}) def wrapper(*args, **kwargs): + if not ops.is_auto_dtype_conversion_enabled(): + return op(*args, **kwargs) bound_arguments = signature.bind(*args, **kwargs) bound_arguments.apply_defaults() bound_kwargs = bound_arguments.arguments - bound_kwargs[weak_tensor_arg_name] = bound_kwargs[ - weak_tensor_arg_name - ].to_tensor() - - # Only return WeakTensor if there is no dtype specified. - if bound_kwargs.get("dtype", None) is None: - return WeakTensor.from_tensor((op(**bound_kwargs))) - else: + x = bound_kwargs[x_arg_name] + y = bound_kwargs[y_arg_name] + # Infer input type and determine the result promotion type. + try: + target_type, is_weak = flexible_dtypes.result_type(x, y) + # NotImplementedError is thrown from result_type when x or y is an + # unsupported input type (e.g. CompositeTensor). + except NotImplementedError: + logging.warning( + "The new dtype semantics do not support this input dtype. Falling" + " back to old semantics." + ) return op(**bound_kwargs) + bound_kwargs[x_arg_name] = _convert_or_cast(x, target_type, "x") + bound_kwargs[y_arg_name] = _convert_or_cast(y, target_type, "y") + return weak_tensor.maybe_convert_to_weak_tensor(op(**bound_kwargs), is_weak) + + wrapper = tf_decorator.make_decorator(op, wrapper) + + # Update dispatch dictionary to store monkey-patched op references. + _update_weak_tensor_patched_ops_in_dispatch_dict(wrapper) + + # Add the updated function to list of binary ops with WeakTensor support. + _TF_BINARY_APIS.append(wrapper) return wrapper -for tf_unary_api in _TF_UNARY_APIS: - register_unary_weak_tensor_dispatcher(tf_unary_api) +# TODO(b/290672237): Investigate if there is a more elegant solution. +def _update_weak_tensor_patched_ops_in_dispatch_dict(patched_op): + """Update dispatch dictionary to store WeakTensor patched op references. + + _TYPE_BASED_DISPATCH_SIGNATURES in dispatch.py stores mappings from op + reference to all the dispatchers it's registered with. We need to update + this dictionary to add a mapping from the patched-op reference to the + signature dictionary the unpatched-op reference is mapped to. This ensures + that dispatch can be reigstered and unregistered with monkey-patched ops. + """ + dispatch_dict = dispatch._TYPE_BASED_DISPATCH_SIGNATURES # pylint: disable=protected-access + unpatched_api = patched_op.__wrapped__ + if unpatched_api in dispatch_dict: + dispatch_dict[patched_op] = dispatch_dict[unpatched_api] + + +# ============================================================================== +# Monkey patching to add WeakTensor Support. +# ============================================================================== +# Elementwise unary ops +math_ops.abs = weak_tensor_unary_op_wrapper(math_ops.abs) +math_ops.softplus = weak_tensor_unary_op_wrapper(math_ops.softplus) +math_ops.sign = weak_tensor_unary_op_wrapper(math_ops.sign) +math_ops.real = weak_tensor_unary_op_wrapper(math_ops.real) +math_ops.imag = weak_tensor_unary_op_wrapper(math_ops.imag) +math_ops.angle = weak_tensor_unary_op_wrapper(math_ops.angle) +math_ops.round = weak_tensor_unary_op_wrapper(math_ops.round) +math_ops.sigmoid = weak_tensor_unary_op_wrapper(math_ops.sigmoid) +math_ops.log_sigmoid = weak_tensor_unary_op_wrapper(math_ops.log_sigmoid) +math_ops.conj = weak_tensor_unary_op_wrapper(math_ops.conj) +math_ops.reciprocal_no_nan = weak_tensor_unary_op_wrapper( + math_ops.reciprocal_no_nan +) +math_ops.erfinv = weak_tensor_unary_op_wrapper(math_ops.erfinv) +math_ops.ndtri = weak_tensor_unary_op_wrapper(math_ops.ndtri) +math_ops.erfcinv = weak_tensor_unary_op_wrapper(math_ops.erfcinv) +math_ops.ceil = weak_tensor_unary_op_wrapper(math_ops.ceil) +math_ops.sqrt = weak_tensor_unary_op_wrapper(math_ops.sqrt) +math_ops.exp = weak_tensor_unary_op_wrapper(math_ops.exp) +math_ops.rsqrt = weak_tensor_unary_op_wrapper(math_ops.rsqrt) +math_ops.acos = weak_tensor_unary_op_wrapper(math_ops.acos) +math_ops.floor = weak_tensor_unary_op_wrapper(math_ops.floor) +gen_bitwise_ops.invert = weak_tensor_unary_op_wrapper(gen_bitwise_ops.invert) +gen_math_ops.acosh = weak_tensor_unary_op_wrapper(gen_math_ops.acosh) +gen_math_ops.asin = weak_tensor_unary_op_wrapper(gen_math_ops.asin) +gen_math_ops.asinh = weak_tensor_unary_op_wrapper(gen_math_ops.asinh) +gen_math_ops.atan = weak_tensor_unary_op_wrapper(gen_math_ops.atan) +gen_math_ops.atanh = weak_tensor_unary_op_wrapper(gen_math_ops.atanh) +gen_math_ops.cos = weak_tensor_unary_op_wrapper(gen_math_ops.cos) +gen_math_ops.cosh = weak_tensor_unary_op_wrapper(gen_math_ops.cosh) +gen_math_ops.digamma = weak_tensor_unary_op_wrapper(gen_math_ops.digamma) +gen_math_ops.erf = weak_tensor_unary_op_wrapper(gen_math_ops.erf) +gen_math_ops.erfc = weak_tensor_unary_op_wrapper(gen_math_ops.erfc) +gen_math_ops.expm1 = weak_tensor_unary_op_wrapper(gen_math_ops.expm1) +gen_math_ops.lgamma = weak_tensor_unary_op_wrapper(gen_math_ops.lgamma) +gen_math_ops.log = weak_tensor_unary_op_wrapper(gen_math_ops.log) +gen_math_ops.log1p = weak_tensor_unary_op_wrapper(gen_math_ops.log1p) +gen_math_ops.neg = weak_tensor_unary_op_wrapper(gen_math_ops.neg) +gen_math_ops.reciprocal = weak_tensor_unary_op_wrapper(gen_math_ops.reciprocal) +gen_math_ops.rint = weak_tensor_unary_op_wrapper(gen_math_ops.rint) +gen_math_ops.sin = weak_tensor_unary_op_wrapper(gen_math_ops.sin) +gen_math_ops.sinh = weak_tensor_unary_op_wrapper(gen_math_ops.sinh) +gen_math_ops.square = weak_tensor_unary_op_wrapper(gen_math_ops.square) +gen_math_ops.tan = weak_tensor_unary_op_wrapper(gen_math_ops.tan) +gen_math_ops.tanh = weak_tensor_unary_op_wrapper(gen_math_ops.tanh) +array_ops.zeros_like = weak_tensor_unary_op_wrapper(array_ops.zeros_like) +array_ops.zeros_like_v2 = weak_tensor_unary_op_wrapper(array_ops.zeros_like_v2) +array_ops.ones_like = weak_tensor_unary_op_wrapper(array_ops.ones_like) +array_ops.ones_like_v2 = weak_tensor_unary_op_wrapper(array_ops.ones_like_v2) +gen_array_ops.check_numerics = weak_tensor_unary_op_wrapper( + gen_array_ops.check_numerics +) +nn_ops.relu6 = weak_tensor_unary_op_wrapper(nn_ops.relu6) +nn_ops.leaky_relu = weak_tensor_unary_op_wrapper(nn_ops.leaky_relu) +nn_ops.gelu = weak_tensor_unary_op_wrapper(nn_ops.gelu) +nn_ops.log_softmax = weak_tensor_unary_op_wrapper(nn_ops.log_softmax) +nn_ops.log_softmax_v2 = weak_tensor_unary_op_wrapper(nn_ops.log_softmax_v2) +nn_impl.swish = weak_tensor_unary_op_wrapper(nn_impl.swish) +gen_nn_ops.elu = weak_tensor_unary_op_wrapper(gen_nn_ops.elu) +gen_nn_ops.relu = weak_tensor_unary_op_wrapper(gen_nn_ops.relu) +gen_nn_ops.selu = weak_tensor_unary_op_wrapper(gen_nn_ops.selu) +gen_nn_ops.softsign = weak_tensor_unary_op_wrapper(gen_nn_ops.softsign) +image_ops_impl.random_brightness = weak_tensor_unary_op_wrapper( + image_ops_impl.random_brightness +) +image_ops_impl.stateless_random_brightness = weak_tensor_unary_op_wrapper( + image_ops_impl.stateless_random_brightness +) +image_ops_impl.adjust_brightness = weak_tensor_unary_op_wrapper( + image_ops_impl.adjust_brightness +) +image_ops_impl.adjust_gamma = weak_tensor_unary_op_wrapper( + image_ops_impl.adjust_gamma +) +clip_ops.clip_by_value = weak_tensor_unary_op_wrapper(clip_ops.clip_by_value) +special_math_ops.dawsn = weak_tensor_unary_op_wrapper(special_math_ops.dawsn) +special_math_ops.expint = weak_tensor_unary_op_wrapper(special_math_ops.expint) +special_math_ops.fresnel_cos = weak_tensor_unary_op_wrapper( + special_math_ops.fresnel_cos +) +special_math_ops.fresnel_sin = weak_tensor_unary_op_wrapper( + special_math_ops.fresnel_sin +) +special_math_ops.spence = weak_tensor_unary_op_wrapper(special_math_ops.spence) +special_math_ops.bessel_i0 = weak_tensor_unary_op_wrapper( + special_math_ops.bessel_i0 +) +special_math_ops.bessel_i0e = weak_tensor_unary_op_wrapper( + special_math_ops.bessel_i0e +) +special_math_ops.bessel_i1 = weak_tensor_unary_op_wrapper( + special_math_ops.bessel_i1 +) +special_math_ops.bessel_i1e = weak_tensor_unary_op_wrapper( + special_math_ops.bessel_i1e +) +special_math_ops.bessel_k0 = weak_tensor_unary_op_wrapper( + special_math_ops.bessel_k0 +) +special_math_ops.bessel_k0e = weak_tensor_unary_op_wrapper( + special_math_ops.bessel_k0e +) +special_math_ops.bessel_k1 = weak_tensor_unary_op_wrapper( + special_math_ops.bessel_k1 +) +special_math_ops.bessel_k1e = weak_tensor_unary_op_wrapper( + special_math_ops.bessel_k1e +) +special_math_ops.bessel_j0 = weak_tensor_unary_op_wrapper( + special_math_ops.bessel_j0 +) +special_math_ops.bessel_j1 = weak_tensor_unary_op_wrapper( + special_math_ops.bessel_j1 +) +special_math_ops.bessel_y0 = weak_tensor_unary_op_wrapper( + special_math_ops.bessel_y0 +) +special_math_ops.bessel_y1 = weak_tensor_unary_op_wrapper( + special_math_ops.bessel_y1 +) + +# TF Non-Elementwise Unary Ops +math_ops.reduce_euclidean_norm = weak_tensor_unary_op_wrapper( + math_ops.reduce_euclidean_norm +) +math_ops.reduce_logsumexp = weak_tensor_unary_op_wrapper( + math_ops.reduce_logsumexp +) +math_ops.reduce_max = weak_tensor_unary_op_wrapper(math_ops.reduce_max) +math_ops.reduce_max_v1 = weak_tensor_unary_op_wrapper(math_ops.reduce_max_v1) +math_ops.reduce_mean = weak_tensor_unary_op_wrapper(math_ops.reduce_mean) +math_ops.reduce_mean_v1 = weak_tensor_unary_op_wrapper(math_ops.reduce_mean_v1) +math_ops.reduce_min = weak_tensor_unary_op_wrapper(math_ops.reduce_min) +math_ops.reduce_min_v1 = weak_tensor_unary_op_wrapper(math_ops.reduce_min_v1) +math_ops.reduce_prod = weak_tensor_unary_op_wrapper(math_ops.reduce_prod) +math_ops.reduce_prod_v1 = weak_tensor_unary_op_wrapper(math_ops.reduce_prod_v1) +math_ops.reduce_std = weak_tensor_unary_op_wrapper(math_ops.reduce_std) +math_ops.reduce_sum = weak_tensor_unary_op_wrapper(math_ops.reduce_sum) +math_ops.reduce_sum_v1 = weak_tensor_unary_op_wrapper(math_ops.reduce_sum_v1) +math_ops.reduce_variance = weak_tensor_unary_op_wrapper( + math_ops.reduce_variance +) +math_ops.trace = weak_tensor_unary_op_wrapper(math_ops.trace) +array_ops.reshape = weak_tensor_unary_op_wrapper(array_ops.reshape) +array_ops.depth_to_space = weak_tensor_unary_op_wrapper( + array_ops.depth_to_space +) +array_ops.depth_to_space_v2 = weak_tensor_unary_op_wrapper( + array_ops.depth_to_space_v2 +) +array_ops.expand_dims = weak_tensor_unary_op_wrapper(array_ops.expand_dims) +array_ops.expand_dims_v2 = weak_tensor_unary_op_wrapper( + array_ops.expand_dims_v2 +) +array_ops.extract_image_patches = weak_tensor_unary_op_wrapper( + array_ops.extract_image_patches +) +array_ops.extract_image_patches_v2 = weak_tensor_unary_op_wrapper( + array_ops.extract_image_patches_v2 +) +array_ops.identity = weak_tensor_unary_op_wrapper(array_ops.identity) +array_ops.matrix_diag = weak_tensor_unary_op_wrapper(array_ops.matrix_diag) +array_ops.matrix_diag_part = weak_tensor_unary_op_wrapper( + array_ops.matrix_diag_part +) +array_ops.matrix_transpose = weak_tensor_unary_op_wrapper( + array_ops.matrix_transpose +) +array_ops.space_to_depth = weak_tensor_unary_op_wrapper( + array_ops.space_to_depth +) +array_ops.space_to_depth_v2 = weak_tensor_unary_op_wrapper( + array_ops.space_to_depth_v2 +) +array_ops.squeeze = weak_tensor_unary_op_wrapper(array_ops.squeeze) +array_ops.squeeze_v2 = weak_tensor_unary_op_wrapper(array_ops.squeeze_v2) +array_ops.stop_gradient = weak_tensor_unary_op_wrapper(array_ops.stop_gradient) +array_ops.tensor_diag_part = weak_tensor_unary_op_wrapper( + array_ops.tensor_diag_part +) +array_ops.transpose = weak_tensor_unary_op_wrapper(array_ops.transpose) +array_ops.transpose_v2 = weak_tensor_unary_op_wrapper(array_ops.transpose_v2) + +# TF NumPy Unary Ops +np_math_ops.abs = weak_tensor_unary_op_wrapper(np_math_ops.abs) +np_math_ops.absolute = weak_tensor_unary_op_wrapper(np_math_ops.absolute) +np_math_ops.angle = weak_tensor_unary_op_wrapper(np_math_ops.angle) +np_math_ops.arccos = weak_tensor_unary_op_wrapper(np_math_ops.arccos) +np_math_ops.arcsin = weak_tensor_unary_op_wrapper(np_math_ops.arcsin) +np_math_ops.arcsinh = weak_tensor_unary_op_wrapper(np_math_ops.arcsinh) +np_math_ops.arctan = weak_tensor_unary_op_wrapper(np_math_ops.arctan) +np_math_ops.arctanh = weak_tensor_unary_op_wrapper(np_math_ops.arctanh) +np_math_ops.bitwise_not = weak_tensor_unary_op_wrapper(np_math_ops.bitwise_not) +np_math_ops.cbrt = weak_tensor_unary_op_wrapper(np_math_ops.cbrt) +np_math_ops.ceil = weak_tensor_unary_op_wrapper(np_math_ops.ceil) +np_math_ops.conj = weak_tensor_unary_op_wrapper(np_math_ops.conj) +np_math_ops.conjugate = weak_tensor_unary_op_wrapper(np_math_ops.conjugate) +np_math_ops.cos = weak_tensor_unary_op_wrapper(np_math_ops.cos) +np_math_ops.cosh = weak_tensor_unary_op_wrapper(np_math_ops.cosh) +np_math_ops.deg2rad = weak_tensor_unary_op_wrapper(np_math_ops.deg2rad) +np_math_ops.exp = weak_tensor_unary_op_wrapper(np_math_ops.exp) +np_math_ops.exp2 = weak_tensor_unary_op_wrapper(np_math_ops.exp2) +np_math_ops.expm1 = weak_tensor_unary_op_wrapper(np_math_ops.expm1) +np_math_ops.fabs = weak_tensor_unary_op_wrapper(np_math_ops.fabs) +np_math_ops.fix = weak_tensor_unary_op_wrapper(np_math_ops.fix) +np_math_ops.floor = weak_tensor_unary_op_wrapper(np_math_ops.floor) +np_math_ops.log = weak_tensor_unary_op_wrapper(np_math_ops.log) +np_math_ops.negative = weak_tensor_unary_op_wrapper(np_math_ops.negative) +np_math_ops.rad2deg = weak_tensor_unary_op_wrapper(np_math_ops.rad2deg) +np_math_ops.reciprocal = weak_tensor_unary_op_wrapper(np_math_ops.reciprocal) +np_math_ops.sin = weak_tensor_unary_op_wrapper(np_math_ops.sin) +np_math_ops.sinh = weak_tensor_unary_op_wrapper(np_math_ops.sinh) +np_math_ops.sqrt = weak_tensor_unary_op_wrapper(np_math_ops.sqrt) +np_math_ops.tan = weak_tensor_unary_op_wrapper(np_math_ops.tan) +np_math_ops.tanh = weak_tensor_unary_op_wrapper(np_math_ops.tanh) +np_math_ops.nanmean = weak_tensor_unary_op_wrapper(np_math_ops.nanmean) +np_math_ops.log2 = weak_tensor_unary_op_wrapper(np_math_ops.log2) +np_math_ops.log10 = weak_tensor_unary_op_wrapper(np_math_ops.log10) +np_math_ops.log1p = weak_tensor_unary_op_wrapper(np_math_ops.log1p) +np_math_ops.positive = weak_tensor_unary_op_wrapper(np_math_ops.positive) +np_math_ops.sinc = weak_tensor_unary_op_wrapper(np_math_ops.sinc) +np_math_ops.square = weak_tensor_unary_op_wrapper(np_math_ops.square) +np_math_ops.diff = weak_tensor_unary_op_wrapper(np_math_ops.diff) +np_math_ops.sort = weak_tensor_unary_op_wrapper(np_math_ops.sort) +np_math_ops.average = weak_tensor_unary_op_wrapper(np_math_ops.average) +np_math_ops.trace = weak_tensor_unary_op_wrapper(np_math_ops.trace) +np_array_ops.amax = weak_tensor_unary_op_wrapper(np_array_ops.amax) +np_array_ops.amin = weak_tensor_unary_op_wrapper(np_array_ops.amin) +np_array_ops.around = weak_tensor_unary_op_wrapper(np_array_ops.around) +np_array_ops.arange = weak_tensor_unary_op_wrapper(np_array_ops.arange) +np_array_ops.array = weak_tensor_unary_op_wrapper(np_array_ops.array) +np_array_ops.asanyarray = weak_tensor_unary_op_wrapper(np_array_ops.asanyarray) +np_array_ops.asarray = weak_tensor_unary_op_wrapper(np_array_ops.asarray) +np_array_ops.ascontiguousarray = weak_tensor_unary_op_wrapper( + np_array_ops.ascontiguousarray +) +np_array_ops.copy = weak_tensor_unary_op_wrapper(np_array_ops.copy) +np_array_ops.cumprod = weak_tensor_unary_op_wrapper(np_array_ops.cumprod) +np_array_ops.cumsum = weak_tensor_unary_op_wrapper(np_array_ops.cumsum) +np_array_ops.diag = weak_tensor_unary_op_wrapper(np_array_ops.diag) +np_array_ops.diagflat = weak_tensor_unary_op_wrapper(np_array_ops.diagflat) +np_array_ops.diagonal = weak_tensor_unary_op_wrapper(np_array_ops.diagonal) +np_array_ops.empty_like = weak_tensor_unary_op_wrapper(np_array_ops.empty_like) +np_array_ops.expand_dims = weak_tensor_unary_op_wrapper( + np_array_ops.expand_dims +) +np_array_ops.flatten = weak_tensor_unary_op_wrapper(np_array_ops.flatten) +np_array_ops.flip = weak_tensor_unary_op_wrapper(np_array_ops.flip) +np_array_ops.fliplr = weak_tensor_unary_op_wrapper(np_array_ops.fliplr) +np_array_ops.flipud = weak_tensor_unary_op_wrapper(np_array_ops.flipud) +np_array_ops.full_like = weak_tensor_unary_op_wrapper(np_array_ops.full_like) +np_array_ops.imag = weak_tensor_unary_op_wrapper(np_array_ops.imag) +np_array_ops.max = weak_tensor_unary_op_wrapper(np_array_ops.max) +np_array_ops.mean = weak_tensor_unary_op_wrapper(np_array_ops.mean) +np_array_ops.min = weak_tensor_unary_op_wrapper(np_array_ops.min) +np_array_ops.moveaxis = weak_tensor_unary_op_wrapper(np_array_ops.moveaxis) +np_array_ops.ones_like = weak_tensor_unary_op_wrapper(np_array_ops.ones_like) +np_array_ops.prod = weak_tensor_unary_op_wrapper(np_array_ops.prod) +np_array_ops.ravel = weak_tensor_unary_op_wrapper(np_array_ops.ravel) +np_array_ops.real = weak_tensor_unary_op_wrapper(np_array_ops.real) +np_array_ops.reshape = weak_tensor_unary_op_wrapper(np_array_ops.reshape) +np_array_ops.rot90 = weak_tensor_unary_op_wrapper(np_array_ops.rot90) +np_array_ops.round = weak_tensor_unary_op_wrapper(np_array_ops.round) +np_array_ops.squeeze = weak_tensor_unary_op_wrapper(np_array_ops.squeeze) +np_array_ops.std = weak_tensor_unary_op_wrapper(np_array_ops.std) +np_array_ops.sum = weak_tensor_unary_op_wrapper(np_array_ops.sum) +np_array_ops.swapaxes = weak_tensor_unary_op_wrapper(np_array_ops.swapaxes) +np_array_ops.transpose = weak_tensor_unary_op_wrapper(np_array_ops.transpose) +np_array_ops.triu = weak_tensor_unary_op_wrapper(np_array_ops.triu) +np_array_ops.vander = weak_tensor_unary_op_wrapper(np_array_ops.vander) +np_array_ops.var = weak_tensor_unary_op_wrapper(np_array_ops.var) +np_array_ops.zeros_like = weak_tensor_unary_op_wrapper(np_array_ops.zeros_like) + +# ============================================================================== +# Update old op references. +# ============================================================================== +# Update Tensor dunder methods. +tensor.Tensor.__add__ = math_ops.add +tensor.Tensor.__sub__ = math_ops.sub +tensor.Tensor.__mul__ = math_ops.multiply +tensor.Tensor.__div__ = math_ops.div +tensor.Tensor.__truediv__ = math_ops.truediv +tensor.Tensor.__floordiv__ = math_ops.floordiv +tensor.Tensor.__mod__ = gen_math_ops.floor_mod +tensor.Tensor.__pow__ = math_ops.pow +tensor.Tensor.__matmul__ = math_ops.matmul + +# Set WeakTensor dunder methods. +weak_tensor.WeakTensor.__invert__ = math_ops.invert_ +weak_tensor.WeakTensor.__neg__ = gen_math_ops.neg +weak_tensor.WeakTensor.__abs__ = math_ops.abs +weak_tensor.WeakTensor.__add__ = math_ops.add +weak_tensor.WeakTensor.__sub__ = math_ops.sub +weak_tensor.WeakTensor.__mul__ = math_ops.multiply +weak_tensor.WeakTensor.__div__ = math_ops.div +weak_tensor.WeakTensor.__truediv__ = math_ops.truediv +weak_tensor.WeakTensor.__floordiv__ = math_ops.floordiv +weak_tensor.WeakTensor.__mod__ = gen_math_ops.floor_mod +weak_tensor.WeakTensor.__pow__ = math_ops.pow +weak_tensor.WeakTensor.__matmul__ = math_ops.matmul diff --git a/tensorflow/python/ops/weak_tensor_ops_list.py b/tensorflow/python/ops/weak_tensor_ops_list.py deleted file mode 100644 index 067feb621e0e61..00000000000000 --- a/tensorflow/python/ops/weak_tensor_ops_list.py +++ /dev/null @@ -1,251 +0,0 @@ -# Copyright 2023 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Lists of ops that support WeakTensor.""" - -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import clip_ops -from tensorflow.python.ops import gen_array_ops -from tensorflow.python.ops import gen_bitwise_ops -from tensorflow.python.ops import gen_math_ops -from tensorflow.python.ops import gen_nn_ops -from tensorflow.python.ops import image_ops_impl -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import nn_impl -from tensorflow.python.ops import nn_ops -from tensorflow.python.ops import special_math_ops -from tensorflow.python.ops.numpy_ops import np_array_ops -from tensorflow.python.ops.numpy_ops import np_math_ops - - -# Below are lists of unary ops that return a WeakTensor when given a WeakTensor -# input. These are some of the reasons why ops may not support WeakTensor. -# (1) The return dtype is specified. (e.g. tofloat(), cast(), is_finite()) -# (2) The list is prioritized to unary elementwise ops, TF-NumPy ops, math_ops, -# and array_ops. -# (3) There is no "weak" string type so any string ops are not supported. -# If you wish to add support to a specific unary op, add the unary op to a -# corresponding list. - -_ELEMENTWISE_UNARY_OPS = [ - math_ops.abs, - math_ops.softplus, - math_ops.sign, - math_ops.real, - math_ops.imag, - math_ops.angle, - math_ops.round, - math_ops.sigmoid, - math_ops.log_sigmoid, - math_ops.conj, - math_ops.reciprocal_no_nan, - math_ops.erfinv, - math_ops.ndtri, - math_ops.erfcinv, - math_ops.ceil, - math_ops.sqrt, - math_ops.exp, - math_ops.rsqrt, - math_ops.acos, - math_ops.floor, - gen_bitwise_ops.invert, - gen_math_ops.acosh, - gen_math_ops.asin, - gen_math_ops.asinh, - gen_math_ops.atan, - gen_math_ops.atanh, - gen_math_ops.cos, - gen_math_ops.cosh, - gen_math_ops.digamma, - gen_math_ops.erf, - gen_math_ops.erfc, - gen_math_ops.expm1, - gen_math_ops.lgamma, - gen_math_ops.log, - gen_math_ops.log1p, - gen_math_ops.neg, - gen_math_ops.reciprocal, - gen_math_ops.rint, - gen_math_ops.sin, - gen_math_ops.sinh, - gen_math_ops.square, - gen_math_ops.tan, - gen_math_ops.tanh, - array_ops.zeros_like, - array_ops.zeros_like_v2, - array_ops.ones_like, - array_ops.ones_like_v2, - gen_array_ops.check_numerics, - nn_ops.relu6, - nn_ops.leaky_relu, - nn_ops.gelu, - nn_ops.log_softmax, - gen_nn_ops.elu, - gen_nn_ops.relu, - gen_nn_ops.selu, - gen_nn_ops.softsign, - image_ops_impl.random_brightness, - image_ops_impl.stateless_random_brightness, - image_ops_impl.adjust_brightness, - image_ops_impl.adjust_gamma, - nn_impl.swish, - clip_ops.clip_by_value, - special_math_ops.dawsn, - special_math_ops.expint, - special_math_ops.fresnel_cos, - special_math_ops.fresnel_sin, - special_math_ops.spence, - special_math_ops.bessel_i0, - special_math_ops.bessel_i0e, - special_math_ops.bessel_i1, - special_math_ops.bessel_i1e, - special_math_ops.bessel_k0, - special_math_ops.bessel_k0e, - special_math_ops.bessel_k1, - special_math_ops.bessel_k1e, - special_math_ops.bessel_j0, - special_math_ops.bessel_j1, - special_math_ops.bessel_y0, - special_math_ops.bessel_y1, -] -_TF_UNARY_OPS = [ - math_ops.reduce_euclidean_norm, - math_ops.reduce_logsumexp, - math_ops.reduce_max, - math_ops.reduce_max_v1, - math_ops.reduce_mean, - math_ops.reduce_mean_v1, - math_ops.reduce_min, - math_ops.reduce_min_v1, - math_ops.reduce_prod, - math_ops.reduce_prod_v1, - math_ops.reduce_std, - math_ops.reduce_sum, - math_ops.reduce_sum_v1, - math_ops.reduce_variance, - math_ops.trace, - array_ops.depth_to_space, - array_ops.depth_to_space_v2, - array_ops.expand_dims, - array_ops.expand_dims_v2, - array_ops.extract_image_patches, - array_ops.extract_image_patches_v2, - array_ops.identity, - array_ops.matrix_diag, - array_ops.matrix_diag_part, - array_ops.matrix_transpose, - array_ops.shape, - array_ops.shape_v2, - array_ops.size, - array_ops.size_v2, - array_ops.space_to_depth, - array_ops.space_to_depth_v2, - array_ops.squeeze, - array_ops.squeeze_v2, - array_ops.stop_gradient, - array_ops.tensor_diag_part, - array_ops.transpose, - array_ops.transpose_v2, -] -_TF_NUMPY_UNARY_OPS = [ - np_math_ops.abs, - np_math_ops.absolute, - np_math_ops.angle, - np_math_ops.arccos, - np_math_ops.arcsin, - np_math_ops.arcsinh, - np_math_ops.arctan, - np_math_ops.arctanh, - np_math_ops.bitwise_not, - np_math_ops.cbrt, - np_math_ops.ceil, - np_math_ops.conj, - np_math_ops.conjugate, - np_math_ops.cos, - np_math_ops.cosh, - np_math_ops.deg2rad, - np_math_ops.exp, - np_math_ops.exp2, - np_math_ops.expm1, - np_math_ops.fabs, - np_math_ops.fix, - np_math_ops.floor, - np_math_ops.log, - np_math_ops.negative, - np_math_ops.rad2deg, - np_math_ops.reciprocal, - np_math_ops.sin, - np_math_ops.sinh, - np_math_ops.sqrt, - np_math_ops.tan, - np_math_ops.tanh, - np_math_ops.nanmean, - np_math_ops.log2, - np_math_ops.log10, - np_math_ops.log1p, - np_math_ops.positive, - np_math_ops.sinc, - np_math_ops.square, - np_math_ops.diff, - np_math_ops.sort, - np_math_ops.average, - np_math_ops.trace, - np_array_ops.amax, - np_array_ops.amin, - np_array_ops.around, - np_array_ops.arange, - np_array_ops.array, - np_array_ops.asanyarray, - np_array_ops.asarray, - np_array_ops.ascontiguousarray, - np_array_ops.copy, - np_array_ops.cumprod, - np_array_ops.cumsum, - np_array_ops.diag, - np_array_ops.diagflat, - np_array_ops.diagonal, - np_array_ops.empty_like, - np_array_ops.expand_dims, - np_array_ops.flatten, - np_array_ops.flip, - np_array_ops.fliplr, - np_array_ops.flipud, - np_array_ops.imag, - np_array_ops.max, - np_array_ops.mean, - np_array_ops.min, - np_array_ops.moveaxis, - np_array_ops.ones_like, - np_array_ops.prod, - np_array_ops.ravel, - np_array_ops.real, - np_array_ops.reshape, - np_array_ops.rot90, - np_array_ops.round, - np_array_ops.squeeze, - np_array_ops.std, - np_array_ops.sum, - np_array_ops.swapaxes, - np_array_ops.transpose, - np_array_ops.triu, - np_array_ops.vander, - np_array_ops.var, - np_array_ops.zeros_like, -] - -# Below are lists of binary ops that have support for WeakTensor input(s). -_ELEMENTWISE_BINARY_OPS = [] - -ALL_UNARY_OPS = _ELEMENTWISE_UNARY_OPS + _TF_UNARY_OPS + _TF_NUMPY_UNARY_OPS -ALL_BINARY_OPS = _ELEMENTWISE_BINARY_OPS diff --git a/tensorflow/python/ops/weak_tensor_ops_test.py b/tensorflow/python/ops/weak_tensor_ops_test.py index c7816bc8092f25..cdcdedd31603bd 100644 --- a/tensorflow/python/ops/weak_tensor_ops_test.py +++ b/tensorflow/python/ops/weak_tensor_ops_test.py @@ -13,10 +13,13 @@ # limitations under the License. # ============================================================================== """Tests for TF ops with WeakTensor input.""" + from absl.testing import parameterized +import numpy as np from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import extension_type from tensorflow.python.framework import ops from tensorflow.python.framework import tensor from tensorflow.python.framework import test_util @@ -27,15 +30,21 @@ from tensorflow.python.ops import gen_bitwise_ops from tensorflow.python.ops import image_ops_impl from tensorflow.python.ops import math_ops -from tensorflow.python.ops import weak_tensor_ops # pylint: disable=unused-import -from tensorflow.python.ops import weak_tensor_ops_list +from tensorflow.python.ops import weak_tensor_ops +from tensorflow.python.ops import weak_tensor_test_util from tensorflow.python.ops.numpy_ops import np_array_ops from tensorflow.python.ops.numpy_ops import np_config from tensorflow.python.ops.numpy_ops import np_math_ops +from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.platform import googletest +from tensorflow.python.util import dispatch + +_get_weak_tensor = weak_tensor_test_util.get_weak_tensor +_convert_to_input_type = weak_tensor_test_util.convert_to_input_type -_TF_UNARY_APIS = weak_tensor_ops_list.ALL_UNARY_OPS + +_TF_UNARY_APIS = weak_tensor_ops._TF_UNARY_APIS _TF_UNARY_APIS_SPECIFIC_DTYPE = [ math_ops.to_float, math_ops.to_double, @@ -52,9 +61,11 @@ image_ops_impl.adjust_brightness, clip_ops.clip_by_value, np_array_ops.expand_dims, + np_array_ops.full_like, np_array_ops.moveaxis, np_array_ops.reshape, np_array_ops.swapaxes, + array_ops.reshape, array_ops.depth_to_space, array_ops.depth_to_space_v2, array_ops.expand_dims, @@ -87,102 +98,245 @@ ] +class MyTensor(extension_type.ExtensionType): + value: tensor.Tensor + + class WeakTensorOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase): # Test unary ops with one input. - @parameterized.parameters( - set(_TF_UNARY_APIS) - set(_TF_UNARY_APIS_WITH_MULT_INPUT) + @parameterized.named_parameters( + (api.__module__ + "." + api.__name__, api) + for api in set(_TF_UNARY_APIS) - set(_TF_UNARY_APIS_WITH_MULT_INPUT) ) def test_unary_ops_return_weak_tensor(self, unary_api): - op_input = _get_test_input(unary_api) - res = unary_api(op_input) - # Check that WeakTensor is returned. + weak_tensor_input, python_input, tensor_input, numpy_input = ( + _get_test_input(unary_api) + ) + + # Check that WeakTensor input outputs a WeakTensor. + res = unary_api(weak_tensor_input) self.assertIsInstance(res, WeakTensor) + expected_result = unary_api(weak_tensor_input.tensor) # Check that the actual result is correct. - expected_result = unary_api(op_input.tensor) self.assertAllEqual(res, expected_result) + # Check that python nested scalar type (weak type) returns a WeakTensor. + res = unary_api(python_input) + self.assertIsInstance(res, WeakTensor) + + # Check that normal Tensor input outputs a Tensor. + res = unary_api(tensor_input) + self.assertIsInstance(res, tensor.Tensor) + + # Check that numpy type input outputs a Tensor. + res = unary_api(numpy_input) + self.assertIsInstance(res, tensor.Tensor) + # Test unary ops with multiple inputs. - def test_multi_arg_unary_ops_return_weak_tensor(self): - a = WeakTensor(constant_op.constant([1, 2, 3], dtypes.float32)) + @parameterized.parameters( + ("WeakTensor", dtypes.float32, WeakTensor), + ("Python", dtypes.float32, WeakTensor), + ("NumPy", np.float32, tensor.Tensor), + ("NumPy", None, tensor.Tensor), + ("Tensor", dtypes.float32, tensor.Tensor), + ) + def test_multi_arg_unary_ops_return_weak_tensor( + self, input_type, input_dtype, result_type + ): + test_input = _convert_to_input_type( + [1.0, 2.0, 3.0], input_type, input_dtype + ) + self.assertIsInstance( + gen_array_ops.check_numerics(test_input, message=""), result_type + ) self.assertIsInstance( - gen_array_ops.check_numerics(a, message=""), WeakTensor + image_ops_impl.random_brightness(test_input, 0.2), result_type ) - self.assertIsInstance(image_ops_impl.random_brightness(a, 0.2), WeakTensor) self.assertIsInstance( image_ops_impl.stateless_random_brightness( - image=a, max_delta=0.2, seed=(1, 2) + image=test_input, max_delta=0.2, seed=(1, 2) + ), + result_type, + ) + self.assertIsInstance( + image_ops_impl.adjust_brightness(test_input, delta=0.2), result_type + ) + self.assertIsInstance( + clip_ops.clip_by_value( + test_input, clip_value_min=1.1, clip_value_max=2.2 ), - WeakTensor, + result_type, ) self.assertIsInstance( - image_ops_impl.adjust_brightness(a, delta=0.2), WeakTensor + np_array_ops.expand_dims(test_input, axis=0), result_type ) self.assertIsInstance( - clip_ops.clip_by_value(a, clip_value_min=1.1, clip_value_max=2.2), - WeakTensor, + np_array_ops.moveaxis(test_input, source=0, destination=0), result_type ) - self.assertIsInstance(np_array_ops.expand_dims(a, axis=0), WeakTensor) self.assertIsInstance( - np_array_ops.moveaxis(a, source=0, destination=0), WeakTensor + np_array_ops.reshape(test_input, newshape=(3,)), result_type ) - self.assertIsInstance(np_array_ops.reshape(a, newshape=(3,)), WeakTensor) self.assertIsInstance( - np_array_ops.swapaxes(a, axis1=0, axis2=0), WeakTensor + np_array_ops.swapaxes(test_input, axis1=0, axis2=0), result_type + ) + self.assertIsInstance( + array_ops.reshape(test_input, shape=(3,)), result_type + ) + self.assertIsInstance( + array_ops.expand_dims(test_input, axis=0), result_type ) - self.assertIsInstance(array_ops.expand_dims(a, axis=0), WeakTensor) # Test unary ops with a specific return dtype. @parameterized.parameters(_TF_UNARY_APIS_SPECIFIC_DTYPE) def test_unary_ops_return_normal_tensor(self, unary_api_specific_dtype): - a = WeakTensor(constant_op.constant([1, 2, 3], dtypes.float32)) - res = unary_api_specific_dtype(a) + # All inputs should output a normal Tensor because return dtype is + # specified. + weak_tensor_input = _get_weak_tensor([1, 2, 3], dtypes.float32) + res = unary_api_specific_dtype(weak_tensor_input) + self.assertIsInstance(res, tensor.Tensor) + + python_input = [1.0, 2.0, 3.0] + res = unary_api_specific_dtype(python_input) + self.assertIsInstance(res, tensor.Tensor) + + tensor_input = constant_op.constant([1.0, 2.0, 3.0], dtypes.float32) + res = unary_api_specific_dtype(tensor_input) + self.assertIsInstance(res, tensor.Tensor) + + tensor_input = np.array([1.0, 2.0, 3.0]) + res = unary_api_specific_dtype(tensor_input) self.assertIsInstance(res, tensor.Tensor) # Test unary ops with optional dtype arg. - def test_elementwise_unary_ops_optional_dtype(self): - a = WeakTensor(constant_op.constant([1, 2, 3], dtypes.float32)) + @parameterized.parameters( + ("WeakTensor", dtypes.float32, WeakTensor), + ("Python", None, WeakTensor), + ("NumPy", np.float32, tensor.Tensor), + ("NumPy", None, tensor.Tensor), + ("Tensor", dtypes.float32, tensor.Tensor), + ) + def test_elementwise_unary_ops_optional_dtype( + self, input_type, input_dtype, result_type + ): + test_input = _convert_to_input_type( + [1.0, 2.0, 3.0], input_type, input_dtype + ) # No dtype specified in the argument. - self.assertIsInstance(array_ops.zeros_like(a), WeakTensor) - self.assertIsInstance(array_ops.ones_like(a), WeakTensor) - self.assertIsInstance(array_ops.ones_like(a, dtype=None), WeakTensor) + self.assertIsInstance(array_ops.zeros_like(test_input), result_type) + self.assertIsInstance(array_ops.ones_like(test_input), result_type) + self.assertIsInstance( + array_ops.ones_like(test_input, dtype=None), result_type + ) # dtype specified in the argument. self.assertIsInstance( - array_ops.zeros_like(a, dtype=dtypes.int32), tensor.Tensor + array_ops.zeros_like(test_input, dtype=dtypes.int32), tensor.Tensor ) self.assertIsInstance( - array_ops.ones_like(a, dtype=dtypes.int32), tensor.Tensor + array_ops.ones_like(test_input, dtype=dtypes.int32), tensor.Tensor ) - self.assertIsInstance(array_ops.zeros_like(a, dtypes.int32), tensor.Tensor) - self.assertIsInstance(array_ops.ones_like(a, dtypes.int32), tensor.Tensor) self.assertIsInstance( - np_array_ops.arange( - WeakTensor(constant_op.constant(5)), 0, 1, dtypes.float32 - ), - tensor.Tensor, + array_ops.zeros_like(test_input, dtypes.int32), tensor.Tensor + ) + self.assertIsInstance( + array_ops.ones_like(test_input, dtypes.int32), tensor.Tensor + ) + + @parameterized.parameters( + ("WeakTensor", dtypes.float32, None, WeakTensor), + ("WeakTensor", dtypes.float32, dtypes.int32, tensor.Tensor), + ("Python", None, None, WeakTensor), + ("Python", None, dtypes.int32, tensor.Tensor), + ("NumPy", None, None, tensor.Tensor), + ("NumPy", None, np.int32, tensor.Tensor), + ("Tensor", dtypes.float32, None, tensor.Tensor), + ("Tensor", dtypes.float32, dtypes.int32, tensor.Tensor), + ) + # Test unary ops with multiple args that includes an optional dtype arg. + def test_elementwise_unary_ops_optional_dtype_with_multi_args( + self, input_type, input_dtype, dtype_arg, result_type + ): + test_input = _convert_to_input_type(5, input_type, input_dtype) + self.assertIsInstance( + np_array_ops.arange(test_input, 10, dtype=dtype_arg), result_type + ) + self.assertIsInstance( + np_array_ops.full_like(test_input, 1, dtype=dtype_arg), result_type ) # Test unary ops that require dtype arg. def test_unary_ops_explicit_dtype_return(self): - a = WeakTensor(constant_op.constant([1, 2, 3], dtypes.float32)) - self.assertIsInstance(math_ops.cast(a, dtypes.int32), tensor.Tensor) + wt_input = _get_weak_tensor([1, 2, 3], dtypes.float32) + self.assertIsInstance(math_ops.cast(wt_input, dtypes.int32), tensor.Tensor) + self.assertIsInstance( + math_ops.saturate_cast(wt_input, dtypes.int32), tensor.Tensor + ) + + python_input = [1.0, 2.0, 3.0] self.assertIsInstance( - math_ops.saturate_cast(a, dtypes.int32), tensor.Tensor) + math_ops.cast(python_input, dtypes.int32), tensor.Tensor + ) + self.assertIsInstance( + math_ops.saturate_cast(python_input, dtypes.int32), tensor.Tensor + ) + + def test_unsupported_input_type_in_weak_tensor_ops(self): + rt = ragged_tensor.RaggedTensor.from_row_splits( + values=[3, 1, 4, 1, 5, 9, 2, 6], row_splits=[0, 4, 4, 7, 8, 8] + ) + # Any unsupported type should be ignored in WeakTensor wrapper. + self.assertIsInstance(math_ops.abs(rt), ragged_tensor.RaggedTensor) + + def test_update_weak_tensor_patched_ops_in_dispatch_dict(self): + dispatch_dict = dispatch._TYPE_BASED_DISPATCH_SIGNATURES + # Test that we can use the updated op reference as a key to the dispatch + # dictionary. + self.assertTrue(hasattr(math_ops.abs, "_tf_decorator")) + self.assertNotEmpty(dispatch_dict[math_ops.abs]) + + def test_weak_tensor_ops_dispatch(self): + @dispatch.dispatch_for_api(math_ops.abs) + def my_abs(x: MyTensor): + return MyTensor(math_ops.abs(x.value)) + + self.assertIsInstance(my_abs(MyTensor(constant_op.constant(1.0))), MyTensor) + # Test unregistering dispatch with patched op reference. + dispatch.unregister_dispatch_for(my_abs) + with self.assertRaises(ValueError): + math_ops.abs(MyTensor(constant_op.constant(1.0))) + +# TODO(b/289333658): Add tf.constant(x) with no dtype arg as a "weak" input +# after adding WeakTensor construction logic to tf.constant. def _get_test_input(op): if op in _TF_UNARY_APIS_WITH_INT_INPUT: - return WeakTensor(constant_op.constant(5, dtypes.int32)) + return ( + _get_weak_tensor(5, dtypes.int32), + 5, + constant_op.constant(5, dtypes.int32), + np.array(5), + ) elif op in _TF_UNARY_APIS_WITH_2D_INPUT: - return WeakTensor(constant_op.constant([[1, 2], [3, 4]], dtypes.int32)) + return ( + _get_weak_tensor([[1, 2], [3, 4]], dtypes.int32), + [[1, 2], [3, 4]], + constant_op.constant([[1, 2], [3, 4]], dtypes.int32), + np.array([[1, 2], [3, 4]]), + ) else: - return WeakTensor(constant_op.constant([1, 2, 3], dtypes.float32)) + return ( + _get_weak_tensor([1.0, 2.0, 3.0], dtype=dtypes.float32), + [1.0, 2.0, 3.0], + constant_op.constant([1.0, 2.0, 3.0], dtype=dtypes.float32), + np.array([1.0, 2.0, 3.0]), + ) if __name__ == "__main__": ops.enable_eager_execution() # Enabling numpy behavior adds some NumPy methods to the Tensor class, which # TF-NumPy ops depend on. - np_config.enable_numpy_behavior() + np_config.enable_numpy_behavior(dtype_conversion_mode="all") googletest.main() diff --git a/tensorflow/python/ops/weak_tensor_test_util.py b/tensorflow/python/ops/weak_tensor_test_util.py index eec7b936a24015..aa117def50c086 100644 --- a/tensorflow/python/ops/weak_tensor_test_util.py +++ b/tensorflow/python/ops/weak_tensor_test_util.py @@ -14,7 +14,28 @@ # ============================================================================== """Utils for WeakTensor related tests.""" +import numpy as np + +from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops +from tensorflow.python.framework.weak_tensor import WeakTensor + + +def convert_to_input_type(base_input, input_type, dtype=None): + if input_type == "WeakTensor": + return WeakTensor(constant_op.constant(base_input, dtype=dtype)) + elif input_type == "Tensor": + return constant_op.constant(base_input, dtype=dtype) + elif input_type == "NumPy": + return np.array(base_input, dtype=dtype) + elif input_type == "Python": + return base_input + else: + raise ValueError(f"The provided input_type {input_type} is not supported.") + + +def get_weak_tensor(*args, **kwargs): + return WeakTensor(constant_op.constant(*args, **kwargs)) class DtypeConversionTestEnv: diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD index a8ef26478df3a3..08033624332047 100644 --- a/tensorflow/tools/pip_package/BUILD +++ b/tensorflow/tools/pip_package/BUILD @@ -160,7 +160,6 @@ COMMON_PIP_DEPS = [ "//tensorflow/python/kernel_tests/signal:test_util", "//tensorflow/python/kernel_tests/sparse_ops:sparse_xent_op_test_base", "//tensorflow/python/lib:__init__", - "//tensorflow/python/ops:weak_tensor_ops", "//tensorflow/python/ops/parallel_for:test_util", "//tensorflow/python/ops/structured:structured_tensor_dynamic", "//tensorflow/python/platform:resource_loader", From 63bfa44151a113a605ab17a9a6903e970a233ff4 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 12 Jul 2023 12:16:36 -0700 Subject: [PATCH 209/376] Integrate LLVM at llvm/llvm-project@b10899d86995 Updates LLVM usage to match [b10899d86995](https://github.com/llvm/llvm-project/commit/b10899d86995) PiperOrigin-RevId: 547566590 --- third_party/llvm/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl index 25289428d729d0..843f993f9edec6 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "5671f023042b558d38c3b777ee4ae0ad037b1867" - LLVM_SHA256 = "353607dd4ca5b20e6a2ec6650353dd5de006829e5be502716383624152bb1f0f" + LLVM_COMMIT = "b10899d869954e1426684cbc20a43d7303075d49" + LLVM_SHA256 = "62df1d4c4a10d9fa1c805b8eeddd5448e819ee98cf2ac8306b63b68d67656568" tf_http_archive( name = name, From 0108f4450922ea810b4569c2df02049e39a67765 Mon Sep 17 00:00:00 2001 From: David Dunleavy Date: Wed, 12 Jul 2023 12:32:17 -0700 Subject: [PATCH 210/376] Delete get_compatible_with_cloud. Update all users to get_compatible_with_portable PiperOrigin-RevId: 547570705 --- tensorflow/tensorflow.bzl | 3 --- tensorflow/tensorflow.default.bzl | 2 -- tensorflow/tsl/tsl.default.bzl | 3 --- 3 files changed, 8 deletions(-) diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl index 7f32fa50914c7c..af8e82b0adce41 100644 --- a/tensorflow/tensorflow.bzl +++ b/tensorflow/tensorflow.bzl @@ -3433,9 +3433,6 @@ def tf_grpc_cc_dependencies(): def get_compatible_with_portable(): return [] -def get_compatible_with_cloud(): - return [] - def filegroup(**kwargs): native.filegroup(**kwargs) diff --git a/tensorflow/tensorflow.default.bzl b/tensorflow/tensorflow.default.bzl index 017268250c3c6c..9c6515f9798e5e 100644 --- a/tensorflow/tensorflow.default.bzl +++ b/tensorflow/tensorflow.default.bzl @@ -8,7 +8,6 @@ load( _cuda_py_test = "cuda_py_test", _filegroup = "filegroup", _genrule = "genrule", - _get_compatible_with_cloud = "get_compatible_with_cloud", _get_compatible_with_portable = "get_compatible_with_portable", _if_indexing_source_code = "if_indexing_source_code", _if_not_mobile_or_arm_or_lgpl_restricted = "if_not_mobile_or_arm_or_lgpl_restricted", @@ -81,7 +80,6 @@ tf_external_workspace_visible = _tf_external_workspace_visible tf_grpc_dependencies = _tf_grpc_dependencies tf_grpc_cc_dependencies = _tf_grpc_cc_dependencies get_compatible_with_portable = _get_compatible_with_portable -get_compatible_with_cloud = _get_compatible_with_cloud cc_header_only_library = _cc_header_only_library tf_gen_op_libs = _tf_gen_op_libs tf_gen_op_wrapper_cc = _tf_gen_op_wrapper_cc diff --git a/tensorflow/tsl/tsl.default.bzl b/tensorflow/tsl/tsl.default.bzl index 37b772a9f3b12c..1d339f95d1a6c8 100644 --- a/tensorflow/tsl/tsl.default.bzl +++ b/tensorflow/tsl/tsl.default.bzl @@ -31,9 +31,6 @@ internal_hlo_deps = _internal_hlo_deps tsl_grpc_cc_dependencies = _tsl_grpc_cc_dependencies tsl_pybind_extension = _tsl_pybind_extension -def get_compatible_with_cloud(): - return [] - def tsl_gpu_cc_test( name, srcs = [], From b9e1aff0c19f4cc6ef5133e67333cfcc2e201534 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 12 Jul 2023 12:54:35 -0700 Subject: [PATCH 211/376] Update TFRT dependency to use revision http://github.com/tensorflow/runtime/commit/b325dedfa3a47df75e06da3640424c1bdb28dd3a. PiperOrigin-RevId: 547576359 --- third_party/tf_runtime/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/tf_runtime/workspace.bzl b/third_party/tf_runtime/workspace.bzl index b9bb8c2fba04f5..8d575144823d6c 100644 --- a/third_party/tf_runtime/workspace.bzl +++ b/third_party/tf_runtime/workspace.bzl @@ -6,8 +6,8 @@ def repo(): """Imports TFRT.""" # Attention: tools parse and update these lines. - TFRT_COMMIT = "cf061a8afb57bad642bcc01442c414bf76fc3074" - TFRT_SHA256 = "d4d05de303b5126d0e648a40e1c74091013d022179a5a0fa3fe2d13f7a73f2de" + TFRT_COMMIT = "b325dedfa3a47df75e06da3640424c1bdb28dd3a" + TFRT_SHA256 = "bc4341e8c6d0deed35b662903a82008b21b88127aec053b1a250b92219f4f0c9" tf_http_archive( name = "tf_runtime", From 3ce365a2b9356ec16f4c81de83c1d6194429b6a5 Mon Sep 17 00:00:00 2001 From: Weiyi Wang Date: Wed, 12 Jul 2023 12:55:47 -0700 Subject: [PATCH 212/376] Add experimental APIs to signature runner to set custom allocation for IO tensors. PiperOrigin-RevId: 547576673 --- tensorflow/lite/signature_runner.cc | 22 +++++++++++++ tensorflow/lite/signature_runner.h | 50 +++++++++++++++++++++++++++++ 2 files changed, 72 insertions(+) diff --git a/tensorflow/lite/signature_runner.cc b/tensorflow/lite/signature_runner.cc index 3223aefa3cce36..521f5abaaf3eb4 100644 --- a/tensorflow/lite/signature_runner.cc +++ b/tensorflow/lite/signature_runner.cc @@ -85,4 +85,26 @@ TfLiteStatus SignatureRunner::Invoke() { return kTfLiteOk; } +TfLiteStatus SignatureRunner::SetCustomAllocationForInputTensor( + const char* input_name, const TfLiteCustomAllocation& allocation, + int64_t flags) { + const auto& it = signature_def_->inputs.find(input_name); + if (it == signature_def_->inputs.end()) { + subgraph_->ReportError("Input name %s was not found", input_name); + return kTfLiteError; + } + return subgraph_->SetCustomAllocationForTensor(it->second, allocation, flags); +} + +TfLiteStatus SignatureRunner::SetCustomAllocationForOutputTensor( + const char* output_name, const TfLiteCustomAllocation& allocation, + int64_t flags) { + const auto& it = signature_def_->outputs.find(output_name); + if (it == signature_def_->outputs.end()) { + subgraph_->ReportError("Output name %s was not found", output_name); + return kTfLiteError; + } + return subgraph_->SetCustomAllocationForTensor(it->second, allocation, flags); +} + } // namespace tflite diff --git a/tensorflow/lite/signature_runner.h b/tensorflow/lite/signature_runner.h index ae904e99edd1ca..165c98ef82bca7 100644 --- a/tensorflow/lite/signature_runner.h +++ b/tensorflow/lite/signature_runner.h @@ -145,6 +145,56 @@ class SignatureRunner { /// WARNING: This is an experimental API and subject to change. TfLiteStatus Cancel() { return subgraph_->Cancel(); } + /// \brief Assigns (or reassigns) a custom memory allocation for the given + /// tensor name. `flags` is a bitmask, see TfLiteCustomAllocationFlags. + /// The runtime does NOT take ownership of the underlying memory. + /// + /// NOTE: User needs to call AllocateTensors() after this. + /// Invalid/insufficient buffers will cause an error during AllocateTensors or + /// Invoke (in case of dynamic shapes in the graph). + /// + /// Parameters should satisfy the following conditions: + /// 1. tensor->allocation_type == kTfLiteArenaRw or kTfLiteArenaRwPersistent + /// In general, this is true for I/O tensors & variable tensors. + /// 2. allocation->data has the appropriate permissions for runtime access + /// (Read-only for inputs, Read-Write for others), and outlives + /// Interpreter. + /// 3. allocation->bytes >= tensor->bytes. + /// This condition is checked again if any tensors are resized. + /// 4. allocation->data should be aligned to kDefaultTensorAlignment + /// defined in lite/util.h. (Currently 64 bytes) + /// This check is skipped if kTfLiteCustomAllocationFlagsSkipAlignCheck is + /// set through `flags`. + /// \warning This is an experimental API and subject to change. \n + TfLiteStatus SetCustomAllocationForInputTensor( + const char* input_name, const TfLiteCustomAllocation& allocation, + int64_t flags = kTfLiteCustomAllocationFlagsNone); + + /// \brief Assigns (or reassigns) a custom memory allocation for the given + /// tensor name. `flags` is a bitmask, see TfLiteCustomAllocationFlags. + /// The runtime does NOT take ownership of the underlying memory. + /// + /// NOTE: User needs to call AllocateTensors() after this. + /// Invalid/insufficient buffers will cause an error during AllocateTensors or + /// Invoke (in case of dynamic shapes in the graph). + /// + /// Parameters should satisfy the following conditions: + /// 1. tensor->allocation_type == kTfLiteArenaRw or kTfLiteArenaRwPersistent + /// In general, this is true for I/O tensors & variable tensors. + /// 2. allocation->data has the appropriate permissions for runtime access + /// (Read-only for inputs, Read-Write for others), and outlives + /// Interpreter. + /// 3. allocation->bytes >= tensor->bytes. + /// This condition is checked again if any tensors are resized. + /// 4. allocation->data should be aligned to kDefaultTensorAlignment + /// defined in lite/util.h. (Currently 64 bytes) + /// This check is skipped if kTfLiteCustomAllocationFlagsSkipAlignCheck is + /// set through `flags`. + /// \warning This is an experimental API and subject to change. \n + TfLiteStatus SetCustomAllocationForOutputTensor( + const char* output_name, const TfLiteCustomAllocation& allocation, + int64_t flags = kTfLiteCustomAllocationFlagsNone); + private: // The life cycle of SignatureRunner depends on the life cycle of Subgraph, // which is owned by an Interpreter. Therefore, the Interpreter will takes the From 59db07174f496e8e2c58db0f07d4ed36a4d85e57 Mon Sep 17 00:00:00 2001 From: Justin Lebar Date: Wed, 12 Jul 2023 12:58:07 -0700 Subject: [PATCH 213/376] [XLA:GPU] Enable relu6 fusion on Turing. This was disabled because nvidia had not done much testing on Turing. But the benchmark results look good, so this seems good to try. PiperOrigin-RevId: 547577240 --- .../xla/service/gpu/cudnn_fused_conv_rewriter.cc | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.cc b/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.cc index 6fa52445f637fa..3bb00f58de0b07 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.cc @@ -67,12 +67,13 @@ bool IsNonDepthwiseConvCustomCall(const HloInstruction* instr) { // elu, relu6, and leaky-relu activations are supported in cudnn via the // "runtime fusion" engine, which JIT compiles C++ code. This can be slow to -// compile, so we guard it with a debug option. Also nvidia currently -// recommends that we enable this only on Ampere+. +// compile, so we guard it with a debug option. +// +// nvidia currently recommends that we enable this only on Ampere+, but we've +// tested on Turing (sm75) and it seems to work fine. bool ShouldUseCudnnRuntimeFusion(const DebugOptions& debug_opts, se::CudaComputeCapability cc) { - return debug_opts.xla_gpu_use_runtime_fusion() && - cc.IsAtLeast(se::CudaComputeCapability::AMPERE); + return debug_opts.xla_gpu_use_runtime_fusion() && cc.IsAtLeast(7, 5); } // Can instr be converted to type `dst_ty` without losing any precision? For From 45a552b637b9ca634bf4f881d93583cbfe58d1a6 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 12 Jul 2023 13:03:18 -0700 Subject: [PATCH 214/376] [XLA:CPU] Shard runtime matmul kernels. These can be very slow to build (I've seen 5+ minutes in an OSS build). Shard them into separate files to allow parallelism. PiperOrigin-RevId: 547578695 --- tensorflow/compiler/xla/service/cpu/BUILD | 36 ++++- .../xla/service/cpu/runtime_matmul_c128.cc | 30 ++++ .../xla/service/cpu/runtime_matmul_c64.cc | 30 ++++ ...time_matmul.cc => runtime_matmul_common.h} | 66 +------- .../xla/service/cpu/runtime_matmul_f16.cc | 30 ++++ .../xla/service/cpu/runtime_matmul_f32.cc | 36 +++++ .../xla/service/cpu/runtime_matmul_f64.cc | 29 ++++ .../xla/service/cpu/runtime_matmul_s32.cc | 29 ++++ .../cpu/runtime_single_threaded_matmul.cc | 141 ------------------ .../cpu/runtime_single_threaded_matmul.h | 1 + .../runtime_single_threaded_matmul_c128.cc | 31 ++++ .../cpu/runtime_single_threaded_matmul_c64.cc | 31 ++++ .../runtime_single_threaded_matmul_common.h | 88 +++++++++++ .../cpu/runtime_single_threaded_matmul_f16.cc | 31 ++++ .../cpu/runtime_single_threaded_matmul_f32.cc | 31 ++++ .../cpu/runtime_single_threaded_matmul_f64.cc | 32 ++++ .../cpu/runtime_single_threaded_matmul_s32.cc | 32 ++++ 17 files changed, 501 insertions(+), 203 deletions(-) create mode 100644 tensorflow/compiler/xla/service/cpu/runtime_matmul_c128.cc create mode 100644 tensorflow/compiler/xla/service/cpu/runtime_matmul_c64.cc rename tensorflow/compiler/xla/service/cpu/{runtime_matmul.cc => runtime_matmul_common.h} (66%) create mode 100644 tensorflow/compiler/xla/service/cpu/runtime_matmul_f16.cc create mode 100644 tensorflow/compiler/xla/service/cpu/runtime_matmul_f32.cc create mode 100644 tensorflow/compiler/xla/service/cpu/runtime_matmul_f64.cc create mode 100644 tensorflow/compiler/xla/service/cpu/runtime_matmul_s32.cc delete mode 100644 tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.cc create mode 100644 tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul_c128.cc create mode 100644 tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul_c64.cc create mode 100644 tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul_common.h create mode 100644 tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul_f16.cc create mode 100644 tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul_f32.cc create mode 100644 tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul_f64.cc create mode 100644 tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul_s32.cc diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index c0d0f1908a03a8..6c434c35acf17e 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -74,13 +74,25 @@ filegroup( "runtime_single_threaded_conv2d.cc", "runtime_single_threaded_conv3d.cc", "runtime_single_threaded_fft.cc", - "runtime_single_threaded_matmul.cc", + "runtime_single_threaded_matmul_c128.cc", + "runtime_single_threaded_matmul_c64.cc", + "runtime_single_threaded_matmul_common.h", + "runtime_single_threaded_matmul_f16.cc", + "runtime_single_threaded_matmul_f32.cc", + "runtime_single_threaded_matmul_f64.cc", + "runtime_single_threaded_matmul_s32.cc", "runtime_topk.cc", # Multi-threaded support. "runtime_conv2d.cc", "runtime_conv3d.cc", "runtime_fft.cc", - "runtime_matmul.cc", + "runtime_matmul_c128.cc", + "runtime_matmul_c64.cc", + "runtime_matmul_common.h", + "runtime_matmul_f16.cc", + "runtime_matmul_f32.cc", + "runtime_matmul_f64.cc", + "runtime_matmul_s32.cc", "runtime_fork_join.cc", ], visibility = [":friends"], @@ -946,7 +958,15 @@ cc_library( cc_library( name = "runtime_matmul", - srcs = ["runtime_matmul.cc"], + srcs = [ + "runtime_matmul_c128.cc", + "runtime_matmul_c64.cc", + "runtime_matmul_common.h", + "runtime_matmul_f16.cc", + "runtime_matmul_f32.cc", + "runtime_matmul_f64.cc", + "runtime_matmul_s32.cc", + ], hdrs = ["runtime_matmul.h"], copts = runtime_copts(), visibility = ["//visibility:public"], @@ -1068,7 +1088,15 @@ cc_library( cc_library( name = "runtime_single_threaded_matmul_impl", - srcs = ["runtime_single_threaded_matmul.cc"], + srcs = [ + "runtime_single_threaded_matmul_c128.cc", + "runtime_single_threaded_matmul_c64.cc", + "runtime_single_threaded_matmul_common.h", + "runtime_single_threaded_matmul_f16.cc", + "runtime_single_threaded_matmul_f32.cc", + "runtime_single_threaded_matmul_f64.cc", + "runtime_single_threaded_matmul_s32.cc", + ], hdrs = ["runtime_single_threaded_matmul.h"], compatible_with = get_compatible_with_portable(), copts = runtime_copts(), diff --git a/tensorflow/compiler/xla/service/cpu/runtime_matmul_c128.cc b/tensorflow/compiler/xla/service/cpu/runtime_matmul_c128.cc new file mode 100644 index 00000000000000..c53692de7f3a2b --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/runtime_matmul_c128.cc @@ -0,0 +1,30 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/runtime_matmul.h" + +#include +#include + +#include "absl/base/attributes.h" +#include "tensorflow/compiler/xla/service/cpu/runtime_matmul_common.h" + +ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenMatMulC128( + const void* run_options_ptr, std::complex* out, + std::complex* lhs, std::complex* rhs, int64_t m, int64_t n, + int64_t k, int32_t transpose_lhs, int32_t transpose_rhs) { + xla::MatMulDispatch>(run_options_ptr, out, lhs, rhs, m, + n, k, transpose_lhs, transpose_rhs); +} diff --git a/tensorflow/compiler/xla/service/cpu/runtime_matmul_c64.cc b/tensorflow/compiler/xla/service/cpu/runtime_matmul_c64.cc new file mode 100644 index 00000000000000..9c3482d6ef5049 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/runtime_matmul_c64.cc @@ -0,0 +1,30 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/runtime_matmul.h" + +#include +#include + +#include "absl/base/attributes.h" +#include "tensorflow/compiler/xla/service/cpu/runtime_matmul_common.h" + +ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenMatMulC64( + const void* run_options_ptr, std::complex* out, + std::complex* lhs, std::complex* rhs, int64_t m, int64_t n, + int64_t k, int32_t transpose_lhs, int32_t transpose_rhs) { + xla::MatMulDispatch>(run_options_ptr, out, lhs, rhs, m, n, + k, transpose_lhs, transpose_rhs); +} diff --git a/tensorflow/compiler/xla/service/cpu/runtime_matmul.cc b/tensorflow/compiler/xla/service/cpu/runtime_matmul_common.h similarity index 66% rename from tensorflow/compiler/xla/service/cpu/runtime_matmul.cc rename to tensorflow/compiler/xla/service/cpu/runtime_matmul_common.h index 21ca5ed6402578..6acefadba6f5d4 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_matmul.cc +++ b/tensorflow/compiler/xla/service/cpu/runtime_matmul_common.h @@ -13,7 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/cpu/runtime_matmul.h" +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_MATMUL_COMMON_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_MATMUL_COMMON_H_ + +#include #define EIGEN_USE_THREADS @@ -26,9 +29,9 @@ limitations under the License. #include "tensorflow/tsl/framework/contraction/eigen_contraction_kernel.h" #endif -namespace { +namespace xla { -bool Is16BytesAligned(void* ptr) { +static inline bool Is16BytesAligned(void* ptr) { return reinterpret_cast(ptr) % 16 == 0; } @@ -146,59 +149,6 @@ void BatchMatMulDispatch(const void* run_options_ptr, T* out, T* lhs, T* rhs, batch_size, transpose_lhs, transpose_rhs); } -} // namespace - -ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenMatMulF16( - const void* run_options_ptr, Eigen::half* out, Eigen::half* lhs, - Eigen::half* rhs, int64_t m, int64_t n, int64_t k, int32_t transpose_lhs, - int32_t transpose_rhs) { - MatMulDispatch(run_options_ptr, out, lhs, rhs, m, n, k, - transpose_lhs, transpose_rhs); -} - -ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenMatMulF32( - const void* run_options_ptr, float* out, float* lhs, float* rhs, int64_t m, - int64_t n, int64_t k, int32_t transpose_lhs, int32_t transpose_rhs) { - MatMulDispatch(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs, - transpose_rhs); -} - -ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenMatMulF64( - const void* run_options_ptr, double* out, double* lhs, double* rhs, - int64_t m, int64_t n, int64_t k, int32_t transpose_lhs, - int32_t transpose_rhs) { - MatMulDispatch(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs, - transpose_rhs); -} +} // namespace xla -ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenMatMulC64( - const void* run_options_ptr, std::complex* out, - std::complex* lhs, std::complex* rhs, int64_t m, int64_t n, - int64_t k, int32_t transpose_lhs, int32_t transpose_rhs) { - MatMulDispatch>(run_options_ptr, out, lhs, rhs, m, n, k, - transpose_lhs, transpose_rhs); -} - -ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenMatMulC128( - const void* run_options_ptr, std::complex* out, - std::complex* lhs, std::complex* rhs, int64_t m, int64_t n, - int64_t k, int32_t transpose_lhs, int32_t transpose_rhs) { - MatMulDispatch>(run_options_ptr, out, lhs, rhs, m, n, k, - transpose_lhs, transpose_rhs); -} - -ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenMatMulS32( - const void* run_options_ptr, int32_t* out, int32_t* lhs, int32_t* rhs, - int64_t m, int64_t n, int64_t k, int32_t transpose_lhs, - int32_t transpose_rhs) { - MatMulDispatch(run_options_ptr, out, lhs, rhs, m, n, k, - transpose_lhs, transpose_rhs); -} - -ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenBatchMatMulF32( - const void* run_options_ptr, float* out, float* lhs, float* rhs, int64_t m, - int64_t n, int64_t k, int64_t batch_size, int32_t transpose_lhs, - int32_t transpose_rhs) { - BatchMatMulDispatch(run_options_ptr, out, lhs, rhs, m, n, k, - batch_size, transpose_lhs, transpose_rhs); -} +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_MATMUL_COMMON_H_ diff --git a/tensorflow/compiler/xla/service/cpu/runtime_matmul_f16.cc b/tensorflow/compiler/xla/service/cpu/runtime_matmul_f16.cc new file mode 100644 index 00000000000000..d18516805bb45d --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/runtime_matmul_f16.cc @@ -0,0 +1,30 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/runtime_matmul.h" + +#include + +#include "absl/base/attributes.h" +#include "third_party/eigen3/Eigen/Core" +#include "tensorflow/compiler/xla/service/cpu/runtime_matmul_common.h" + +ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenMatMulF16( + const void* run_options_ptr, Eigen::half* out, Eigen::half* lhs, + Eigen::half* rhs, int64_t m, int64_t n, int64_t k, int32_t transpose_lhs, + int32_t transpose_rhs) { + xla::MatMulDispatch(run_options_ptr, out, lhs, rhs, m, n, k, + transpose_lhs, transpose_rhs); +} diff --git a/tensorflow/compiler/xla/service/cpu/runtime_matmul_f32.cc b/tensorflow/compiler/xla/service/cpu/runtime_matmul_f32.cc new file mode 100644 index 00000000000000..6d84a3ac5c8193 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/runtime_matmul_f32.cc @@ -0,0 +1,36 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/runtime_matmul.h" + +#include + +#include "absl/base/attributes.h" +#include "tensorflow/compiler/xla/service/cpu/runtime_matmul_common.h" + +ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenMatMulF32( + const void* run_options_ptr, float* out, float* lhs, float* rhs, int64_t m, + int64_t n, int64_t k, int32_t transpose_lhs, int32_t transpose_rhs) { + xla::MatMulDispatch(run_options_ptr, out, lhs, rhs, m, n, k, + transpose_lhs, transpose_rhs); +} + +ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenBatchMatMulF32( + const void* run_options_ptr, float* out, float* lhs, float* rhs, int64_t m, + int64_t n, int64_t k, int64_t batch_size, int32_t transpose_lhs, + int32_t transpose_rhs) { + xla::BatchMatMulDispatch(run_options_ptr, out, lhs, rhs, m, n, k, + batch_size, transpose_lhs, transpose_rhs); +} diff --git a/tensorflow/compiler/xla/service/cpu/runtime_matmul_f64.cc b/tensorflow/compiler/xla/service/cpu/runtime_matmul_f64.cc new file mode 100644 index 00000000000000..1424d17fa5f6e6 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/runtime_matmul_f64.cc @@ -0,0 +1,29 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/runtime_matmul.h" + +#include + +#include "absl/base/attributes.h" +#include "tensorflow/compiler/xla/service/cpu/runtime_matmul_common.h" + +ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenMatMulF64( + const void* run_options_ptr, double* out, double* lhs, double* rhs, + int64_t m, int64_t n, int64_t k, int32_t transpose_lhs, + int32_t transpose_rhs) { + xla::MatMulDispatch(run_options_ptr, out, lhs, rhs, m, n, k, + transpose_lhs, transpose_rhs); +} diff --git a/tensorflow/compiler/xla/service/cpu/runtime_matmul_s32.cc b/tensorflow/compiler/xla/service/cpu/runtime_matmul_s32.cc new file mode 100644 index 00000000000000..6c93cb53f5eb31 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/runtime_matmul_s32.cc @@ -0,0 +1,29 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/runtime_matmul.h" + +#include + +#include "absl/base/attributes.h" +#include "tensorflow/compiler/xla/service/cpu/runtime_matmul_common.h" + +ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenMatMulS32( + const void* run_options_ptr, int32_t* out, int32_t* lhs, int32_t* rhs, + int64_t m, int64_t n, int64_t k, int32_t transpose_lhs, + int32_t transpose_rhs) { + xla::MatMulDispatch(run_options_ptr, out, lhs, rhs, m, n, k, + transpose_lhs, transpose_rhs); +} diff --git a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.cc b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.cc deleted file mode 100644 index d5f0b6b93a6258..00000000000000 --- a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.cc +++ /dev/null @@ -1,141 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h" - -#include "absl/base/attributes.h" -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" - -#if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL) -#include "tensorflow/tsl/framework/contraction/eigen_contraction_kernel.h" -#endif - -namespace { - -bool Is16BytesAligned(void* ptr) { - return reinterpret_cast(ptr) % 16 == 0; -} - -template -void MatMul(const void* run_options_ptr, T* out, T* lhs, T* rhs, int64_t m, - int64_t n, int64_t k, int32_t transpose_lhs, - int32_t transpose_rhs) { - int64_t lhs_rows = m; - int64_t lhs_cols = k; - if (transpose_lhs) { - std::swap(lhs_rows, lhs_cols); - } - - int64_t rhs_rows = k; - int64_t rhs_cols = n; - if (transpose_rhs) { - std::swap(rhs_rows, rhs_cols); - } - - const Eigen::TensorMap, Alignment> A(lhs, lhs_rows, - lhs_cols); - const Eigen::TensorMap, Alignment> B(rhs, rhs_rows, - rhs_cols); - Eigen::TensorMap, Alignment> C(out, m, n); - - typedef typename Eigen::Tensor::DimensionPair DimPair; - int lhs_contract_dim = transpose_lhs ? 0 : 1; - int rhs_contract_dim = transpose_rhs ? 1 : 0; - const Eigen::array dims( - {DimPair(lhs_contract_dim, rhs_contract_dim)}); - - // Matrix multiply is a special case of the "contract" operation where - // the contraction is performed along dimension 1 of the lhs and dimension - // 0 of the rhs. - C = A.contract(B, dims); -} - -template -void SingleThreadedMatMulDispatch(const void* run_options_ptr, T* out, T* lhs, - T* rhs, int64_t m, int64_t n, int64_t k, - int32_t transpose_lhs, - int32_t transpose_rhs) { - bool all_buffers_16b_aligned = - Is16BytesAligned(out) && Is16BytesAligned(lhs) && Is16BytesAligned(rhs); - - if (!all_buffers_16b_aligned) { - MatMul(run_options_ptr, out, lhs, rhs, m, n, k, - transpose_lhs, transpose_rhs); - } - - MatMul(run_options_ptr, out, lhs, rhs, m, n, k, - transpose_lhs, transpose_rhs); -} - -} // namespace - -ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void -__xla_cpu_runtime_EigenSingleThreadedMatMulF16( - const void* run_options_ptr, Eigen::half* out, Eigen::half* lhs, - Eigen::half* rhs, int64_t m, int64_t n, int64_t k, int32_t transpose_lhs, - int32_t transpose_rhs) { - SingleThreadedMatMulDispatch(run_options_ptr, out, lhs, rhs, m, - n, k, transpose_lhs, transpose_rhs); -} - -ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void -__xla_cpu_runtime_EigenSingleThreadedMatMulF32(const void* run_options_ptr, - float* out, float* lhs, - float* rhs, int64_t m, int64_t n, - int64_t k, int32_t transpose_lhs, - int32_t transpose_rhs) { - SingleThreadedMatMulDispatch(run_options_ptr, out, lhs, rhs, m, n, k, - transpose_lhs, transpose_rhs); -} - -ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void -__xla_cpu_runtime_EigenSingleThreadedMatMulF64(const void* run_options_ptr, - double* out, double* lhs, - double* rhs, int64_t m, - int64_t n, int64_t k, - int32_t transpose_lhs, - int32_t transpose_rhs) { - SingleThreadedMatMulDispatch(run_options_ptr, out, lhs, rhs, m, n, k, - transpose_lhs, transpose_rhs); -} - -ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void -__xla_cpu_runtime_EigenSingleThreadedMatMulC64( - const void* run_options_ptr, std::complex* out, - std::complex* lhs, std::complex* rhs, int64_t m, int64_t n, - int64_t k, int32_t transpose_lhs, int32_t transpose_rhs) { - SingleThreadedMatMulDispatch>( - run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs); -} - -ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void -__xla_cpu_runtime_EigenSingleThreadedMatMulC128( - const void* run_options_ptr, std::complex* out, - std::complex* lhs, std::complex* rhs, int64_t m, int64_t n, - int64_t k, int32_t transpose_lhs, int32_t transpose_rhs) { - SingleThreadedMatMulDispatch>( - run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs); -} - -ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void -__xla_cpu_runtime_EigenSingleThreadedMatMulS32(const void* run_options_ptr, - int32_t* out, int32_t* lhs, - int32_t* rhs, int64_t m, - int64_t n, int64_t k, - int32_t transpose_lhs, - int32_t transpose_rhs) { - SingleThreadedMatMulDispatch(run_options_ptr, out, lhs, rhs, m, n, k, - transpose_lhs, transpose_rhs); -} diff --git a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h index 9473eb7f56fc52..1ac85a4f125404 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h +++ b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_SINGLE_THREADED_MATMUL_H_ #include +#include #include "third_party/eigen3/Eigen/Core" diff --git a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul_c128.cc b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul_c128.cc new file mode 100644 index 00000000000000..81199c14daf7f8 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul_c128.cc @@ -0,0 +1,31 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h" + +#include +#include + +#include "absl/base/attributes.h" +#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul_common.h" + +ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void +__xla_cpu_runtime_EigenSingleThreadedMatMulC128( + const void* run_options_ptr, std::complex* out, + std::complex* lhs, std::complex* rhs, int64_t m, int64_t n, + int64_t k, int32_t transpose_lhs, int32_t transpose_rhs) { + xla::SingleThreadedMatMulDispatch>( + run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs); +} diff --git a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul_c64.cc b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul_c64.cc new file mode 100644 index 00000000000000..6a176435912403 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul_c64.cc @@ -0,0 +1,31 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h" + +#include +#include + +#include "absl/base/attributes.h" +#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul_common.h" + +ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void +__xla_cpu_runtime_EigenSingleThreadedMatMulC64( + const void* run_options_ptr, std::complex* out, + std::complex* lhs, std::complex* rhs, int64_t m, int64_t n, + int64_t k, int32_t transpose_lhs, int32_t transpose_rhs) { + xla::SingleThreadedMatMulDispatch>( + run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs); +} diff --git a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul_common.h b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul_common.h new file mode 100644 index 00000000000000..d91d8f5258c71e --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul_common.h @@ -0,0 +1,88 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_SINGLE_THREADED_MATMUL_COMMON_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_SINGLE_THREADED_MATMUL_COMMON_H_ + +#include + +#include "absl/base/attributes.h" +#include "third_party/eigen3/Eigen/Core" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +#if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL) +#include "tensorflow/tsl/framework/contraction/eigen_contraction_kernel.h" +#endif + +namespace xla { + +static inline bool Is16BytesAligned(void* ptr) { + return reinterpret_cast(ptr) % 16 == 0; +} + +template +void SingleThreadedMatMul(const void* run_options_ptr, T* out, T* lhs, T* rhs, + int64_t m, int64_t n, int64_t k, + int32_t transpose_lhs, int32_t transpose_rhs) { + int64_t lhs_rows = m; + int64_t lhs_cols = k; + if (transpose_lhs) { + std::swap(lhs_rows, lhs_cols); + } + + int64_t rhs_rows = k; + int64_t rhs_cols = n; + if (transpose_rhs) { + std::swap(rhs_rows, rhs_cols); + } + + const Eigen::TensorMap, Alignment> A(lhs, lhs_rows, + lhs_cols); + const Eigen::TensorMap, Alignment> B(rhs, rhs_rows, + rhs_cols); + Eigen::TensorMap, Alignment> C(out, m, n); + + typedef typename Eigen::Tensor::DimensionPair DimPair; + int lhs_contract_dim = transpose_lhs ? 0 : 1; + int rhs_contract_dim = transpose_rhs ? 1 : 0; + const Eigen::array dims( + {DimPair(lhs_contract_dim, rhs_contract_dim)}); + + // Matrix multiply is a special case of the "contract" operation where + // the contraction is performed along dimension 1 of the lhs and dimension + // 0 of the rhs. + C = A.contract(B, dims); +} + +template +void SingleThreadedMatMulDispatch(const void* run_options_ptr, T* out, T* lhs, + T* rhs, int64_t m, int64_t n, int64_t k, + int32_t transpose_lhs, + int32_t transpose_rhs) { + bool all_buffers_16b_aligned = + Is16BytesAligned(out) && Is16BytesAligned(lhs) && Is16BytesAligned(rhs); + + if (!all_buffers_16b_aligned) { + SingleThreadedMatMul( + run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs); + } + + SingleThreadedMatMul(run_options_ptr, out, lhs, rhs, m, + n, k, transpose_lhs, transpose_rhs); +} + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_SINGLE_THREADED_MATMUL_COMMON_H_ diff --git a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul_f16.cc b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul_f16.cc new file mode 100644 index 00000000000000..76a5b93af75e5e --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul_f16.cc @@ -0,0 +1,31 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h" + +#include + +#include "absl/base/attributes.h" +#include "third_party/eigen3/Eigen/Core" +#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul_common.h" + +ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void +__xla_cpu_runtime_EigenSingleThreadedMatMulF16( + const void* run_options_ptr, Eigen::half* out, Eigen::half* lhs, + Eigen::half* rhs, int64_t m, int64_t n, int64_t k, int32_t transpose_lhs, + int32_t transpose_rhs) { + xla::SingleThreadedMatMulDispatch( + run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs); +} diff --git a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul_f32.cc b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul_f32.cc new file mode 100644 index 00000000000000..6f3271180e9b2a --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul_f32.cc @@ -0,0 +1,31 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h" + +#include + +#include "absl/base/attributes.h" +#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul_common.h" + +ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void +__xla_cpu_runtime_EigenSingleThreadedMatMulF32(const void* run_options_ptr, + float* out, float* lhs, + float* rhs, int64_t m, int64_t n, + int64_t k, int32_t transpose_lhs, + int32_t transpose_rhs) { + xla::SingleThreadedMatMulDispatch(run_options_ptr, out, lhs, rhs, m, n, + k, transpose_lhs, transpose_rhs); +} diff --git a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul_f64.cc b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul_f64.cc new file mode 100644 index 00000000000000..15191c7f151dc8 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul_f64.cc @@ -0,0 +1,32 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h" + +#include + +#include "absl/base/attributes.h" +#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul_common.h" + +ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void +__xla_cpu_runtime_EigenSingleThreadedMatMulF64(const void* run_options_ptr, + double* out, double* lhs, + double* rhs, int64_t m, + int64_t n, int64_t k, + int32_t transpose_lhs, + int32_t transpose_rhs) { + xla::SingleThreadedMatMulDispatch(run_options_ptr, out, lhs, rhs, m, + n, k, transpose_lhs, transpose_rhs); +} diff --git a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul_s32.cc b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul_s32.cc new file mode 100644 index 00000000000000..cf854e5c8f3527 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul_s32.cc @@ -0,0 +1,32 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h" + +#include + +#include "absl/base/attributes.h" +#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul_common.h" + +ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void +__xla_cpu_runtime_EigenSingleThreadedMatMulS32(const void* run_options_ptr, + int32_t* out, int32_t* lhs, + int32_t* rhs, int64_t m, + int64_t n, int64_t k, + int32_t transpose_lhs, + int32_t transpose_rhs) { + xla::SingleThreadedMatMulDispatch( + run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs); +} From 30da06db808dca408d91f2888cdb3748be6c9ed0 Mon Sep 17 00:00:00 2001 From: Justin Szaday Date: Wed, 12 Jul 2023 13:11:44 -0700 Subject: [PATCH 215/376] Handle multiple sets of inferred resource indices and layouts. PiperOrigin-RevId: 547580913 --- .../mlir/dtensor_multi_device_expansion.cc | 41 +++++++++++-------- .../mlir/tests/multi_device_expansion.mlir | 21 ++++++++++ 2 files changed, 44 insertions(+), 18 deletions(-) diff --git a/tensorflow/dtensor/mlir/dtensor_multi_device_expansion.cc b/tensorflow/dtensor/mlir/dtensor_multi_device_expansion.cc index d0793e348cf7f6..1080ca93f6c434 100644 --- a/tensorflow/dtensor/mlir/dtensor_multi_device_expansion.cc +++ b/tensorflow/dtensor/mlir/dtensor_multi_device_expansion.cc @@ -193,7 +193,7 @@ StatusOr> FindResourceLayout(mlir::BlockArgument arg) { } for (auto [i, index] : llvm::enumerate(resource_indices)) { - uint64_t index_value = index.getZExtValue(); + int64_t index_value = index.getSExtValue(); if (index_value == arg_num) { return (resource_layouts.value())->at(i); } @@ -433,7 +433,8 @@ StatusOr> GetExpandedArguments( mesh = layout->mesh(); } else { return absl::InvalidArgumentError( - "Could not find resource layout!"); + absl::StrCat("Could not find resource layout for %arg", + arg.getArgNumber(), "!")); } } } @@ -494,26 +495,29 @@ struct InferredResourceAttributes { template mlir::LogicalResult GetInferredResourceAttributes( - const Operations& call_ops, + mlir::OpBuilder& builder, const Operations& call_ops, std::optional* resource_attrs) { - for (auto call_op : call_ops) { - // Set the resource layouts. - mlir::Attribute resource_layouts_attr = - call_op->getAttr(kNewResourceArgLayouts); - mlir::Attribute resource_indices_attr = - call_op->getAttr(kNewResourceLayoutIndices); + llvm::SmallVector resource_layouts; + llvm::SmallVector resource_indices; + for (mlir::Operation* call_op : call_ops) { + const auto resource_layouts_attr = + call_op->getAttrOfType(kNewResourceArgLayouts); + const auto resource_indices_attr = + call_op->getAttrOfType( + kNewResourceLayoutIndices); if (resource_indices_attr && resource_layouts_attr) { - if (resource_attrs->has_value()) { - // TODO(twelve): Determine how to merge inferred resource attrs if there - // are multiple sets of them. (when can that happen?) - call_op.emitOpError() - << "Multiple sets of inferred resource attributes!"; - return mlir::failure(); - } else { - resource_attrs->emplace(resource_layouts_attr, resource_indices_attr); + for (auto [index, layout] : + llvm::zip(resource_indices_attr, resource_layouts_attr)) { + // Build up the lists of resource indices and layouts. + resource_indices.emplace_back(index.getSExtValue()); + resource_layouts.emplace_back(layout); } } } + if (!resource_layouts.empty()) { + resource_attrs->emplace(builder.getArrayAttr(resource_layouts), + builder.getI32VectorAttr(resource_indices)); + } return mlir::success(); } @@ -565,7 +569,8 @@ mlir::LogicalResult BuildOuterMainFunc( expanded_call_op->setAttr(kNumLocalOutputsAttr, num_local_outputs_attr); std::optional resource_attrs; - if (failed(GetInferredResourceAttributes(call_ops, &resource_attrs))) { + if (failed( + GetInferredResourceAttributes(builder, call_ops, &resource_attrs))) { return mlir::failure(); } diff --git a/tensorflow/dtensor/mlir/tests/multi_device_expansion.mlir b/tensorflow/dtensor/mlir/tests/multi_device_expansion.mlir index a2f33ae33bc654..4de4f74ff094f0 100644 --- a/tensorflow/dtensor/mlir/tests/multi_device_expansion.mlir +++ b/tensorflow/dtensor/mlir/tests/multi_device_expansion.mlir @@ -162,3 +162,24 @@ module @test_tpu_with_inputs attributes {dtensor.enable_multi_device_mode = true return %arg0 : tensor<4xf32> } } + +// ----- + +// CHECK-LABEL: module @test_inferred_resource_attributes +// CHECK-LABEL: func.func @main +// CHECK: "tf.StatefulPartitionedCall" +// CHECK-SAME: _inferred_resource_indices = dense<[1, 2]> +// CHECK-SAME: _inferred_resource_layouts = ["sharding_specs:x,unsharded +// CHECK-SAME , "sharding_specs:unsharded,y + +module @test_inferred_resource_attributes attributes {dtensor.all_reduce_combiner.num_ops_in_group = 0 : i64, dtensor.all_reduce_combiner.topological_distance = 0 : i64, dtensor.eager_operation_name = "AssignVariableOp", dtensor.enable_multi_device_mode = true, tf._default_mesh = "|x=2,y=1|0,1|0,1|/job:localhost/replica:0/task:0/device:CPU:0,/job:localhost/replica:0/task:0/device:CPU:1", tf.devices = {"/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:CPU:1"}, tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 1555 : i32}} { + func.func @main(%arg0: tensor {tf._global_shape = #tf_type.shape<>}, %arg1: tensor>> {tf._assigned_resource_local_shape = #tf_type.shape<>, tf._global_shape = #tf_type.shape<>, tf._layout = "empty_layout", tf._mesh = "empty_mesh"}, %arg2: tensor>> {tf._assigned_resource_local_shape = #tf_type.shape<>, tf._global_shape = #tf_type.shape<>, tf._layout = "empty_layout", tf._mesh = "empty_mesh"}) attributes {allow_soft_placement = false, tf.entry_function = {control_outputs = "eager_operation", inputs = "device_id,op_input_0,op_input_1", outputs = ""}} { + "tf.StatefulPartitionedCall"(%arg0, %arg1) {_inferred_resource_indices = dense<1> : vector<1xi32>, _inferred_resource_layouts = ["sharding_specs:x,unsharded, mesh:|x=2,y=1|0,1|0,1|/job:localhost/replica:0/task:0/device:CPU:0,/job:localhost/replica:0/task:0/device:CPU:1"], _layout = [], _mesh = "|x=2,y=1|0,1|0,1|/job:localhost/replica:0/task:0/device:CPU:0,/job:localhost/replica:0/task:0/device:CPU:1", config = "|x=2,y=1|0,1|0,1|/job:localhost/replica:0/task:0/device:CPU:0,/job:localhost/replica:0/task:0/device:CPU:1", config_proto = "", executor_type = "", f = @_func} : (tensor, tensor>>) -> () + "tf.StatefulPartitionedCall"(%arg0, %arg2) {_inferred_resource_indices = dense<2> : vector<1xi32>, _inferred_resource_layouts = ["sharding_specs:unsharded,y, mesh:|x=1,y=2|0,1|0,1|/job:localhost/replica:0/task:0/device:CPU:0,/job:localhost/replica:0/task:0/device:CPU:1"], _layout = [], _mesh = "|x=1,y=2|0,1|0,1|/job:localhost/replica:0/task:0/device:CPU:0,/job:localhost/replica:0/task:0/device:CPU:1", config = "|x=1,y=2|0,1|0,1|/job:localhost/replica:0/task:0/device:CPU:0,/job:localhost/replica:0/task:0/device:CPU:1", config_proto = "", executor_type = "", f = @_func} : (tensor, tensor>>) -> () + return + } + func.func private @_func(%arg0: tensor, %arg1: tensor>>) { + "tf.AssignVariableOp"(%arg1, %arg0) {_global_shape = [], _layout = [], device = "", validate_shape = false} : (tensor>>, tensor) -> () + return + } +} From def8728a883dc400cc138820017293512e7ad243 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 12 Jul 2023 13:16:13 -0700 Subject: [PATCH 216/376] Remove autoclustering in TFRT. PiperOrigin-RevId: 547582036 --- .../compiler/mlir/tfrt/jit/transforms/BUILD | 1 - .../transforms/tf_jitrt_clustering_pass.cc | 171 ---------------- .../tfrt/jit/transforms/tf_jitrt_passes.h | 8 - .../tfrt/jit/transforms/tf_jitrt_passes.td | 21 -- .../tests/jit/tf_jitrt_clustering_oplist.mlir | 23 --- .../jit/tf_jitrt_clustering_oplist_all.mlir | 54 ----- .../jit/tf_jitrt_clustering_oplist_tier1.mlir | 54 ----- .../tfrt/tests/tf_to_corert/auto-fusion.mlir | 67 ------ .../tf_to_corert_pipeline_cpurt.mlir | 190 ------------------ .../mlir/tfrt/transforms/tfrt_jitrt_passes.cc | 7 - .../tfrt/transforms/tfrt_pipeline_options.h | 20 -- .../mlir/tfrt/translate/import_model.cc | 3 - .../tfrt/translate/tfrt_compile_options.cc | 4 - .../tfrt/translate/tfrt_compile_options.h | 14 -- tensorflow/compiler/xla/mlir/runtime/BUILD | 1 - 15 files changed, 638 deletions(-) delete mode 100644 tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_clustering_pass.cc delete mode 100644 tensorflow/compiler/mlir/tfrt/tests/jit/tf_jitrt_clustering_oplist.mlir delete mode 100644 tensorflow/compiler/mlir/tfrt/tests/jit/tf_jitrt_clustering_oplist_all.mlir delete mode 100644 tensorflow/compiler/mlir/tfrt/tests/jit/tf_jitrt_clustering_oplist_tier1.mlir delete mode 100644 tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/auto-fusion.mlir delete mode 100644 tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/tf_to_corert_pipeline_cpurt.mlir diff --git a/tensorflow/compiler/mlir/tfrt/jit/transforms/BUILD b/tensorflow/compiler/mlir/tfrt/jit/transforms/BUILD index cffa4511b88d13..edbdcb36a6b765 100644 --- a/tensorflow/compiler/mlir/tfrt/jit/transforms/BUILD +++ b/tensorflow/compiler/mlir/tfrt/jit/transforms/BUILD @@ -46,7 +46,6 @@ cc_library( name = "tf_jitrt_passes", srcs = [ "tf_jitrt_buffer_forwarding.cc", - "tf_jitrt_clustering_pass.cc", "tf_jitrt_copy_removal.cc", "tf_jitrt_fission.cc", "tf_jitrt_fusion.cc", diff --git a/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_clustering_pass.cc b/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_clustering_pass.cc deleted file mode 100644 index c1416bb5c45c8c..00000000000000 --- a/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_clustering_pass.cc +++ /dev/null @@ -1,171 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include - -#include "llvm/ADT/STLExtras.h" -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h" -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h" -#include "tensorflow/compiler/mlir/tensorflow/transforms/cluster_ops_by_policy.h" -#include "tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_clustering.h" -#include "tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.h" - -namespace tensorflow { -namespace { - -#define GEN_PASS_DEF_CLUSTERING -#include "tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.h.inc" - -using llvm::ArrayRef; - -using mlir::TF::ConstOp; -using mlir::TF::HashTableV2Op; -using mlir::TF::ReadVariableOp; - -using mlir::TFDevice::Cluster; -using mlir::TFDevice::ClusteringPolicySet; -using mlir::TFDevice::CreateClusterOp; -using mlir::TFDevice::FindClustersInTheBlock; - -// -------------------------------------------------------------------------- // -// Cluster operations based on the TF JitRt clustering policy. -// -------------------------------------------------------------------------- // -struct ClusteringPass : public impl::ClusteringBase { - ClusteringPass() = default; - ClusteringPass(ArrayRef cluster_oplist, int cluster_min_size) { - oplist = cluster_oplist; - min_cluster_size = cluster_min_size; - } - - void runOnOperation() override { - ClusteringPolicySet policies; - - // Parse clustering tier and operations filter from the oplist. - llvm::DenseSet opset; - std::optional tier; - - for (const auto& op : oplist) { - if (op == "tier0") { - tier = JitRtClusteringTier::kTier0; - } else if (op == "tier1") { - tier = JitRtClusteringTier::kTier1; - } else if (op == "tier1metadata") { - tier = JitRtClusteringTier::kTier1Metadata; - } else if (op == "tier1reductions") { - tier = JitRtClusteringTier::kTier1Reductions; - } else if (op == "all") { - tier = JitRtClusteringTier::kAll; - } else { - opset.insert(op); - } - } - - // Run clustering only if the clustering tier or supported operations are - // explicitly defined by the oplist. - if (!tier.has_value() && opset.empty()) return; - - // If the clustering tier is not defined, it means that the opset will later - // filter supported operations, so it's ok to use `all` tier. - populateTfJitRtClusteringPolicies(policies, - tier.value_or(JitRtClusteringTier::kAll)); - - // If opset is not empty restrict operations that are enabled for - // clustering. - auto opset_filter = [&](mlir::Operation* op) -> bool { - return opset.empty() || opset.contains(op->getName().getStringRef()); - }; - - // Find operations that could be hoisted from the function body into the - // TFRT resource initialization function. Currently it is an approximation - // of hoisting rules in the TFRT, we just find all the operations that - // depend only on ConstOp, ReadVariableOp or HashTableV2Op operations. We - // don't do any side effects analysis and conservatively can mark as - // hoistable operations that will not be hoisted by TFRT because of side - // effect dependencies. - // - // TODO(ezhulenev): This should be shared with TFRT hoisting implementation. - - // Initialize a set of operations that we assume we will hoist. - llvm::DenseSet hoisted_ops; - getOperation().walk([&](mlir::Operation* op) { - if (mlir::isa(op)) - hoisted_ops.insert(op); - }); - - // Initialize work list with users of ReadVariableOp results. - llvm::SmallVector work_list; - for (mlir::Operation* hoisted : hoisted_ops) - work_list.append(hoisted->user_begin(), hoisted->user_end()); - - // Traverse all users until we find all operations that could be hoisted. - while (!work_list.empty()) { - mlir::Operation* op = work_list.pop_back_val(); - - // Skip operations that are already in the hoisted set. - if (hoisted_ops.contains(op)) continue; - - // Add operation to hoisted ops set if all operands can be hoisted. - bool all_operands_hoisted = - llvm::all_of(op->getOperands(), [&](mlir::Value value) { - return hoisted_ops.contains(value.getDefiningOp()); - }); - if (!all_operands_hoisted) continue; - - hoisted_ops.insert(op); - work_list.append(op->user_begin(), op->user_end()); - } - - auto hoist_filter = [&](mlir::Operation* op) { - return !hoisted_ops.contains(op); - }; - - // Combine together opset and hoist filters. - auto filter = [&](mlir::Operation* op) -> bool { - return opset_filter(op) && hoist_filter(op); - }; - - // Annotate all formed clusters with an attribute. - auto policy = mlir::StringAttr::get(&getContext(), "tfrt.auto-fusion"); - - getOperation().walk([&](mlir::Block* block) { - for (Cluster& cluster : FindClustersInTheBlock(block, policies, filter)) { - // Do not create too small clusters. - if (cluster.operations.size() < min_cluster_size) continue; - // Verify that JIT runtime can compile the cluster. - if (failed(VerifyCluster(cluster))) continue; - - CreateClusterOp(cluster, policy); - } - }); - } -}; - -} // namespace - -std::unique_ptr> -CreateTfJitRtClusteringPass() { - return std::make_unique(); -} - -std::unique_ptr> -CreateTfJitRtClusteringPass(llvm::ArrayRef oplist, - int min_cluster_size) { - return std::make_unique(oplist, min_cluster_size); -} - -} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.h b/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.h index ae7a7b8da17f06..b50dad8498b90f 100644 --- a/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.h +++ b/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.h @@ -57,14 +57,6 @@ std::unique_ptr> CreateFissionPass(); // Pass to fuse Linalg generic operations on Tensors. std::unique_ptr> CreateFusionPass(); -// Creates `tf_device.cluster` operations according to the TF JitRt clustering -// policy. -std::unique_ptr> -CreateTfJitRtClusteringPass(); -std::unique_ptr> -CreateTfJitRtClusteringPass(llvm::ArrayRef oplist, - int min_cluster_size); - // Pass to replace math ops with approximations. std::unique_ptr> CreateMathApproximationPass(llvm::ArrayRef oplist = {}); diff --git a/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.td b/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.td index 5e2578bfb50a4d..dbf765dec9bc31 100644 --- a/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.td +++ b/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.td @@ -85,27 +85,6 @@ def JitRtLegalizeI1Types ]; } -def Clustering : Pass<"tf-jitrt-clustering", "mlir::func::FuncOp"> { - let summary = "Creates `tf_device.cluster` operations according to the TF " - "JitRt clustering policy"; - - let constructor = "tensorflow::CreateTfJitRtClusteringPass()"; - - let dependentDialects = ["mlir::tf_device::TensorFlowDeviceDialect"]; - - let options = [ - Option<"min_cluster_size", "min-cluster-size", "int" , /*default=*/"1", - "Do not form clusters smaller of the given size.">, - // TODO(ezhulenev): This is a temporary workaround to control TF->JitRt - // clustering policy at runtime. - ListOption<"oplist", "oplist", "std::string", - "Explicitly allow operations for clustering. Only operations in " - "this list will be passed to the TF->JitRt clustering policy. " - "Alternatively use 'tier1', ..., 'all' to allow clustering for " - "all operations included in the given clustering tier."> - ]; -} - def MathApproximation : Pass<"tf-jitrt-math-approximation", "mlir::func::FuncOp"> { let summary = "Approximate math operations with an implementation meant to " "match Eigen's results. This is a useful property to have when " diff --git a/tensorflow/compiler/mlir/tfrt/tests/jit/tf_jitrt_clustering_oplist.mlir b/tensorflow/compiler/mlir/tfrt/tests/jit/tf_jitrt_clustering_oplist.mlir deleted file mode 100644 index cde0cef4e38ed7..00000000000000 --- a/tensorflow/compiler/mlir/tfrt/tests/jit/tf_jitrt_clustering_oplist.mlir +++ /dev/null @@ -1,23 +0,0 @@ -// RUN: tf-tfrt-opt %s \ -// RUN: -tf-jitrt-clustering="oplist=tf.Add,tf.Sub,tf.Neg min-cluster-size=2"\ -// RUN: | FileCheck %s - -// CHECK-LABEL: func @single_cluster_one_result -func.func @single_cluster_one_result(%arg0 : tensor, %arg1 : tensor) - -> tensor { - // CHECK: %[[CLUSTER:.*]] = "tf_device.cluster"() - // CHECK: "tf.Add" - // CHECK: "tf.Neg" - // CHECK: "tf.Sub" - // CHECK: "tf.Neg" - // CHECK: %[[RET:.*]] = "tf.Add" - // CHECK: tf_device.return %[[RET]] - %0 = "tf.Add"(%arg0, %arg1) : (tensor, tensor) -> tensor - %1 = "tf.Neg"(%0) : (tensor) -> tensor - %2 = "tf.Sub"(%arg0, %arg1) : (tensor, tensor) -> tensor - %3 = "tf.Neg"(%2) : (tensor) -> tensor - %4 = "tf.Add"(%1, %3) : (tensor, tensor) -> tensor - // CHECK: }) {policy = "tfrt.auto-fusion"} - // CHECK: return %[[CLUSTER]] - func.return %4 : tensor -} diff --git a/tensorflow/compiler/mlir/tfrt/tests/jit/tf_jitrt_clustering_oplist_all.mlir b/tensorflow/compiler/mlir/tfrt/tests/jit/tf_jitrt_clustering_oplist_all.mlir deleted file mode 100644 index d6f6ed6da92768..00000000000000 --- a/tensorflow/compiler/mlir/tfrt/tests/jit/tf_jitrt_clustering_oplist_all.mlir +++ /dev/null @@ -1,54 +0,0 @@ -// RUN: tf-tfrt-opt %s \ -// RUN: -tf-jitrt-clustering="oplist=all min-cluster-size=2" \ -// RUN: | FileCheck %s - -// CHECK-LABEL: func @single_cluster_one_result -func.func @single_cluster_one_result(%arg0 : tensor, %arg1 : tensor) - -> tensor { - // CHECK: %[[CLUSTER:.*]] = "tf_device.cluster"() - // CHECK: "tf.Add" - // CHECK: "tf.Neg" - // CHECK: "tf.Sub" - // CHECK: "tf.Neg" - // CHECK: %[[RET:.*]] = "tf.Add" - // CHECK: tf_device.return %[[RET]] - %0 = "tf.Add"(%arg0, %arg1) : (tensor, tensor) -> tensor - %1 = "tf.Neg"(%0) : (tensor) -> tensor - %2 = "tf.Sub"(%arg0, %arg1) : (tensor, tensor) -> tensor - %3 = "tf.Neg"(%2) : (tensor) -> tensor - %4 = "tf.Add"(%1, %3) : (tensor, tensor) -> tensor - // CHECK: }) {policy = "tfrt.auto-fusion"} - // CHECK: return %[[CLUSTER]] - func.return %4 : tensor -} - -// CHECK-LABEL: func @do_not_cluster_hoistable_ops -func.func @do_not_cluster_hoistable_ops( - %arg0 : tensor, - %arg1 : tensor<*x!tf_type.resource>, - %arg2 : tensor<*x!tf_type.resource> - ) -> tensor { - // CHECK: "tf.Const" - // CHECK: "tf.ReadVariableOp" - // CHECK: "tf.ReadVariableOp" - // CHECK: "tf.Add" - // CHECK: "tf.Neg" - // CHECK: "tf.Sub" - %c = "tf.Const"() { value = dense<1> : tensor } : () -> tensor - %x = "tf.ReadVariableOp"(%arg1) : (tensor<*x!tf_type.resource>) -> tensor - %y = "tf.ReadVariableOp"(%arg2) : (tensor<*x!tf_type.resource>) -> tensor - %0 = "tf.Add"(%x, %y) : (tensor, tensor) -> tensor - %1 = "tf.Neg"(%0) : (tensor) -> tensor - %2 = "tf.Sub"(%0, %c) : (tensor, tensor) -> tensor - // CHECK: %[[CLUSTER:.*]] = "tf_device.cluster"() - // CHECK: "tf.Sub" - // CHECK: "tf.Neg" - // CHECK: %[[RET:.*]] = "tf.Add" - // CHECK: tf_device.return %[[RET]] - %3 = "tf.Sub"(%arg0, %2) : (tensor, tensor) -> tensor - %4 = "tf.Neg"(%3) : (tensor) -> tensor - %5 = "tf.Add"(%2, %4) : (tensor, tensor) -> tensor - // CHECK: }) {policy = "tfrt.auto-fusion"} - // CHECK: return %[[CLUSTER]] - func.return %5 : tensor -} diff --git a/tensorflow/compiler/mlir/tfrt/tests/jit/tf_jitrt_clustering_oplist_tier1.mlir b/tensorflow/compiler/mlir/tfrt/tests/jit/tf_jitrt_clustering_oplist_tier1.mlir deleted file mode 100644 index 026828ae368718..00000000000000 --- a/tensorflow/compiler/mlir/tfrt/tests/jit/tf_jitrt_clustering_oplist_tier1.mlir +++ /dev/null @@ -1,54 +0,0 @@ -// RUN: tf-tfrt-opt %s \ -// RUN: -tf-jitrt-clustering="oplist=tier1 min-cluster-size=2" \ -// RUN: | FileCheck %s --check-prefix CHECK --check-prefix=TIER1 -// RUN: tf-tfrt-opt %s \ -// RUN: -tf-jitrt-clustering="oplist=tier1metadata min-cluster-size=2" \ -// RUN: | FileCheck %s --check-prefix CHECK --check-prefix=METADATA -// RUN: tf-tfrt-opt %s \ -// RUN: -tf-jitrt-clustering="oplist=tier1reductions min-cluster-size=2" \ -// RUN: | FileCheck %s --check-prefix CHECK --check-prefix=REDUCTIONS - -// CHECK-LABEL: func @single_cluster_one_result -func.func @single_cluster_one_result(%arg0 : tensor, %arg1 : tensor) - -> tensor { - // CHECK: %[[CLUSTER:.*]] = "tf_device.cluster"() - // TIER1-NOT: "tf.Sum" - // TIER1: "tf.Add" - // TIER1: "tf.Neg" - // TIER1: "tf.Sub" - // TIER1: "tf.Neg" - // TIER1: %[[RET:.*]] = "tf.Add" - // TIER1: tf_device.return %[[RET]] - - // METADATA-NOT: "tf.Sum" - // METADATA: "tf.Add" - // METADATA: "tf.Neg" - // METADATA: "tf.Sub" - // METADATA: "tf.Neg" - // METADATA: "tf.Add" - // METADATA: %[[RET:.*]] = "tf.Shape" - // METADATA: tf_device.return %[[RET]] - - // REDUCTIONS: "tf.Sum" - // REDUCTIONS: "tf.Add" - // REDUCTIONS: "tf.Neg" - // REDUCTIONS: "tf.Sub" - // REDUCTIONS: "tf.Neg" - // REDUCTIONS: %[[RET:.*]] = "tf.Add" - // REDUCTIONS: tf_device.return %[[RET]] - %dimension = "tf.Const"() { value = dense<0> : tensor<1xi64> } : () -> tensor<1xi64> - %s = "tf.Sum"(%arg0, %dimension) { keep_dims = false }: (tensor, tensor<1xi64>) -> tensor - %0 = "tf.Add"(%s, %arg1) : (tensor, tensor) -> tensor - %1 = "tf.Neg"(%0) : (tensor) -> tensor - %2 = "tf.Sub"(%s, %arg1) : (tensor, tensor) -> tensor - %3 = "tf.Neg"(%2) : (tensor) -> tensor - %4 = "tf.Add"(%1, %3) : (tensor, tensor) -> tensor - %5 = "tf.Shape"(%4) : (tensor) -> tensor - // CHECK: }) {policy = "tfrt.auto-fusion"} - // TIER1: %[[SHAPE:.*]] = "tf.Shape"(%[[CLUSTER]]) - // TIER1: return %[[SHAPE]] - - // REDUCTIONS: %[[SHAPE:.*]] = "tf.Shape"(%[[CLUSTER]]) - // REDUCTIONS: return %[[SHAPE]] - func.return %5 : tensor -} diff --git a/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/auto-fusion.mlir b/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/auto-fusion.mlir deleted file mode 100644 index b6c63f3f560bcf..00000000000000 --- a/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/auto-fusion.mlir +++ /dev/null @@ -1,67 +0,0 @@ -// RUN: tf-tfrt-opt -tf-executor-to-tfrt-pipeline="auto-fusion-oplist=tf.Rsqrt,tf.Tanh auto-fusion-min-cluster-size=1" -split-input-file %s \ -// RUN: | FileCheck %s --dump-input=always - -// CHECK-LABEL: func @single_op_cluster -// CHECK: %[[ARG0:.*]]: !tfrt.chain -// CHECK: %[[ARG1:.*]]: !corert.tensorhandle -func.func @single_op_cluster(%arg0: tensor) -> tensor { - // CHECK: %[[ARG:.*]] = tfrt_fallback_async.corert_tensorhandle_to_fallback_tensor - // CHECK-SAME: %[[ARG1]] - // CHECK-SAME: device = "/job:localhost/replica:0/task:0/device:CPU:0" - // CHECK: %[[RES:.*]] = tf_jitrt.fallback.execute @kernel::@compute(%[[ARG]]) - // CHECK: %[[OUT:.*]] = tfrt_fallback_async.fallback_tensor_to_corert_tensorhandle - // CHECK-SAME: %[[RES]] - // CHECK-SAME: device = "/job:localhost/replica:0/task:0/device:CPU:0" - // CHECK: tfrt.return %[[ARG0]], %[[OUT]] : !tfrt.chain, !corert.tensorhandle - %0 = "tf.Rsqrt"(%arg0) {T = f32, device="/device:CPU:0"} : (tensor) -> tensor - func.return %0 : tensor -} - -// CHECK: module @kernel attributes { -// CHECK-SAME: tfrt.compiled -// CHECK-SAME: "tfrt.max-arg-size" = 1 : i64 -// CHECK-SAME: } -// CHECK-LABEL: func @compute -// CHECK-SAME: %[[ARG0:.*]]: tensor -// CHECK: %[[RES:.*]] = "tf.Rsqrt"(%[[ARG0]]) -// CHECK: return %[[RES]] - -// ----- - -// CHECK-LABEL: func @one_compiled_cluster -func.func @one_compiled_cluster(%arg0: tensor) -> tensor { - // CHECK: %[[RES:.*]] = tf_jitrt.fallback.execute @kernel::@compute - // CHECK-NOT: Rsqrt - // CHECK-NOT: Tanh - %0 = "tf.Rsqrt"(%arg0) {T = f32, device="/device:CPU:0"} : (tensor) -> tensor - %1 = "tf.Tanh"(%0) {T = f32, device="/device:CPU:0"} : (tensor) -> tensor - func.return %1 : tensor -} - -// CHECK: module @kernel attributes { -// CHECK-SAME: tfrt.compiled -// CHECK-SAME: "tfrt.max-arg-size" = 1 : i64 -// CHECK-SAME: } -// CHECK-LABEL: func @compute -// CHECK-SAME: %[[ARG0:.*]]: tensor -// CHECK: %[[RES0:.*]] = "tf.Rsqrt"(%[[ARG0]]) -// CHECK: %[[RES1:.*]] = "tf.Tanh"(%[[RES0]]) -// CHECK: return %[[RES1]] - -// ----- - -// CHECK-LABEL: func @two_compiled_clusters -func.func @two_compiled_clusters(%arg0: tensor) -> tensor { - // CHECK: tf_jitrt.fallback.execute @kernel::@compute - %0 = "tf.Rsqrt"(%arg0) {T = f32, device="/device:CPU:0"} : (tensor) -> tensor - // CHECK: tfrt_fallback_async.executeop {{.*}} "tf.Sqrt" - %1 = "tf.Sqrt"(%0) {T = f32, device="/device:CPU:0"} : (tensor) -> tensor - // CHECK: tf_jitrt.fallback.execute @kernel_0::@compute - %2 = "tf.Tanh"(%1) {T = f32, device="/device:CPU:0"} : (tensor) -> tensor - func.return %2 : tensor -} - -// CHECK: module @kernel -// CHECK: tf.Rsqrt -// CHECK: module @kernel_0 -// CHECK: tf.Tanh diff --git a/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/tf_to_corert_pipeline_cpurt.mlir b/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/tf_to_corert_pipeline_cpurt.mlir deleted file mode 100644 index 43267d6ec2264e..00000000000000 --- a/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/tf_to_corert_pipeline_cpurt.mlir +++ /dev/null @@ -1,190 +0,0 @@ -// RUN: tf-tfrt-opt %s \ -// RUN: -split-input-file \ -// RUN: -tf-executor-to-tfrt-pipeline=" \ -// RUN: enable-optimizer=true \ -// RUN: tfrt-cost-threshold=1024 \ -// RUN: auto-fusion-oplist=tf.Relu,tf.Transpose,tf.Const \ -// RUN: auto-fusion-min-cluster-size=1" \ -// RUN: | FileCheck %s --dump-input=always - -// Check TF->JitRT JIT compiled operations clustering and outlining starting -// from the Tensorflow executor dialect. - -// ----- -// Simple cluster consisting of a single operation. - -module attributes {tf.versions = {producer = 462 : i32}} { - // CHECK: func @_tfrt_fallback_init(%[[ARG:.*]]: !tfrt.chain) - // CHECK: %[[COMPILED:.*]] = tf_jitrt.fallback.compile @kernel::@compute - // CHECK: %[[CHAIN:.*]] = tfrt.merge.chains %[[ARG]], %[[COMPILED]] - // CHECK: tfrt.return %[[CHAIN]] - - // CHECK: func @call - func.func @call(%arg0: tensor) -> (tensor) - attributes { tf.entry_function = {control_outputs = "", - inputs = "input_0", - outputs = "output_0"}} { - // CHECK: tf_jitrt.fallback.execute @kernel::@compute - %0 = tf_executor.graph { - %outs, %control = tf_executor.island wraps "tf.Relu"(%arg0) - {device = ""} : (tensor) -> tensor - tf_executor.fetch %outs: tensor - } - func.return %0 : tensor - } -} - -// CHECK: module @kernel attributes { -// CHECK-SAME: tfrt.compiled -// CHECK-SAME: "tfrt.max-arg-size" = 1 : i64 -// CHECK-SAME: } -// CHECK: func @compute( -// CHECK-SAME: %[[ARG0:.*]]: tensor -// CHECK-SAME: ) -> tensor { -// CHECK: %[[RELU:.*]] = "tf.Relu"(%[[ARG0]]) -// CHECK: return %[[RELU]] -// CHECK: } - -// ----- -// Two identical clusters (except the _class attribute) consisting of a single -// `Relu` operation. Check that outlined clusters are deduplicated and we -// compile only once. - -module attributes {tf.versions = {producer = 462 : i32}} { - // CHECK: func @_tfrt_fallback_init - // CHECK: tf_jitrt.fallback.compile @kernel::@compute - // CHECK-NOT: tf_jitrt.fallback.compile - - // CHECK: func @call - func.func @call(%arg0: tensor) -> (tensor) - attributes { tf.entry_function = {control_outputs = "", - inputs = "input_0", - outputs = "output_0"}} { - // CHECK: tf_jitrt.fallback.execute @kernel::@compute - // CHECK: tfrt_fallback_async.executeop {{.*}} "tf.Sqrt" - // CHECK: tf_jitrt.fallback.execute @kernel::@compute - %0 = tf_executor.graph { - %outs0, %control0 = tf_executor.island wraps "tf.Relu"(%arg0) - {device = "", _class = ["loc:@Relu_0"]} - : (tensor) -> tensor - %outs1, %control1 = tf_executor.island wraps "tf.Sqrt"(%outs0) - {device = ""} : (tensor) -> tensor - %outs2, %control2 = tf_executor.island wraps "tf.Relu"(%outs1) - {device = "", _class = ["loc:@Relu_1"]} - : (tensor) -> tensor - tf_executor.fetch %outs2: tensor - } - func.return %0 : tensor - } -} - -// CHECK: module @kernel attributes { -// CHECK-SAME: tfrt.compiled -// CHECK-SAME: "tfrt.max-arg-size" = 1 : i64 -// CHECK-SAME: } -// CHECK: func @compute( -// CHECK-SAME: %[[ARG0:.*]]: tensor -// CHECK-SAME: ) -> tensor { -// CHECK: %[[RELU:.*]] = "tf.Relu"(%[[ARG0]]) -// CHECK: return %[[RELU]] -// CHECK: } - -// ----- -// Constants sunk into the outlined compiled functions. - -module attributes {tf.versions = {producer = 462 : i32}} { - // CHECK: func @_tfrt_fallback_init - // CHECK: tf_jitrt.fallback.compile @kernel::@compute - - // CHECK: func @call - func.func @call(%arg0: tensor) -> (tensor) - attributes { tf.entry_function = {control_outputs = "", - inputs = "input_0", - outputs = "output_0"}} { - // CHECK: tf_jitrt.fallback.execute @kernel::@compute - %0 = tf_executor.graph { - %perm, %perm_ctl = tf_executor.island wraps "tf.Const"() - {device = "", value = dense<[1, 0]> : tensor<2xi32>} - : () -> tensor<2xi32> - %out, %out_ctl = tf_executor.island wraps "tf.Transpose"(%arg0, %perm) - {device = ""} - : (tensor, tensor<2xi32>) -> tensor - tf_executor.fetch %out: tensor - } - func.return %0 : tensor - } -} - -// CHECK: module @kernel attributes { -// CHECK-SAME: tfrt.compiled -// CHECK-SAME: "tfrt.max-arg-size" = 1 : i64 -// CHECK-SAME: } -// CHECK: func @compute( -// CHECK-SAME: %[[ARG0:.*]]: tensor -// CHECK-SAME: ) -> tensor { -// CHECK: %[[PERM:.*]] = "tf.Const"() {{.*}} dense<[1, 0]> -// CHECK: %[[RET:.*]] = "tf.Transpose"(%[[ARG0]], %[[PERM]]) -// CHECK: return %[[RET]] -// CHECK: } - -// ----- -// tf.Transpose: a non-const permutation parameter cannot be sunk into the -// compiled function. Such a transpose should, however, support clustering, -// and its permutation parameter should compile to be value-constrained. - -module attributes {tf.versions = {producer = 462 : i32}} { - // CHECK: func @_tfrt_fallback_init - // CHECK: tf_jitrt.fallback.compile @kernel::@compute - - // CHECK: func @call - func.func @call(%arg0: tensor, %arg1: tensor) -> (tensor) - attributes { tf.entry_function = {control_outputs = "", - inputs = "input_0,input_1", - outputs = "output_0"}} { - // CHECK: tf_jitrt.fallback.execute @kernel::@compute - %0 = tf_executor.graph { - %out, %out_ctl = tf_executor.island wraps "tf.Transpose"(%arg0, %arg1) - {device = ""} - : (tensor, tensor) -> tensor - tf_executor.fetch %out: tensor - } - func.return %0 : tensor - } -} - -// CHECK: module @kernel attributes { -// CHECK-SAME: tfrt.compiled -// CHECK-SAME: "tfrt.max-arg-size" = 1 : i64 -// CHECK-SAME: } -// CHECK: func @compute( -// CHECK-SAME: %[[ARG0:.*]]: tensor -// CHECK-SAME: %[[ARG1:.*]]: tensor {rt.constraint = "value"} -// CHECK-SAME: ) -> tensor { -// CHECK-NEXT: %[[RET:.*]] = "tf.Transpose"(%[[ARG0]], %[[ARG1]]) -// CHECK: return %[[RET]] -// CHECK: } - -// ----- -// Operations with unsupported data type operands/results are not clustered. - -module attributes {tf.versions = {producer = 462 : i32}} { - func.func @call(%arg0: tensor) -> (tensor) - attributes { tf.entry_function = {control_outputs = "", - inputs = "input_0", - outputs = "output_0"}} { - // CHECK-NOT: tf_jitrt.fallback.compile - // CHECK-NOT: tf_jitrt.fallback.execute - // CHECK-NOT: module @kernel - %0 = tf_executor.graph { - %perm, %perm_ctl = - tf_executor.island wraps "tf.Const"() - {device = "", value = dense<[1, 0]> : tensor<2xi32>} - : () -> tensor<2xi32> - %out, %out_ctl = - tf_executor.island wraps "tf.Transpose"(%arg0, %perm) {device = ""} - : (tensor, tensor<2xi32>) -> tensor - tf_executor.fetch %out: tensor - } - func.return %0 : tensor - } -} diff --git a/tensorflow/compiler/mlir/tfrt/transforms/tfrt_jitrt_passes.cc b/tensorflow/compiler/mlir/tfrt/transforms/tfrt_jitrt_passes.cc index 91a4c1d61fd166..83fb4e2343b1ea 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/tfrt_jitrt_passes.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/tfrt_jitrt_passes.cc @@ -378,13 +378,6 @@ std::unique_ptr CreateOutlineJitRtClustersPass() { void TfrtJitRtStubImpl::AddTfrtJitRtPasses(const TfrtPipelineOptions &options, mlir::OpPassManager &pm) { - // Outline auto-fusion clusters into tf_device.cluster_operations and then - // convert them to functions. We currently support only tfrt fallback tensors - // as operands, so we disable these passes if we can have native ops after - // lowering. - pm.addNestedPass(CreateTfJitRtClusteringPass( - options.auto_fusion_oplist, options.auto_fusion_min_cluster_size)); - // Sink small constants into the outlined clusters to reduce the number of // arguments for each of the execute operations. auto is_compilable_const = [](mlir::tf_device::ClusterOp cluster, diff --git a/tensorflow/compiler/mlir/tfrt/transforms/tfrt_pipeline_options.h b/tensorflow/compiler/mlir/tfrt/transforms/tfrt_pipeline_options.h index a4c62f8bf20a2e..f85a700b9dcf04 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/tfrt_pipeline_options.h +++ b/tensorflow/compiler/mlir/tfrt/transforms/tfrt_pipeline_options.h @@ -153,26 +153,6 @@ struct TfrtPipelineOptions llvm::cl::desc("If true, streams with inter data depenedencies will be " "preferred to be merged for inline execution."), llvm::cl::init(false)}; - - // A set of flags to control auto-fusion: automatic clustering of Tensorflow - // operations and compiling outlined regions using MLIR based compilation - // stack. - // - // WARNING: These flags are experimental and are intended for manual testing - // of different auto-fusion strategies. They will be removed in the future. - - ListOption auto_fusion_oplist{ - *this, "auto-fusion-oplist", - llvm::cl::desc("A list of Tensorflow operations to cluster together for " - "JIT compilation. Alternatively use 'tier1', ..., 'all' " - "to allow clustering for all operations included in the " - "given clustering tier.")}; - - Option auto_fusion_min_cluster_size{ - *this, "auto-fusion-min-cluster-size", - llvm::cl::desc("Minimum size of the cluster that should be outlined for " - "compilation"), - llvm::cl::init(2)}; }; } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tfrt/translate/import_model.cc b/tensorflow/compiler/mlir/tfrt/translate/import_model.cc index 080439159e628e..55e923d3b6f400 100644 --- a/tensorflow/compiler/mlir/tfrt/translate/import_model.cc +++ b/tensorflow/compiler/mlir/tfrt/translate/import_model.cc @@ -312,9 +312,6 @@ std::unique_ptr GetTfrtPipelineOptions( pipeline_options->func_use_fallback_tensor = true; pipeline_options->enable_while_parallel_iterations = options.enable_while_parallel_iterations; - pipeline_options->auto_fusion_oplist = options.auto_fusion_oplist; - pipeline_options->auto_fusion_min_cluster_size = - options.auto_fusion_min_cluster_size; pipeline_options->cost_threshold = options.cost_threshold; pipeline_options->upper_cost_threshold = options.upper_cost_threshold; pipeline_options->merge_inter_dependent_streams = diff --git a/tensorflow/compiler/mlir/tfrt/translate/tfrt_compile_options.cc b/tensorflow/compiler/mlir/tfrt/translate/tfrt_compile_options.cc index 1e4a81d0d0cc03..3bfb5d853a9a97 100644 --- a/tensorflow/compiler/mlir/tfrt/translate/tfrt_compile_options.cc +++ b/tensorflow/compiler/mlir/tfrt/translate/tfrt_compile_options.cc @@ -57,10 +57,6 @@ std::ostream& operator<<(std::ostream& os, const TfrtCompileOptions& options) { << ", hoist_invariant_ops = " << options.hoist_invariant_ops << ", enable_while_parallel_iterations = " << options.enable_while_parallel_iterations - << ", auto_fusion_oplist = [" - << absl::StrJoin(options.auto_fusion_oplist, ",") << "]" - << ", auto_fusion_min_cluster_size = " - << options.auto_fusion_min_cluster_size << ", cost_threshold = " << options.cost_threshold << ", upper_cost_threshold = " << options.upper_cost_threshold << ", merge_inter_dependent_streams = " diff --git a/tensorflow/compiler/mlir/tfrt/translate/tfrt_compile_options.h b/tensorflow/compiler/mlir/tfrt/translate/tfrt_compile_options.h index 23ef81be002a27..6a2d4a4a23d32a 100644 --- a/tensorflow/compiler/mlir/tfrt/translate/tfrt_compile_options.h +++ b/tensorflow/compiler/mlir/tfrt/translate/tfrt_compile_options.h @@ -125,20 +125,6 @@ struct TfrtCompileOptions { // basis. This is currently experimental. bool enable_while_parallel_iterations = false; - // A set of flags to control auto-fusion: automatic clustering of Tensorflow - // operations and compiling outlined regions using MLIR based compilation - // stack. - // - // WARNING: These flags are experimental and are intended for manual testing - // of different auto-fusion strategies. They will be removed in the future. - - // A list of Tensorflow operations that are supported by auto-fusion - // clustering and compilation (e.g. tf.Tanh). - std::vector auto_fusion_oplist; - - // Minimum size of the cluster to be compiled at runtime. - int auto_fusion_min_cluster_size = 2; - // The cost threshold to decide whether a sequence of operations is cheap, and // then whether it can be executed inline. If the cost is smaller than the // threshold, it will be considered as cheap operations. Since the cost must diff --git a/tensorflow/compiler/xla/mlir/runtime/BUILD b/tensorflow/compiler/xla/mlir/runtime/BUILD index e36272acd6e0cd..48d976de9f2d9d 100644 --- a/tensorflow/compiler/xla/mlir/runtime/BUILD +++ b/tensorflow/compiler/xla/mlir/runtime/BUILD @@ -7,7 +7,6 @@ package_group( packages = [ # copybara:uncomment_begin(google-only) # "//platforms/xla/service/cpu/...", - # "//learning/brain/experimental/tfrt/autofusion/...", # "//third_party/mlir_edge/tpgen/...", # # TODO(ezhulenev): Clean up dependencies that are leforvers from Autofusion project. # "@tf_runtime//...", From ab731c0409ee5ff5b1ce9a69dea68eb691ded45b Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 12 Jul 2023 13:55:43 -0700 Subject: [PATCH 217/376] Add error logging for internal error statistics. Logging invocation is a no-op in open-source, no logging is performed in open-source code. PiperOrigin-RevId: 547593050 --- tensorflow/compiler/mlir/tf2xla/api/v1/BUILD | 1 + .../mlir/tf2xla/api/v1/legalize_tf.cc | 21 +++++++++++++- tensorflow/tsl/platform/BUILD | 12 ++++++++ tensorflow/tsl/platform/build_config.bzl | 2 ++ tensorflow/tsl/platform/default/BUILD | 20 +++++++++++++ .../tsl/platform/default/build_config.bzl | 3 ++ .../tsl/platform/default/error_logging.cc | 29 +++++++++++++++++++ .../tsl/platform/default/error_logging.h | 29 +++++++++++++++++++ tensorflow/tsl/platform/error_logging.h | 27 +++++++++++++++++ 9 files changed, 143 insertions(+), 1 deletion(-) create mode 100644 tensorflow/tsl/platform/default/error_logging.cc create mode 100644 tensorflow/tsl/platform/default/error_logging.h create mode 100644 tensorflow/tsl/platform/error_logging.h diff --git a/tensorflow/compiler/mlir/tf2xla/api/v1/BUILD b/tensorflow/compiler/mlir/tf2xla/api/v1/BUILD index 3406b6d0e4739d..19f9bae6ca920f 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v1/BUILD +++ b/tensorflow/compiler/mlir/tf2xla/api/v1/BUILD @@ -45,6 +45,7 @@ cc_library( "//tensorflow/core/tpu/kernels:tpu_compile_op_support", "//tensorflow/core/tpu/kernels:tpu_compile_proto_cc", "//tensorflow/core/tpu/kernels:tpu_util_hdrs", + "//tensorflow/tsl/platform:error_logging", "//tensorflow/tsl/platform:status", "//tensorflow/tsl/platform:statusor", "@com_google_absl//absl/log", diff --git a/tensorflow/compiler/mlir/tf2xla/api/v1/legalize_tf.cc b/tensorflow/compiler/mlir/tf2xla/api/v1/legalize_tf.cc index 59ebb5e913b3a0..cf21467ad7ff6d 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v1/legalize_tf.cc +++ b/tensorflow/compiler/mlir/tf2xla/api/v1/legalize_tf.cc @@ -48,6 +48,7 @@ limitations under the License. #include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h" #include "tensorflow/core/tpu/kernels/tpu_util.h" #include "tensorflow/core/tpu/tpu_compile.h" +#include "tensorflow/tsl/platform/error_logging.h" #include "tensorflow/tsl/platform/status.h" #include "tensorflow/tsl/platform/statusor.h" @@ -95,6 +96,9 @@ constexpr char kOldBridgeWithFallbackModeSuccess[] = // Old Bridge failed in fallback (was run because MLIR bridge failed first). constexpr char kOldBridgeWithFallbackModeFailure[] = "kOldBridgeWithFallbackModeFailure"; +// Name of component for error logging. This name is fixed and required to +// enable logging. +constexpr char kBridgeComponent[] = "TFXLABridge"; // Time the execution of kernels (in CPU cycles). Meant to be used as RAII. struct CompilationTimer { @@ -227,9 +231,19 @@ tsl::StatusOr LegalizeMlirToHlo( } else if (!enable_op_fallback) { // Don't fallback to the old bridge if op-by-op fallback isn't enabled. mlir_second_phase_count->GetCell(kMlirModeFailure)->IncrementBy(1); + if (!mlir_bridge_status.ok()) { + tsl::error_logging::Log(kBridgeComponent, + "TFXLA_API_V1_BRIDGE_NO_FALLBACK", + mlir_bridge_status.ToString()) + .IgnoreError(); + } return mlir_bridge_status; + } else { + tsl::error_logging::Log(kBridgeComponent, + "TFXLA_API_V1_BRIDGE_WITH_FALLBACK_FAIL", + mlir_bridge_status.ToString()) + .IgnoreError(); } - bool filtered_graph = false; if (mlir_bridge_status == CompileToHloGraphAnalysisFailedError()) { VLOG(1) << "Filtered out MLIR computation to XLA HLO using MLIR tf2xla " @@ -263,6 +277,11 @@ tsl::StatusOr LegalizeMlirToHlo( mlir_second_phase_count->GetCell(kOldBridgeWithFallbackModeFailure) ->IncrementBy(1); } + if (!old_bridge_status.ok()) { + tsl::error_logging::Log(kBridgeComponent, "TFXLA_API_V1_OLD_BRIDGE", + mlir_bridge_status.ToString()) + .IgnoreError(); + } return old_bridge_status; } diff --git a/tensorflow/tsl/platform/BUILD b/tensorflow/tsl/platform/BUILD index 3eb5ff9089abc6..dd5d7b3ace120c 100644 --- a/tensorflow/tsl/platform/BUILD +++ b/tensorflow/tsl/platform/BUILD @@ -14,6 +14,7 @@ load( load( "//tensorflow/tsl/platform:build_config.bzl", "tf_cuda_libdevice_path_deps", + "tf_error_logging_deps", "tf_fingerprint_deps", "tf_google_mobile_srcs_no_runtime", "tf_logging_deps", @@ -665,6 +666,7 @@ exports_files( "env.cc", "env.h", "env_time.h", + "error_logging.h", "file_system.cc", "file_system.h", "file_system_helper.cc", @@ -1096,6 +1098,16 @@ cc_library( deps = tf_logging_deps(), ) +cc_library( + name = "error_logging", + compatible_with = get_compatible_with_portable(), + textual_hdrs = ["error_logging.h"], + visibility = [ + "//visibility:public", + ], + deps = tf_error_logging_deps(), +) + cc_library( name = "prefetch", hdrs = ["prefetch.h"], diff --git a/tensorflow/tsl/platform/build_config.bzl b/tensorflow/tsl/platform/build_config.bzl index a257152eea8ebd..8515b784a585cc 100644 --- a/tensorflow/tsl/platform/build_config.bzl +++ b/tensorflow/tsl/platform/build_config.bzl @@ -14,6 +14,7 @@ load( _tf_additional_tensor_coding_deps = "tf_additional_tensor_coding_deps", _tf_additional_test_deps = "tf_additional_test_deps", _tf_cuda_libdevice_path_deps = "tf_cuda_libdevice_path_deps", + _tf_error_logging_deps = "tf_error_logging_deps", _tf_fingerprint_deps = "tf_fingerprint_deps", _tf_google_mobile_srcs_no_runtime = "tf_google_mobile_srcs_no_runtime", _tf_google_mobile_srcs_only_runtime = "tf_google_mobile_srcs_only_runtime", @@ -54,6 +55,7 @@ tf_additional_rpc_deps = _tf_additional_rpc_deps tf_additional_tensor_coding_deps = _tf_additional_tensor_coding_deps tf_additional_test_deps = _tf_additional_test_deps tf_cuda_libdevice_path_deps = _tf_cuda_libdevice_path_deps +tf_error_logging_deps = _tf_error_logging_deps tf_fingerprint_deps = _tf_fingerprint_deps tf_google_mobile_srcs_no_runtime = _tf_google_mobile_srcs_no_runtime tf_google_mobile_srcs_only_runtime = _tf_google_mobile_srcs_only_runtime diff --git a/tensorflow/tsl/platform/default/BUILD b/tensorflow/tsl/platform/default/BUILD index 84af5a2fbca3ab..d9cbafe7710aa7 100644 --- a/tensorflow/tsl/platform/default/BUILD +++ b/tensorflow/tsl/platform/default/BUILD @@ -206,6 +206,24 @@ cc_library( deps = ["//tensorflow/tsl/platform:types"], ) +cc_library( + name = "error_logging", + srcs = ["error_logging.cc"], + hdrs = ["//tensorflow/tsl/platform:error_logging.h"], + tags = [ + "manual", + "no_oss", + "nobuilder", + ], + textual_hdrs = ["error_logging.h"], + deps = [ + "//tensorflow/tsl/platform", + "@com_google_absl//absl/base", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + ], +) + cc_library( name = "human_readable_json", srcs = ["human_readable_json.cc"], @@ -608,6 +626,7 @@ exports_files( srcs = glob( ["*"], exclude = [ + "error_logging.h", "integral_types.h", "logging.h", "test.cc", @@ -618,6 +637,7 @@ exports_files( exports_files( srcs = [ + "error_logging.h", "integral_types.h", "logging.h", "test.cc", diff --git a/tensorflow/tsl/platform/default/build_config.bzl b/tensorflow/tsl/platform/default/build_config.bzl index 8c8f606dd644e3..7863928644e66e 100644 --- a/tensorflow/tsl/platform/default/build_config.bzl +++ b/tensorflow/tsl/platform/default/build_config.bzl @@ -842,6 +842,9 @@ def tf_platform_alias(name, platform_dir = "//tensorflow/tsl/platform/"): def tf_logging_deps(): return [clean_dep("//tensorflow/tsl/platform/default:logging")] +def tf_error_logging_deps(): + return [clean_dep("//tensorflow/tsl/platform/default:error_logging")] + def tf_resource_deps(): return [clean_dep("//tensorflow/tsl/platform/default:resource")] diff --git a/tensorflow/tsl/platform/default/error_logging.cc b/tensorflow/tsl/platform/default/error_logging.cc new file mode 100644 index 00000000000000..59efa3dc148124 --- /dev/null +++ b/tensorflow/tsl/platform/default/error_logging.cc @@ -0,0 +1,29 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/tsl/platform/default/error_logging.h" + +#include "absl/status/status.h" +#include "absl/strings/string_view.h" + +namespace tsl::error_logging { + +absl::Status Log(absl::string_view component, absl::string_view subcomponent, + absl::string_view error_msg) { + // no-op, intentionally empty function + return absl::OkStatus(); +} + +} // namespace tsl::error_logging diff --git a/tensorflow/tsl/platform/default/error_logging.h b/tensorflow/tsl/platform/default/error_logging.h new file mode 100644 index 00000000000000..26360b7e5b72e4 --- /dev/null +++ b/tensorflow/tsl/platform/default/error_logging.h @@ -0,0 +1,29 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_TSL_PLATFORM_DEFAULT_ERROR_LOGGING_H_ +#define TENSORFLOW_TSL_PLATFORM_DEFAULT_ERROR_LOGGING_H_ + +#include "absl/status/status.h" +#include "absl/strings/string_view.h" + +namespace tsl::error_logging { + +absl::Status Log(absl::string_view component, absl::string_view subcomponent, + absl::string_view error_msg); + +} + +#endif // TENSORFLOW_TSL_PLATFORM_DEFAULT_ERROR_LOGGING_H_ diff --git a/tensorflow/tsl/platform/error_logging.h b/tensorflow/tsl/platform/error_logging.h new file mode 100644 index 00000000000000..d27d0115f37391 --- /dev/null +++ b/tensorflow/tsl/platform/error_logging.h @@ -0,0 +1,27 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_TSL_PLATFORM_ERROR_LOGGING_H_ +#define TENSORFLOW_TSL_PLATFORM_ERROR_LOGGING_H_ + +#include "tensorflow/tsl/platform/platform.h" + +#if defined(PLATFORM_GOOGLE) +#include "tensorflow/tsl/platform/google/error_logging.h" // IWYU pragma: export +#else +#include "tensorflow/tsl/platform/default/error_logging.h" // IWYU pragma: export +#endif + +#endif // TENSORFLOW_TSL_PLATFORM_ERROR_LOGGING_H_ From 74a6afafff69dc33f247a18ced5d1c4f0c891606 Mon Sep 17 00:00:00 2001 From: Yu Feng Date: Wed, 12 Jul 2023 14:21:04 -0700 Subject: [PATCH 218/376] Increase type support for DoInplace and DoCopy to uint8, int8 and uint64 With DTensor, it is more likely that one emit TF graphs with these new dtypes. PiperOrigin-RevId: 547600278 --- tensorflow/core/kernels/inplace_ops.cc | 3 +++ .../kernels/inplace_ops_functor_gpu.cu.cc | 14 ++++++++++---- .../array_ops/inplace_ops_test.py | 19 +++++++++++++++---- 3 files changed, 28 insertions(+), 8 deletions(-) diff --git a/tensorflow/core/kernels/inplace_ops.cc b/tensorflow/core/kernels/inplace_ops.cc index 5d6561a201a25a..1ce48822ea2c20 100644 --- a/tensorflow/core/kernels/inplace_ops.cc +++ b/tensorflow/core/kernels/inplace_ops.cc @@ -457,7 +457,10 @@ REGISTER_KERNEL_BUILDER( Name("InplaceUpdate").Device(DEVICE_GPU).TypeConstraint("T"), InplaceOp); TF_CALL_GPU_NUMBER_TYPES(REGISTER); +REGISTER(int8_t); +REGISTER(uint8_t); REGISTER(int64_t); +REGISTER(uint64_t); REGISTER_EMPTY(int32, GPU); #undef REGISTER diff --git a/tensorflow/core/kernels/inplace_ops_functor_gpu.cu.cc b/tensorflow/core/kernels/inplace_ops_functor_gpu.cu.cc index f2536da06bec24..06a9a0b56aefd0 100644 --- a/tensorflow/core/kernels/inplace_ops_functor_gpu.cu.cc +++ b/tensorflow/core/kernels/inplace_ops_functor_gpu.cu.cc @@ -177,10 +177,13 @@ Status DoInplace(const Device& d, InplaceOpType op, const Tensor& i, CASE(double) CASE(Eigen::half) CASE(Eigen::bfloat16) - CASE(int64) + CASE(uint8_t) + CASE(int8_t) + CASE(int64_t) + CASE(uint64_t) #undef CASE default: - return errors::InvalidArgument("Unsupported data type: ", + return errors::InvalidArgument("Unsupported data type from DoInplace: ", DataTypeString(v.dtype())); } return OkStatus(); @@ -202,10 +205,13 @@ Status DoCopy(const Device& d, const Tensor& x, Tensor* y) { CASE(Eigen::bfloat16) CASE(complex64) CASE(complex128) - CASE(int64) + CASE(uint8_t) + CASE(int8_t) + CASE(int64_t) + CASE(uint64_t) #undef CASE default: - return errors::InvalidArgument("Unsupported dtype: ", + return errors::InvalidArgument("Unsupported dtype from DoCopy: ", DataTypeString(x.dtype())); } return OkStatus(); diff --git a/tensorflow/python/kernel_tests/array_ops/inplace_ops_test.py b/tensorflow/python/kernel_tests/array_ops/inplace_ops_test.py index dae8b333f8dfcf..1a7ec203c9467e 100644 --- a/tensorflow/python/kernel_tests/array_ops/inplace_ops_test.py +++ b/tensorflow/python/kernel_tests/array_ops/inplace_ops_test.py @@ -24,10 +24,21 @@ from tensorflow.python.platform import test as test_lib +BASIC_TYPES = [ + dtypes.float32, + dtypes.int8, + dtypes.uint8, + dtypes.int32, + dtypes.int64, + dtypes.uint64, + dtypes.bfloat16, +] + + class InplaceOpsTest(test_util.TensorFlowTestCase): def testBasicUpdate(self): - for dtype in [dtypes.float32, dtypes.int32, dtypes.int64, dtypes.bfloat16]: + for dtype in BASIC_TYPES: with test_util.use_gpu(): x = array_ops.ones([7, 3], dtype) y = np.ones([7, 3], dtype.as_numpy_dtype) @@ -61,7 +72,7 @@ def testBasicUpdateBool(self): self.assertAllClose(x, y) def testBasicAdd(self): - for dtype in [dtypes.float32, dtypes.int32, dtypes.int64, dtypes.bfloat16]: + for dtype in BASIC_TYPES: with test_util.use_gpu(): x = array_ops.ones([7, 3], dtype) y = np.ones([7, 3], dtype.as_numpy_dtype) @@ -80,7 +91,7 @@ def testBasicAdd(self): self.assertAllClose(x, y) def testBasicSub(self): - for dtype in [dtypes.float32, dtypes.int32, dtypes.int64, dtypes.bfloat16]: + for dtype in BASIC_TYPES: with test_util.use_gpu(): x = array_ops.ones([7, 3], dtype) y = np.ones([7, 3], dtype.as_numpy_dtype) @@ -196,7 +207,7 @@ def testInplaceOpOnEmptyTensors(self): inplace_ops.inplace_sub, inplace_ops.inplace_update, ] - for dtype in [dtypes.float32, dtypes.int32, dtypes.int64]: + for dtype in BASIC_TYPES: for op_fn in op_fns: with test_util.use_gpu(): x = array_ops.zeros([7, 0], dtype) From 8ed1178bc9acdf1b19e3c13607d5a906bbeed983 Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Wed, 12 Jul 2023 14:34:09 -0700 Subject: [PATCH 219/376] Integrate StableHLO at openxla/stablehlo@41bad51 Manual changes: * CMakeLists.txt: keep MLIR-HLO-only customizations to the CMake build. * BUILD.bazel, stablehlo/dialect/CMakeLists.txt, stablehlo/dialect/Base.h, stablehlo/dialect/Base.cpp, stablehlo/dialect/ExperimentalOps.h, stablehlo/dialect/ExperimentalOps.cpp, stablehlo/transforms/StablehloCanonicalizeDynamism.cpp, stablehlo/transforms/StablehloRefineShapes.cpp, stablehlo/tests/stablehlo_canonicalize_dynamism.mlir, stablehlo/tests/stablehlo_refine_shapes.mlir: keep XLA-only customizations to StableHLO shape refinement. PiperOrigin-RevId: 547603768 --- third_party/stablehlo/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/stablehlo/workspace.bzl b/third_party/stablehlo/workspace.bzl index 4319edf4033bb6..c7d30c5b560974 100644 --- a/third_party/stablehlo/workspace.bzl +++ b/third_party/stablehlo/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): # LINT.IfChange - STABLEHLO_COMMIT = "4add5f0e890bc66b333e86961978f066325f8a86" - STABLEHLO_SHA256 = "4d3014703aa8d18477790b2f3040163276b50f647aa2da32396f390ea8bf6f7c" + STABLEHLO_COMMIT = "41bad512515d609ccd3896d74bf697e7d456e1d3" + STABLEHLO_SHA256 = "01d143b57efda2fcf5e3482cbd0c4beae2a51164082e0797f0093cdbd8c82b06" # LINT.ThenChange(Google-internal path) tf_http_archive( From dae72e3210b6b03ac41d548f955aad0accbec6e0 Mon Sep 17 00:00:00 2001 From: Yu Feng Date: Wed, 12 Jul 2023 14:36:02 -0700 Subject: [PATCH 220/376] Code reformatting for readibility. PiperOrigin-RevId: 547604248 --- .../dtensor/mlir/dtensor_collective_type_lowering.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tensorflow/dtensor/mlir/dtensor_collective_type_lowering.cc b/tensorflow/dtensor/mlir/dtensor_collective_type_lowering.cc index ad19eb415f3478..b3aa8a70d9fe4d 100644 --- a/tensorflow/dtensor/mlir/dtensor_collective_type_lowering.cc +++ b/tensorflow/dtensor/mlir/dtensor_collective_type_lowering.cc @@ -121,21 +121,21 @@ mlir::LogicalResult ConvertShortIntReduce(ReduceOpType reduce_op) { << "Received '" << reduce_op.getReduceOpAttr().getValue().str() << "'"; } - if (mlir::isa(tensor_input_type.getElementType())) { + if (auto integer_type = mlir::dyn_cast( + tensor_input_type.getElementType())) { int32_t min_width = 64; if (output_layout->mesh().is_tpu_mesh()) { min_width = 32; } - if (tensor_input_type.getElementType().getIntOrFloatBitWidth() >= - min_width) { + if (integer_type.getWidth() >= min_width) { return mlir::success(); } auto input_type = mlir::RankedTensorType::get( tensor_input_type.getShape(), builder.getIntegerType(min_width)); auto output_type = mlir::RankedTensorType::get( - tensor_output_type.getShape(), tensor_input_type.getElementType()); + tensor_output_type.getShape(), integer_type); return WrapOpWithCasts(input_type, output_type, reduce_op); } if (mlir::isa(tensor_input_type.getElementType())) { From 268090e78ac451847a2cd2586d8bfcbe0acd5bf1 Mon Sep 17 00:00:00 2001 From: Fiona Lang Date: Wed, 12 Jul 2023 14:50:44 -0700 Subject: [PATCH 221/376] Move tf debug imports from python/__init__.py to python/modules_with_exports.py. PiperOrigin-RevId: 547607939 --- tensorflow/python/BUILD | 3 +++ tensorflow/python/__init__.py | 5 ----- tensorflow/python/modules_with_exports.py | 5 +++++ tensorflow/python/ops/BUILD | 2 ++ tensorflow/python/ops/gradients_impl.py | 2 ++ 5 files changed, 12 insertions(+), 5 deletions(-) diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 2940055b78d377..18716ed6ea152f 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -352,6 +352,8 @@ py_library( "//tensorflow/python/compiler/xla", "//tensorflow/python/compiler/xla:compiler_py", "//tensorflow/python/data", + "//tensorflow/python/debug/lib:check_numerics_callback", + "//tensorflow/python/debug/lib:dumping_callback", "//tensorflow/python/distribute", "//tensorflow/python/distribute:merge_call_interim", "//tensorflow/python/distribute:multi_process_runner", @@ -381,6 +383,7 @@ py_library( "//tensorflow/python/ops:composite_tensor_ops", "//tensorflow/python/ops:cond_v2", "//tensorflow/python/ops:cudnn_rnn_ops_gen", + "//tensorflow/python/ops:debug_ops_gen", "//tensorflow/python/ops:gradient_checker_v2", "//tensorflow/python/ops:image_ops", "//tensorflow/python/ops:initializers_ns", diff --git a/tensorflow/python/__init__.py b/tensorflow/python/__init__.py index 4e1807e79f9905..bfb0114b305461 100644 --- a/tensorflow/python/__init__.py +++ b/tensorflow/python/__init__.py @@ -72,11 +72,6 @@ from tensorflow.python.util.all_util import make_all from tensorflow.python.util.tf_export import tf_export -# TensorFlow Debugger (tfdbg). -from tensorflow.python.debug.lib import check_numerics_callback -from tensorflow.python.debug.lib import dumping_callback -from tensorflow.python.ops import gen_debug_ops - # Update dispatch decorator docstrings to contain lists of registered APIs. # (This should come after any imports that register APIs.) from tensorflow.python.util import dispatch diff --git a/tensorflow/python/modules_with_exports.py b/tensorflow/python/modules_with_exports.py index 2edd9ae88d5ba0..c676d9f5bad529 100644 --- a/tensorflow/python/modules_with_exports.py +++ b/tensorflow/python/modules_with_exports.py @@ -39,6 +39,11 @@ # Data from tensorflow.python import data +# TensorFlow Debugger (tfdbg). +from tensorflow.python.debug.lib import check_numerics_callback +from tensorflow.python.debug.lib import dumping_callback +from tensorflow.python.ops import gen_debug_ops + # Distribute from tensorflow.python import distribute diff --git a/tensorflow/python/ops/BUILD b/tensorflow/python/ops/BUILD index f6d7afeb89343b..a81d14afbbe481 100644 --- a/tensorflow/python/ops/BUILD +++ b/tensorflow/python/ops/BUILD @@ -1628,6 +1628,8 @@ py_strict_library( ":tensor_array_ops", ":unconnected_gradients", ":while_loop", + "//tensorflow/python/debug/lib:debug_gradients", + "//tensorflow/python/debug/lib:dumping_callback", "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", "//tensorflow/python/ops/linalg/sparse:sparse_csr_matrix_grad", diff --git a/tensorflow/python/ops/gradients_impl.py b/tensorflow/python/ops/gradients_impl.py index d1ccbbf581fc60..673fcf46fc393c 100644 --- a/tensorflow/python/ops/gradients_impl.py +++ b/tensorflow/python/ops/gradients_impl.py @@ -14,6 +14,8 @@ # ============================================================================== """Implements the graph generation for computation of gradients.""" +from tensorflow.python.debug.lib import debug_gradients # pylint: disable=unused-import +from tensorflow.python.debug.lib import dumping_callback # pylint: disable=unused-import from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_grad # pylint: disable=unused-import From cb73612a496950fa5ace9ee9967757edb2e5a493 Mon Sep 17 00:00:00 2001 From: Justin Lebar Date: Wed, 12 Jul 2023 15:07:19 -0700 Subject: [PATCH 222/376] [XLA:GPU] Disable leaky-relu fusion. This isn't working, even on Ampere. Some kLeakyRelu convolutions have 0 available algorithms, so autotuning just fails. PiperOrigin-RevId: 547612301 --- .../xla/service/gpu/cudnn_fused_conv_rewriter.cc | 9 +++++++++ .../xla/service/gpu/cudnn_fused_conv_rewriter_test.cc | 8 ++++++-- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.cc b/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.cc index 3bb00f58de0b07..220b53c46011a0 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.cc @@ -674,6 +674,15 @@ StatusOr FuseRelu6(HloComputation* comp, se::CudaComputeCapability cc) { StatusOr FuseLeakyRelu(HloComputation* comp, se::CudaComputeCapability cc) { + // TODO(jlebar): Disabled due to bugs in cudnn 8.9.0. In particular, the + // following convolution gets 0 algorithms available, so it fails to run. + // + // (f16[2,256,768,16]{3,2,1,0}, u8[0]{0}) + // custom-call(f16[2,256,768,3]{3,2,1,0} %a, f16[16,3,3,3]{3,2,1,0} %b, + // f16[16]{0} %c), window={size=3x3 pad=1_1x1_1}, + // dim_labels=b01f_o01i->b01f, operand_precision={highest,highest} + return false; + if (!ShouldUseCudnnRuntimeFusion(comp->parent()->config().debug_options(), cc)) { return false; diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter_test.cc b/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter_test.cc index 42326417821860..0701cc5e19e0a1 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter_test.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter_test.cc @@ -285,7 +285,9 @@ TEST_F(CudnnFusedConvRewriterTest, TestRelu6) { })"); } -TEST_F(CudnnFusedConvRewriterTest, TestLeakyRelu) { +// TODO(jlebar): leaky-relu fusion is disabled because some convolutions have 0 +// algorithm choices. See the cc file. +TEST_F(CudnnFusedConvRewriterTest, DISABLED_TestLeakyRelu) { if (!GetCudaComputeCapability().IsAtLeast( se::CudaComputeCapability::AMPERE)) { GTEST_SKIP() @@ -1148,7 +1150,9 @@ TEST_F(CudnnFusedConvRewriterHloTest, DontFuseRelu6IfMultipleUses) { EXPECT_EQ(config.activation_mode(), se::dnn::kNone); } -TEST_F(CudnnFusedConvRewriterHloTest, FuseLeakyRelu) { +// TODO(jlebar): leaky-relu fusion is disabled because some convolutions have 0 +// algorithm choices. See the cc file. +TEST_F(CudnnFusedConvRewriterHloTest, DISABLED_FuseLeakyRelu) { const std::string module_str = R"( HloModule Test ENTRY Test { From c4d31b0f488f02fa7fe1bbccf3f75b34d651cf3a Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 12 Jul 2023 15:19:52 -0700 Subject: [PATCH 223/376] Adds check for Optional Tensors before increasing the reference count for graph outputs PiperOrigin-RevId: 547615457 --- tensorflow/lite/simple_planner.cc | 4 +++- tensorflow/lite/simple_planner_test.cc | 19 +++++++++++++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/tensorflow/lite/simple_planner.cc b/tensorflow/lite/simple_planner.cc index 3cf26384966dfd..9e24ad0660c7b8 100644 --- a/tensorflow/lite/simple_planner.cc +++ b/tensorflow/lite/simple_planner.cc @@ -101,7 +101,9 @@ TfLiteStatus SimplePlanner::PlanAllocations() { // artificially adding one to their ref-counts so they are never selected // for deallocation. for (int tensor_index : graph_info_->outputs()) { - refcounts[tensor_index]++; + if (tensor_index != kTfLiteOptionalTensor) { + refcounts[tensor_index]++; + } } // Variable tensors also should be ensured to be never overwritten and need to diff --git a/tensorflow/lite/simple_planner_test.cc b/tensorflow/lite/simple_planner_test.cc index 4e3f7e06186629..0b49600f569d39 100644 --- a/tensorflow/lite/simple_planner_test.cc +++ b/tensorflow/lite/simple_planner_test.cc @@ -365,5 +365,24 @@ TEST_F(SimplePlannerTest, SimpleGraphWithPersistentResetAllocationsAfter) { EXPECT_TRUE(tensor5_ptr == (*graph.tensors())[5].data.raw); } +TEST_F(SimplePlannerTest, SimpleGraphOptionalOutput) { + TestGraph graph({0, 1}, + { + /* in, out, tmp */ + {{0, 1}, {2}, {}}, // First op + {{2, 0}, {4, 5}, {}}, // Second op + {{4, 5}, {3}, {}} // Third op + }, + {-1, 3}); + SetGraph(&graph); + Execute(0, 10); + + EXPECT_TRUE(IsAllocated(1)); + EXPECT_TRUE(IsAllocated(2)); + EXPECT_TRUE(IsAllocated(3)); + EXPECT_TRUE(IsAllocated(4)); + EXPECT_TRUE(IsAllocated(5)); +} + } // namespace } // namespace tflite From bd41e624acf748f1e4c99a36dba82d6a05290e27 Mon Sep 17 00:00:00 2001 From: Anlun Xu Date: Wed, 12 Jul 2023 15:24:51 -0700 Subject: [PATCH 224/376] [xla:gpu] Support exporting dataflow graph in DOT format for debugging PiperOrigin-RevId: 547616713 --- .../xla/mlir/backends/gpu/transforms/BUILD | 2 ++ .../gpu/transforms/add_concurrent_regions.cc | 9 ++++++ .../gpu/transforms/dataflow_analysis.cc | 28 +++++++++++++++++++ .../gpu/transforms/dataflow_analysis.h | 3 ++ 4 files changed, 42 insertions(+) diff --git a/tensorflow/compiler/xla/mlir/backends/gpu/transforms/BUILD b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/BUILD index 4b43a94b860c4f..622293491f938e 100644 --- a/tensorflow/compiler/xla/mlir/backends/gpu/transforms/BUILD +++ b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/BUILD @@ -33,6 +33,7 @@ cc_library( deps = [ "//tensorflow/compiler/xla/mlir_hlo:lhlo", "//tensorflow/compiler/xla/mlir_hlo:lhlo_gpu", + "@com_google_absl//absl/strings", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:GPUDialect", "@llvm-project//mlir:IR", @@ -73,6 +74,7 @@ cc_library( "//tensorflow/compiler/xla/service/gpu:nccl_collective_thunks", "//tensorflow/compiler/xla/stream_executor:blas", "//tensorflow/compiler/xla/translate/mhlo_to_hlo:location_exporter", + "//tensorflow/tsl/platform:env", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", diff --git a/tensorflow/compiler/xla/mlir/backends/gpu/transforms/add_concurrent_regions.cc b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/add_concurrent_regions.cc index c524043ace6cc9..60a2f1ff1fba52 100644 --- a/tensorflow/compiler/xla/mlir/backends/gpu/transforms/add_concurrent_regions.cc +++ b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/add_concurrent_regions.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include +#include #include #include #include @@ -31,6 +32,7 @@ limitations under the License. #include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/xla/mlir/backends/gpu/transforms/dataflow_analysis.h" #include "tensorflow/compiler/xla/mlir/runtime/utils/custom_calls.h" +#include "tensorflow/tsl/platform/env.h" namespace xla { namespace gpu { @@ -93,6 +95,13 @@ llvm::SmallVector GetRegionInfos( DataflowAnalysis::DataflowGraph dataflow_graph = dataflow_analysis.GetDataflowGraph(capture_func); + // If verbose logging is enabled print the dataflow graph as a DOT graph. + if (VLOG_IS_ON(100)) { + std::cout << "Dependency graph for graph capture function " + << capture_func.getName().str() << ":\n" + << dataflow_analysis.ToDot(dataflow_graph); + } + llvm::SmallVector region; auto store_region_and_start_new_region = [&]() { diff --git a/tensorflow/compiler/xla/mlir/backends/gpu/transforms/dataflow_analysis.cc b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/dataflow_analysis.cc index acca9c7b11e5e0..02873b88258588 100644 --- a/tensorflow/compiler/xla/mlir/backends/gpu/transforms/dataflow_analysis.cc +++ b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/dataflow_analysis.cc @@ -16,8 +16,10 @@ limitations under the License. #include "tensorflow/compiler/xla/mlir/backends/gpu/transforms/dataflow_analysis.h" #include +#include #include +#include "absl/strings/str_cat.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/Dialect/GPU/IR/GPUDialect.h" // from @llvm-project #include "mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project @@ -195,5 +197,31 @@ DataflowAnalysis::DataflowGraph DataflowAnalysis::GetDataflowGraph( return graph; } +std::string DataflowAnalysis::ToDot(const DataflowGraph& graph) { + std::string pad; + std::string res; + auto indent = [&] { pad.append(2, ' '); }; + auto outdent = [&] { pad.resize(pad.size() - 2); }; + auto addline = [&](auto&&... args) { + absl::StrAppend(&res, pad, args..., "\n"); + }; + auto get_name = [](const Node& node) -> std::string { + return absl::StrCat("\"", node.operation->getName().getStringRef().str(), + "_", node.index, "\""); + }; + + addline("digraph {"); + indent(); + for (const Node& node : graph) { + for (size_t child_index : node.children) { + Node child = graph[child_index]; + addline(get_name(node), " -> ", get_name(child)); + } + } + outdent(); + addline("}"); + return res; +} + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/mlir/backends/gpu/transforms/dataflow_analysis.h b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/dataflow_analysis.h index 72deb6eed23c81..f301e4297dc62b 100644 --- a/tensorflow/compiler/xla/mlir/backends/gpu/transforms/dataflow_analysis.h +++ b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/dataflow_analysis.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_MLIR_BACKENDS_GPU_TRANSFORMS_DATAFLOW_ANALYSIS_H_ #define TENSORFLOW_COMPILER_XLA_MLIR_BACKENDS_GPU_TRANSFORMS_DATAFLOW_ANALYSIS_H_ +#include #include #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project @@ -45,6 +46,8 @@ class DataflowAnalysis { // have write-conflicts. // (3) We have information about read-only and read-write buffer arguments. DataflowGraph GetDataflowGraph(mlir::func::FuncOp graph_capture_function); + + std::string ToDot(const DataflowGraph& graph); }; } // namespace gpu From 1ef254071118150c3de7cf7746ba8f64262e26f6 Mon Sep 17 00:00:00 2001 From: Faizan Muhammad Date: Wed, 12 Jul 2023 15:26:43 -0700 Subject: [PATCH 225/376] Add more context to the incorrect tensor num error PiperOrigin-RevId: 547617201 --- .../python/eager/polymorphic_function/atomic_function.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tensorflow/python/eager/polymorphic_function/atomic_function.py b/tensorflow/python/eager/polymorphic_function/atomic_function.py index 3c4425f215012f..ae9ec964659a9f 100644 --- a/tensorflow/python/eager/polymorphic_function/atomic_function.py +++ b/tensorflow/python/eager/polymorphic_function/atomic_function.py @@ -234,6 +234,9 @@ def __call__(self, *args: core.Tensor) -> Sequence[core.Tensor]: if len(args) != expected_len: raise ValueError( f"Signature specifies {expected_len} arguments, got: {len(args)}." + f" Expected inputs: {self.cached_definition.signature.input_arg}." + f" Received inputs: {args}." + f" Function Type: {self.function_type!r}" ) with InterpolateRuntimeError(self): From d6aba8d5db1737934ea8670cc753fbef4b253b27 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 12 Jul 2023 15:50:05 -0700 Subject: [PATCH 226/376] Allow RegisterClientFactory() and GetClient() in xla::ifrt::test_util to work with custom (type-erased) deleters. PiperOrigin-RevId: 547622900 --- .../python/ifrt/ir/tests/executable_impl_test_base.h | 2 +- tensorflow/compiler/xla/python/ifrt/test_util.cc | 10 +++++----- tensorflow/compiler/xla/python/ifrt/test_util.h | 4 ++-- .../xla/python/pjrt_ifrt/tfrt_cpu_client_test_lib.cc | 4 ++-- 4 files changed, 10 insertions(+), 10 deletions(-) diff --git a/tensorflow/compiler/xla/python/ifrt/ir/tests/executable_impl_test_base.h b/tensorflow/compiler/xla/python/ifrt/ir/tests/executable_impl_test_base.h index e903b57713a5f7..788a8a7d8c5ab1 100644 --- a/tensorflow/compiler/xla/python/ifrt/ir/tests/executable_impl_test_base.h +++ b/tensorflow/compiler/xla/python/ifrt/ir/tests/executable_impl_test_base.h @@ -58,7 +58,7 @@ class IfrtIrExecutableImplTestBase : public testing::Test { absl::StatusOr PickDevices(int count); mlir::MLIRContext mlir_context_; - std::unique_ptr client_; + std::shared_ptr client_; }; } // namespace test_util diff --git a/tensorflow/compiler/xla/python/ifrt/test_util.cc b/tensorflow/compiler/xla/python/ifrt/test_util.cc index 47ecc393db8269..ac0e3b53bf9bdc 100644 --- a/tensorflow/compiler/xla/python/ifrt/test_util.cc +++ b/tensorflow/compiler/xla/python/ifrt/test_util.cc @@ -32,20 +32,20 @@ namespace { class ClientFactory { public: - void Register(std::function>()> factory) { + void Register(std::function>()> factory) { absl::MutexLock lock(&mu_); CHECK(!factory_) << "Client factory has been already registered."; factory_ = std::move(factory); } - std::function>()> Get() const { + std::function>()> Get() const { absl::MutexLock lock(&mu_); return factory_; } private: mutable absl::Mutex mu_; - std::function>()> factory_ + std::function>()> factory_ ABSL_GUARDED_BY(mu_); }; @@ -57,11 +57,11 @@ ClientFactory& GetGlobalClientFactory() { } // namespace void RegisterClientFactory( - std::function>()> factory) { + std::function>()> factory) { GetGlobalClientFactory().Register(std::move(factory)); } -StatusOr> GetClient() { +StatusOr> GetClient() { auto factory = GetGlobalClientFactory().Get(); CHECK(factory) << "Client factory has not been registered."; return factory(); diff --git a/tensorflow/compiler/xla/python/ifrt/test_util.h b/tensorflow/compiler/xla/python/ifrt/test_util.h index 18bb418431f757..d25a00cef294d3 100644 --- a/tensorflow/compiler/xla/python/ifrt/test_util.h +++ b/tensorflow/compiler/xla/python/ifrt/test_util.h @@ -39,13 +39,13 @@ namespace test_util { // Registers an IFRT client factory function. Must be called only once. void RegisterClientFactory( - std::function>()> factory); + std::function>()> factory); // Returns true iff an IFRT client factory function has been registered. bool IsClientFactoryRegistered(); // Gets a new IFRT client using the registered client factory. -StatusOr> GetClient(); +StatusOr> GetClient(); // Set a default test filter if user doesn't provide one using --gtest_filter. void SetTestFilterIfNotUserSpecified(absl::string_view custom_filter); diff --git a/tensorflow/compiler/xla/python/pjrt_ifrt/tfrt_cpu_client_test_lib.cc b/tensorflow/compiler/xla/python/pjrt_ifrt/tfrt_cpu_client_test_lib.cc index 181f3be1ff8224..8e399194e9e3ed 100644 --- a/tensorflow/compiler/xla/python/pjrt_ifrt/tfrt_cpu_client_test_lib.cc +++ b/tensorflow/compiler/xla/python/pjrt_ifrt/tfrt_cpu_client_test_lib.cc @@ -26,11 +26,11 @@ namespace { const bool kUnused = (test_util::RegisterClientFactory( - []() -> StatusOr> { + []() -> StatusOr> { TF_ASSIGN_OR_RETURN(auto pjrt_client, xla::GetTfrtCpuClient(/*asynchronous=*/true, /*cpu_device_count=*/2)); - return StatusOr>( + return std::shared_ptr( PjRtClient::Create(std::move(pjrt_client))); }), true); From 27a93890000041682337dcce77729cd440548a3f Mon Sep 17 00:00:00 2001 From: Jieying Luo Date: Wed, 12 Jul 2023 15:59:12 -0700 Subject: [PATCH 227/376] [PJRT C API] Add a C API to query plugin attributes. One example attribute is the minimum supported StableHLO version. How to standardize these attributes will be discussed as a follow up. It will likely be a set of attributes that the plugin should return, and the plugin can return plugin-specific attributes. PiperOrigin-RevId: 547625105 --- tensorflow/compiler/xla/pjrt/c/pjrt_c_api.h | 76 ++++++++++++------- .../xla/pjrt/c/pjrt_c_api_wrapper_impl.cc | 7 ++ .../xla/pjrt/c/pjrt_c_api_wrapper_impl.h | 4 + 3 files changed, 58 insertions(+), 29 deletions(-) diff --git a/tensorflow/compiler/xla/pjrt/c/pjrt_c_api.h b/tensorflow/compiler/xla/pjrt/c/pjrt_c_api.h index 9f7926e0d403f5..f9b4c3e67e11fe 100644 --- a/tensorflow/compiler/xla/pjrt/c/pjrt_c_api.h +++ b/tensorflow/compiler/xla/pjrt/c/pjrt_c_api.h @@ -53,7 +53,7 @@ extern "C" { // Changes include: // * Adding a new field to the PJRT_Api or argument structs // * Renaming a method or argument (doesn't affect ABI) -#define PJRT_API_MINOR 8 +#define PJRT_API_MINOR 9 // The plugin should set the major_version and minor_version of // PJRT_Api.pjrt_api_version to be the `PJRT_API_MAJOR` and `PJRT_API_MINOR` in @@ -138,6 +138,50 @@ typedef PJRT_Error* (*PJRT_CallbackError)(PJRT_Error_Code code, const char* message, size_t message_size); +// ---------------------------- Named Values ----------------------------------- + +typedef enum { + PJRT_NamedValue_kString = 0, + PJRT_NamedValue_kInt64, + PJRT_NamedValue_kInt64List, + PJRT_NamedValue_kFloat, +} PJRT_NamedValue_Type; + +// Named value for key-value pairs. +struct PJRT_NamedValue { + size_t struct_size; + void* priv; + const char* name; + size_t name_size; + PJRT_NamedValue_Type type; + union { + const char* string_value; + int64_t int64_value; + const int64_t* int64_array_value; + float float_value; + }; + // `value_size` is the number of elements for array/string and 1 for scalar + // values. + size_t value_size; +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_NamedValue, value_size); + +// ---------------------------------- Plugin ----------------------------------- + +struct PJRT_Plugin_Attributes_Args { + size_t struct_size; + void* priv; + // Returned attributes have the lifetime of the process. + PJRT_NamedValue* attributes; // out + size_t num_attributes; // out +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_Plugin_Attributes_Args, attributes); + +// Returns an array of plugin attributes which are key-value pairs. One example +// attribute is the minimum supported StableHLO version. +// TODO(b/280349977): standardize the list of attributes. +typedef PJRT_Error* PJRT_Plugin_Attributes(PJRT_Plugin_Attributes_Args* args); + // ---------------------------------- Events ----------------------------------- // Represents a notifying event that is returned by PJRT APIs that enqueue @@ -221,34 +265,6 @@ PJRT_DEFINE_STRUCT_TRAITS(PJRT_Event_OnReady_Args, user_arg); // error status and a pointer to an object of the caller's choice as arguments. typedef PJRT_Error* PJRT_Event_OnReady(PJRT_Event_OnReady_Args* args); -// ------------------------ Other Common Data Types ---------------------------- - -typedef enum { - PJRT_NamedValue_kString = 0, - PJRT_NamedValue_kInt64, - PJRT_NamedValue_kInt64List, - PJRT_NamedValue_kFloat, -} PJRT_NamedValue_Type; - -// Named value for key-value pairs. -struct PJRT_NamedValue { - size_t struct_size; - void* priv; - const char* name; - size_t name_size; - PJRT_NamedValue_Type type; - union { - const char* string_value; - int64_t int64_value; - const int64_t* int64_array_value; - float float_value; - }; - // `value_size` is the number of elements for array/string and 1 for scalar - // values. - size_t value_size; -}; -PJRT_DEFINE_STRUCT_TRAITS(PJRT_NamedValue, value_size); - // ---------------------------------- Client ----------------------------------- typedef struct PJRT_Client PJRT_Client; @@ -1664,6 +1680,8 @@ typedef struct { _PJRT_API_STRUCT_FIELD(PJRT_Error_Message); _PJRT_API_STRUCT_FIELD(PJRT_Error_GetCode); + _PJRT_API_STRUCT_FIELD(PJRT_Plugin_Attributes); + _PJRT_API_STRUCT_FIELD(PJRT_Event_Destroy); _PJRT_API_STRUCT_FIELD(PJRT_Event_IsReady); _PJRT_API_STRUCT_FIELD(PJRT_Event_Error); diff --git a/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc b/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc index 4bf5c5cc796c9d..976c9b19bd4257 100644 --- a/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc +++ b/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc @@ -210,6 +210,13 @@ PJRT_Error* PJRT_Error_GetCode(PJRT_Error_GetCode_Args* args) { return nullptr; } +// ---------------------------------- Plugin ----------------------------------- + +PJRT_Error* PJRT_Plugin_Attributes(PJRT_Plugin_Attributes_Args* args) { + args->num_attributes = 0; + return nullptr; +} + // ---------------------------------- Client ----------------------------------- PJRT_Error* PJRT_Client_Destroy(PJRT_Client_Destroy_Args* args) { diff --git a/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_wrapper_impl.h b/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_wrapper_impl.h index 63f1e46809eda1..d1b89621b90fca 100644 --- a/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_wrapper_impl.h +++ b/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_wrapper_impl.h @@ -141,6 +141,8 @@ void PJRT_Error_Destroy(PJRT_Error_Destroy_Args* args); void PJRT_Error_Message(PJRT_Error_Message_Args* args); PJRT_Error* PJRT_Error_GetCode(PJRT_Error_GetCode_Args* args); +PJRT_Error* PJRT_Plugin_Attributes(PJRT_Plugin_Attributes_Args* args); + PJRT_Error* PJRT_Event_Destroy(PJRT_Event_Destroy_Args* args); PJRT_Error* PJRT_Event_IsReady(PJRT_Event_IsReady_Args* args); PJRT_Error* PJRT_Event_Error(PJRT_Event_Error_Args* args); @@ -320,6 +322,8 @@ constexpr PJRT_Api CreatePjrtApi( /*PJRT_Error_Message=*/pjrt::PJRT_Error_Message, /*PJRT_Error_GetCode=*/pjrt::PJRT_Error_GetCode, + /*PJRT_Plugin_Attributes=*/pjrt::PJRT_Plugin_Attributes, + /*PJRT_Event_Destroy=*/pjrt::PJRT_Event_Destroy, /*PJRT_Event_IsReady=*/pjrt::PJRT_Event_IsReady, /*PJRT_Event_Error=*/pjrt::PJRT_Event_Error, From c0b6aa832fa0cf49878ed37a703fd1b0264c79b3 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 12 Jul 2023 16:02:13 -0700 Subject: [PATCH 228/376] Update Tensor dunder methods and add numpy methods to WeakTensor. PiperOrigin-RevId: 547625852 --- .../python/ops/numpy_ops/np_math_ops.py | 47 ++++++++-------- tensorflow/python/ops/weak_tensor_ops.py | 14 +++-- tensorflow/python/ops/weak_tensor_ops_test.py | 54 +++++++++++++++++++ 3 files changed, 90 insertions(+), 25 deletions(-) diff --git a/tensorflow/python/ops/numpy_ops/np_math_ops.py b/tensorflow/python/ops/numpy_ops/np_math_ops.py index f93cc1c708185b..7aad67c44c0304 100644 --- a/tensorflow/python/ops/numpy_ops/np_math_ops.py +++ b/tensorflow/python/ops/numpy_ops/np_math_ops.py @@ -1412,39 +1412,44 @@ def _tensor_size(self): def _tensor_tolist(self): - if isinstance(self, ops.EagerTensor): - return self._numpy().tolist() # pylint: disable=protected-access + if ops.is_symbolic_tensor(self): + raise ValueError('Symbolic Tensors do not support the tolist API.') - raise ValueError('Symbolic Tensors do not support the tolist API.') + return self._numpy().tolist() # pylint: disable=protected-access -def enable_numpy_methods_on_tensor(): - """Adds additional NumPy methods on tf.Tensor class.""" +def _enable_numpy_methods(tensor_class): + """A helper method for adding additional NumPy methods.""" t = property(_tensor_t) - setattr(tensor.Tensor, 'T', t) + setattr(tensor_class, 'T', t) ndim = property(_tensor_ndim) - setattr(tensor.Tensor, 'ndim', ndim) + setattr(tensor_class, 'ndim', ndim) size = property(_tensor_size) - setattr(tensor.Tensor, 'size', size) + setattr(tensor_class, 'size', size) - setattr(tensor.Tensor, '__pos__', _tensor_pos) - setattr(tensor.Tensor, 'tolist', _tensor_tolist) + setattr(tensor_class, '__pos__', _tensor_pos) + setattr(tensor_class, 'tolist', _tensor_tolist) # TODO(b/178540516): Make a custom `setattr` that changes the method's # docstring to the TF one. - setattr(tensor.Tensor, 'transpose', np_array_ops.transpose) - setattr(tensor.Tensor, 'flatten', np_array_ops.flatten) - setattr(tensor.Tensor, 'reshape', np_array_ops._reshape_method_wrapper) # pylint: disable=protected-access - setattr(tensor.Tensor, 'ravel', np_array_ops.ravel) - setattr(tensor.Tensor, 'clip', clip) - setattr(tensor.Tensor, 'astype', math_ops.cast) - setattr(tensor.Tensor, '__round__', np_array_ops.around) - setattr(tensor.Tensor, 'max', np_array_ops.amax) - setattr(tensor.Tensor, 'mean', np_array_ops.mean) - setattr(tensor.Tensor, 'min', np_array_ops.amin) + setattr(tensor_class, 'transpose', np_array_ops.transpose) + setattr(tensor_class, 'flatten', np_array_ops.flatten) + setattr(tensor_class, 'reshape', np_array_ops._reshape_method_wrapper) # pylint: disable=protected-access + setattr(tensor_class, 'ravel', np_array_ops.ravel) + setattr(tensor_class, 'clip', clip) + setattr(tensor_class, 'astype', math_ops.cast) + setattr(tensor_class, '__round__', np_array_ops.around) + setattr(tensor_class, 'max', np_array_ops.amax) + setattr(tensor_class, 'mean', np_array_ops.mean) + setattr(tensor_class, 'min', np_array_ops.amin) # TODO(wangpeng): Remove `data` when all uses of it are removed data = property(lambda self: self) - setattr(tensor.Tensor, 'data', data) + setattr(tensor_class, 'data', data) + + +def enable_numpy_methods_on_tensor(): + """Adds additional NumPy methods on tf.Tensor class.""" + _enable_numpy_methods(tensor.Tensor) diff --git a/tensorflow/python/ops/weak_tensor_ops.py b/tensorflow/python/ops/weak_tensor_ops.py index 083dcc8ad580d0..c82fe80396a115 100644 --- a/tensorflow/python/ops/weak_tensor_ops.py +++ b/tensorflow/python/ops/weak_tensor_ops.py @@ -88,8 +88,9 @@ def wrapper(*args, **kwargs): # unsupported input type (e.g. CompositeTensor). except NotImplementedError: logging.warning( - "The new dtype semantics do not support this input dtype. Falling" - " back to old semantics." + "The new dtype semantics do not support" + f" {op.__module__}.{op.__name__}({type(x)}). Falling back to old" + " semantics." ) return op(**bound_kwargs) bound_kwargs[x_arg_name] = _convert_or_cast(x, target_type, "x") @@ -136,8 +137,9 @@ def wrapper(*args, **kwargs): # unsupported input type (e.g. CompositeTensor). except NotImplementedError: logging.warning( - "The new dtype semantics do not support this input dtype. Falling" - " back to old semantics." + "The new dtype semantics do not support" + f" {op.__module__}.{op.__name__}({type(x)}, {type(y)}). Falling back" + " to old semantics." ) return op(**bound_kwargs) @@ -475,3 +477,7 @@ def _update_weak_tensor_patched_ops_in_dispatch_dict(patched_op): weak_tensor.WeakTensor.__mod__ = gen_math_ops.floor_mod weak_tensor.WeakTensor.__pow__ = math_ops.pow weak_tensor.WeakTensor.__matmul__ = math_ops.matmul + +# Add/Update NumPy methods in Tensor and WeakTensor. +np_math_ops.enable_numpy_methods_on_tensor() +np_math_ops._enable_numpy_methods(weak_tensor.WeakTensor) # pylint: disable=protected-access diff --git a/tensorflow/python/ops/weak_tensor_ops_test.py b/tensorflow/python/ops/weak_tensor_ops_test.py index cdcdedd31603bd..18708068b5d75f 100644 --- a/tensorflow/python/ops/weak_tensor_ops_test.py +++ b/tensorflow/python/ops/weak_tensor_ops_test.py @@ -307,6 +307,60 @@ def my_abs(x: MyTensor): with self.assertRaises(ValueError): math_ops.abs(MyTensor(constant_op.constant(1.0))) + def testWeakTensorDunderMethods(self): + x = _get_weak_tensor([1, 2, 3]) + + self.assertIsInstance(abs(x), WeakTensor) + self.assertIsInstance(~x, WeakTensor) + self.assertIsInstance(-x, WeakTensor) + + @parameterized.parameters( + ("T", WeakTensor), + ("ndim", int), + ("size", None), + ("data", WeakTensor), + ) + def testNumpyAttributesOnWeakTensor(self, np_attribute, result_type): + a = weak_tensor_test_util.get_weak_tensor(([1, 2, 3])) + b = constant_op.constant([1, 2, 3]) + + self.assertTrue(hasattr(a, np_attribute)) + wt_np_attr = getattr(a, np_attribute) + t_np_attr = getattr(b, np_attribute) + if result_type is None: + # The result type may differ depending on which machine test runs on + # (e.g. size) + self.assertEqual(type(wt_np_attr), type(t_np_attr)) + else: + self.assertIsInstance(wt_np_attr, result_type) + self.assertAllEqual(wt_np_attr, t_np_attr) + + @parameterized.parameters( + ("__pos__", WeakTensor), + ("__round__", WeakTensor, 2), + ("tolist", list), + ("flatten", WeakTensor), + ("transpose", WeakTensor), + ("reshape", WeakTensor, (3, 1)), + ("ravel", WeakTensor), + ("clip", tensor.Tensor, 1.1, 2.2), + ("astype", tensor.Tensor, dtypes.float32), + ("max", WeakTensor), + ("mean", WeakTensor), + ("min", WeakTensor), + ) + def testNumpyMethodsOnWeakTensor(self, np_method, result_type, *args): + a = weak_tensor_test_util.get_weak_tensor(([1, 2, 3])) + b = constant_op.constant([1, 2, 3]) + self.assertTrue(hasattr(a, np_method)) + + wt_np_method_call = getattr(a, np_method) + t_np_method_call = getattr(b, np_method) + wt_np_result = wt_np_method_call(*args) + t_np_result = t_np_method_call(*args) + self.assertIsInstance(wt_np_result, result_type) + self.assertAllEqual(wt_np_result, t_np_result) + # TODO(b/289333658): Add tf.constant(x) with no dtype arg as a "weak" input # after adding WeakTensor construction logic to tf.constant. From 1459cc76dc8ba93eb343f8e124da7d880dda1852 Mon Sep 17 00:00:00 2001 From: Marcello Maggioni Date: Wed, 12 Jul 2023 16:08:10 -0700 Subject: [PATCH 229/376] [XLA] Make sure that we reset the main loop of whle_loop_all_reduce_motion after every movement to avoid call graph being stale. PiperOrigin-RevId: 547627456 --- .../xla/service/while_loop_all_reduce_code_motion.cc | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tensorflow/compiler/xla/service/while_loop_all_reduce_code_motion.cc b/tensorflow/compiler/xla/service/while_loop_all_reduce_code_motion.cc index c2472c66b21535..ca543015af2296 100644 --- a/tensorflow/compiler/xla/service/while_loop_all_reduce_code_motion.cc +++ b/tensorflow/compiler/xla/service/while_loop_all_reduce_code_motion.cc @@ -1050,6 +1050,11 @@ StatusOr WhileLoopAllReduceCodeMotion::Run( TF_RETURN_IF_ERROR(computation->ReplaceInstructionWithDifferentShape( all_reduce, all_reduce->mutable_operand(0))); } + // Needs to rebuild the call graph or we could access removed + // instructions. + if (run_next_pass) { + break; + } } } VLOG(2) << "Hoisted " << count_all_reduce << " all-reduce and " From 928ae0322c05cc58193e795d20dbe729e0831887 Mon Sep 17 00:00:00 2001 From: Terry Heo Date: Wed, 12 Jul 2023 17:07:29 -0700 Subject: [PATCH 230/376] Update build rule for WeightWatcher PiperOrigin-RevId: 547641748 --- tensorflow/BUILD | 9 +++++++++ tensorflow/core/common_runtime/gpu/BUILD | 25 +++++++++++++++--------- 2 files changed, 25 insertions(+), 9 deletions(-) diff --git a/tensorflow/BUILD b/tensorflow/BUILD index 53d444b410deec..bfda3dfb6dcd74 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -571,6 +571,15 @@ config_setting( visibility = ["//visibility:public"], ) +# This condition takes precedence over :linux_x86_64 +# TODO(b/290533709): Remove this with PJRT build rule cleanup. +config_setting( + name = "linux_x86_64_with_weightwatcher", + define_values = {"tensorflow_weightwatcher": "true"}, + values = {"cpu": "k8"}, + visibility = ["//visibility:public"], +) + config_setting( name = "linux_ppc64le", values = {"cpu": "ppc"}, diff --git a/tensorflow/core/common_runtime/gpu/BUILD b/tensorflow/core/common_runtime/gpu/BUILD index e41b383f7c58e8..d343d9a9ad9eb9 100644 --- a/tensorflow/core/common_runtime/gpu/BUILD +++ b/tensorflow/core/common_runtime/gpu/BUILD @@ -1,5 +1,6 @@ load( "//tensorflow:tensorflow.bzl", + "clean_dep", "if_cuda_or_rocm", "if_google", "if_linux_x86_64", @@ -192,15 +193,21 @@ tf_cuda_library( ] + if_google( # TODO(b/282068262): PJRT pulls in TFRT components that are incompatible with ARM platform. # Clean up so that PJRT can run on ARM. - if_linux_x86_64([ - "//tensorflow/compiler/tf2xla:layout_util", - "//tensorflow/compiler/jit:flags", - "//tensorflow/compiler/jit:pjrt_device_context", - "//tensorflow/compiler/xla/pjrt/gpu:gpu_helpers", - "//tensorflow/compiler/xla/pjrt/gpu:se_gpu_pjrt_client", - "//tensorflow/compiler/xla/stream_executor:tf_allocator_adapter", - "//tensorflow/core/tfrt/common:pjrt_util", - ]) + if_cuda_or_rocm([ + # Also it won't build with WeightWatcher which tracks OSS build binaries. + # TODO(b/290533709): Clean up this build rule. + select({ + clean_dep("//tensorflow:linux_x86_64_with_weightwatcher"): [], + clean_dep("//tensorflow:linux_x86_64"): [ + "//tensorflow/compiler/tf2xla:layout_util", + "//tensorflow/compiler/jit:flags", + "//tensorflow/compiler/jit:pjrt_device_context", + "//tensorflow/compiler/xla/pjrt/gpu:gpu_helpers", + "//tensorflow/compiler/xla/pjrt/gpu:se_gpu_pjrt_client", + "//tensorflow/compiler/xla/stream_executor:tf_allocator_adapter", + "//tensorflow/core/tfrt/common:pjrt_util", + ], + "//conditions:default": [], + }) + if_cuda_or_rocm([ "//tensorflow/compiler/xla/service:gpu_plugin_impl", # for registering cuda compiler. ]), ), From 21829ff5c713238a5d725cb1715f12d28a1ed18b Mon Sep 17 00:00:00 2001 From: Hyeontaek Lim Date: Wed, 12 Jul 2023 17:25:28 -0700 Subject: [PATCH 231/376] [IFRT] Add serialization/deserialization for shardings This change adds serialization/deserialization for the following IFRT sharding types: * `SingleDeviceSharding` * `OpaqueSharding` * `ConcreteSharding` * `ConcreteEvenSharding` * `HloSharding` `ShardingParamSharding` serialization/deserialization is not yet supported. PiperOrigin-RevId: 547645529 --- tensorflow/compiler/xla/python/ifrt/BUILD | 28 ++ tensorflow/compiler/xla/python/ifrt/device.cc | 24 ++ tensorflow/compiler/xla/python/ifrt/device.h | 11 + tensorflow/compiler/xla/python/ifrt/shape.cc | 25 ++ tensorflow/compiler/xla/python/ifrt/shape.h | 8 + .../compiler/xla/python/ifrt/sharding.cc | 26 +- .../compiler/xla/python/ifrt/sharding.h | 15 +- .../compiler/xla/python/ifrt/sharding.proto | 46 ++++ .../xla/python/ifrt/sharding_serdes.cc | 240 ++++++++++++++++++ .../xla/python/ifrt/sharding_serdes.h | 48 ++++ .../xla/python/ifrt/sharding_serdes_test.cc | 157 ++++++++++++ .../compiler/xla/python/ifrt/types.proto | 31 +++ .../compiler/xla/python/pjrt_ifrt/BUILD | 24 ++ .../xla/python/pjrt_ifrt/xla_sharding.proto | 27 ++ .../python/pjrt_ifrt/xla_sharding_serdes.cc | 79 ++++++ .../pjrt_ifrt/xla_sharding_serdes_test.cc | 95 +++++++ 16 files changed, 865 insertions(+), 19 deletions(-) create mode 100644 tensorflow/compiler/xla/python/ifrt/sharding.proto create mode 100644 tensorflow/compiler/xla/python/ifrt/sharding_serdes.cc create mode 100644 tensorflow/compiler/xla/python/ifrt/sharding_serdes.h create mode 100644 tensorflow/compiler/xla/python/ifrt/sharding_serdes_test.cc create mode 100644 tensorflow/compiler/xla/python/ifrt/types.proto create mode 100644 tensorflow/compiler/xla/python/pjrt_ifrt/xla_sharding.proto create mode 100644 tensorflow/compiler/xla/python/pjrt_ifrt/xla_sharding_serdes.cc create mode 100644 tensorflow/compiler/xla/python/pjrt_ifrt/xla_sharding_serdes_test.cc diff --git a/tensorflow/compiler/xla/python/ifrt/BUILD b/tensorflow/compiler/xla/python/ifrt/BUILD index ec8774cd5ef187..ba797b6b6b00eb 100644 --- a/tensorflow/compiler/xla/python/ifrt/BUILD +++ b/tensorflow/compiler/xla/python/ifrt/BUILD @@ -45,6 +45,7 @@ cc_library( "index_domain.cc", "shape.cc", "sharding.cc", + "sharding_serdes.cc", "tuple.cc", "value.cc", ], @@ -61,17 +62,21 @@ cc_library( "index_domain.h", "shape.h", "sharding.h", + "sharding_serdes.h", "tuple.h", "value.h", ], deps = [ ":serdes", + ":sharding_proto_cc", + ":types_proto_cc", "//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/pjrt:pjrt_client", "//tensorflow/compiler/xla/python/ifrt/ir", "//tensorflow/tsl/platform:logging", + "//tensorflow/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:inlined_vector", @@ -309,3 +314,26 @@ tf_proto_library( name = "serdes_proto", srcs = ["serdes.proto"], ) + +xla_cc_test( + name = "sharding_serdes_test", + srcs = ["sharding_serdes_test.cc"], + deps = [ + ":ifrt", + ":mock", + ":serdes", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_googletest//:gtest_main", + ], +) + +tf_proto_library( + name = "types_proto", + srcs = ["types.proto"], +) + +tf_proto_library( + name = "sharding_proto", + srcs = ["sharding.proto"], + protodeps = [":types_proto"], +) diff --git a/tensorflow/compiler/xla/python/ifrt/device.cc b/tensorflow/compiler/xla/python/ifrt/device.cc index 0f02149ae48a64..a549a811de6de6 100644 --- a/tensorflow/compiler/xla/python/ifrt/device.cc +++ b/tensorflow/compiler/xla/python/ifrt/device.cc @@ -15,11 +15,35 @@ limitations under the License. #include "tensorflow/compiler/xla/python/ifrt/device.h" +#include #include +#include "tensorflow/compiler/xla/python/ifrt/client.h" +#include "tensorflow/compiler/xla/python/ifrt/types.pb.h" + namespace xla { namespace ifrt { +StatusOr DeviceList::FromProto(Client* client, + const DeviceListProto& proto) { + DeviceList::Devices devices; + devices.reserve(proto.device_ids_size()); + for (int device_id : proto.device_ids()) { + TF_ASSIGN_OR_RETURN(Device * device, client->LookupDevice(device_id)); + devices.push_back(device); + } + return DeviceList(std::move(devices)); +} + +DeviceListProto DeviceList::ToProto() const { + DeviceListProto proto; + proto.mutable_device_ids()->Reserve(devices().size()); + for (Device* device : devices()) { + proto.mutable_device_ids()->AddAlreadyReserved(device->id()); + } + return proto; +} + std::vector GetDeviceIds(DeviceList device_list) { std::vector ids; ids.reserve(device_list.devices().size()); diff --git a/tensorflow/compiler/xla/python/ifrt/device.h b/tensorflow/compiler/xla/python/ifrt/device.h index a2d5f61dd35c2a..d54afa190deaa9 100644 --- a/tensorflow/compiler/xla/python/ifrt/device.h +++ b/tensorflow/compiler/xla/python/ifrt/device.h @@ -21,10 +21,13 @@ limitations under the License. #include "absl/container/inlined_vector.h" #include "tensorflow/compiler/xla/pjrt/pjrt_client.h" +#include "tensorflow/compiler/xla/python/ifrt/types.pb.h" namespace xla { namespace ifrt { +class Client; + // Short-term alias to reuse `xla::PjRtDevice` without a separate abstract type. using Device = ::xla::PjRtDevice; @@ -42,6 +45,14 @@ class DeviceList { explicit DeviceList(Devices devices) : devices_(std::move(devices)) {} + // Constructs `DeviceList` from `DeviceListProto`. Device ids in the proto + // must be consistent with the devices owned by `client'. + static StatusOr FromProto(Client* client, + const DeviceListProto& proto); + + // Returns a `DeviceListProto` representation. + DeviceListProto ToProto() const; + absl::Span devices() const { return devices_; } int size() const { return devices_.size(); } diff --git a/tensorflow/compiler/xla/python/ifrt/shape.cc b/tensorflow/compiler/xla/python/ifrt/shape.cc index bd3ff1fc8e08b6..07e8e2b81494a5 100644 --- a/tensorflow/compiler/xla/python/ifrt/shape.cc +++ b/tensorflow/compiler/xla/python/ifrt/shape.cc @@ -17,12 +17,37 @@ limitations under the License. #include #include +#include #include "absl/strings/str_join.h" +#include "tensorflow/compiler/xla/python/ifrt/types.pb.h" +#include "tensorflow/compiler/xla/util.h" namespace xla { namespace ifrt { +StatusOr Shape::FromProto(const ShapeProto& proto) { + Shape::Dimensions dims; + dims.reserve(proto.dims_size()); + for (int64_t dim : proto.dims()) { + if (dim < 0) { + return InvalidArgument( + "Shape expects non-negative dimension sizes, but got %d", dim); + } + dims.push_back(dim); + } + return Shape(std::move(dims)); +} + +ShapeProto Shape::ToProto() const { + ShapeProto proto; + proto.mutable_dims()->Reserve(dims().size()); + for (int64_t dim : dims()) { + proto.mutable_dims()->AddAlreadyReserved(dim); + } + return proto; +} + int64_t Shape::num_elements() const { int64_t count = 1; for (int64_t d : dims_) { diff --git a/tensorflow/compiler/xla/python/ifrt/shape.h b/tensorflow/compiler/xla/python/ifrt/shape.h index 3558e3518ed84d..f3ce028789d5ef 100644 --- a/tensorflow/compiler/xla/python/ifrt/shape.h +++ b/tensorflow/compiler/xla/python/ifrt/shape.h @@ -22,6 +22,8 @@ limitations under the License. #include "absl/container/inlined_vector.h" #include "absl/types/span.h" +#include "tensorflow/compiler/xla/python/ifrt/types.pb.h" +#include "tensorflow/compiler/xla/statusor.h" namespace xla { namespace ifrt { @@ -42,6 +44,12 @@ class Shape { Shape& operator=(const Shape&) = default; Shape& operator=(Shape&&) = default; + // Constructs `Shape` from `ShapeProto`. + static StatusOr FromProto(const ShapeProto& proto); + + // Returns a `ShapeProto` representation. + ShapeProto ToProto() const; + absl::Span dims() const { return dims_; } bool operator==(const Shape& other) const { return dims_ == other.dims_; } diff --git a/tensorflow/compiler/xla/python/ifrt/sharding.cc b/tensorflow/compiler/xla/python/ifrt/sharding.cc index f057ad53fccf83..8caaf9f12a83e4 100644 --- a/tensorflow/compiler/xla/python/ifrt/sharding.cc +++ b/tensorflow/compiler/xla/python/ifrt/sharding.cc @@ -159,8 +159,10 @@ std::ostream& operator<<(std::ostream& os, const Sharding& sharding) { return os << sharding.DebugString(); } -std::unique_ptr SingleDeviceSharding::Create(Device* device) { - return std::unique_ptr(new SingleDeviceSharding(device)); +std::unique_ptr SingleDeviceSharding::Create( + Device* device) { + return std::unique_ptr( + new SingleDeviceSharding(device)); } StatusOr>>> @@ -187,8 +189,9 @@ std::string SingleDeviceSharding::DebugString() const { devices_.front()->ToString()); } -std::unique_ptr OpaqueSharding::Create(DeviceList devices) { - return std::unique_ptr(new OpaqueSharding(std::move(devices))); +std::unique_ptr OpaqueSharding::Create(DeviceList devices) { + return std::unique_ptr( + new OpaqueSharding(std::move(devices))); } OpaqueSharding::OpaqueSharding(DeviceList devices) @@ -217,10 +220,10 @@ std::string OpaqueSharding::DebugString() const { })); } -std::unique_ptr ConcreteSharding::Create( +std::unique_ptr ConcreteSharding::Create( DeviceList devices, Shape shape, std::vector shard_shapes) { CHECK_EQ(devices.size(), shard_shapes.size()); - return std::unique_ptr(new ConcreteSharding( + return std::unique_ptr(new ConcreteSharding( std::move(devices), std::move(shape), std::move(shard_shapes))); } @@ -270,10 +273,9 @@ std::string ConcreteSharding::DebugString() const { })); } -std::unique_ptr ConcreteEvenSharding::Create(DeviceList devices, - Shape shape, - Shape shard_shape) { - return std::unique_ptr(new ConcreteEvenSharding( +std::unique_ptr ConcreteEvenSharding::Create( + DeviceList devices, Shape shape, Shape shard_shape) { + return std::unique_ptr(new ConcreteEvenSharding( std::move(devices), std::move(shape), std::move(shard_shape))); } @@ -318,7 +320,7 @@ std::string ConcreteEvenSharding::DebugString() const { shape_.DebugString(), shard_shape_.DebugString()); } -StatusOr> ShardingParamSharding::Create( +StatusOr> ShardingParamSharding::Create( ShardingParam sharding_param, DeviceList devices) { int64_t device_count = absl::c_accumulate(sharding_param.minor_to_major().axis_sizes, 1, @@ -329,7 +331,7 @@ StatusOr> ShardingParamSharding::Create( "%d", device_count, devices.size()); } - return std::unique_ptr( + return std::unique_ptr( new ShardingParamSharding(std::move(sharding_param), std::move(devices))); } diff --git a/tensorflow/compiler/xla/python/ifrt/sharding.h b/tensorflow/compiler/xla/python/ifrt/sharding.h index 375cedc16a0a68..6e3d30e99d2584 100644 --- a/tensorflow/compiler/xla/python/ifrt/sharding.h +++ b/tensorflow/compiler/xla/python/ifrt/sharding.h @@ -83,7 +83,7 @@ class SingleDeviceSharding final : public llvm::RTTIExtends { public: // Creates a single-device sharding. - static std::unique_ptr Create(Device* device); + static std::unique_ptr Create(Device* device); // Sharding implementation. @@ -110,7 +110,7 @@ class SingleDeviceSharding final class OpaqueSharding : public llvm::RTTIExtends { public: // Creates an opaque sharding. `Disassemble()` will fail. - static std::unique_ptr Create(DeviceList devices); + static std::unique_ptr Create(DeviceList devices); // Sharding implementation. @@ -138,8 +138,8 @@ class ConcreteSharding : public llvm::RTTIExtends { public: // Creates a concrete sharding that may contain non-identical shard shapes. // REQUIRES: devices.size() == shard_shapes.size() - static std::unique_ptr Create(DeviceList devices, Shape shape, - std::vector shard_shapes); + static std::unique_ptr Create( + DeviceList devices, Shape shape, std::vector shard_shapes); Shape shape() const { DCHECK(this); @@ -179,8 +179,9 @@ class ConcreteEvenSharding : public llvm::RTTIExtends { public: // Creates a concrete even sharding. - static std::unique_ptr Create(DeviceList devices, Shape shape, - Shape shard_shape); + static std::unique_ptr Create(DeviceList devices, + Shape shape, + Shape shard_shape); Shape shape() const { DCHECK(this); @@ -216,7 +217,7 @@ class ConcreteEvenSharding class ShardingParamSharding : public llvm::RTTIExtends { public: - static StatusOr> Create( + static StatusOr> Create( ShardingParam sharding_param, DeviceList devices); StatusOr>>> diff --git a/tensorflow/compiler/xla/python/ifrt/sharding.proto b/tensorflow/compiler/xla/python/ifrt/sharding.proto new file mode 100644 index 00000000000000..066bce11413998 --- /dev/null +++ b/tensorflow/compiler/xla/python/ifrt/sharding.proto @@ -0,0 +1,46 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto3"; + +package xla.ifrt; + +import "tensorflow/compiler/xla/python/ifrt/types.proto"; + +// Wire format for `SingleDeviceSharding`. +message SingleDeviceShardingProto { + // Serialization and deserialization are expected to ensure that device ids + // are stable across proto construction and consumption. + int32 device_id = 1; +} + +// Wire format for `OpaqueSharding`. +message OpaqueShardingProto { + DeviceListProto devices = 1; +} + +// Wire format for `ConcreteSharding`. +message ConcreteShardingProto { + DeviceListProto devices = 1; + ShapeProto shape = 2; + repeated ShapeProto shard_shapes = 3; +} + +// Wire format for `ConcreteEvenSharding`. +message ConcreteEvenShardingProto { + DeviceListProto devices = 1; + ShapeProto shape = 2; + ShapeProto shard_shape = 3; +} diff --git a/tensorflow/compiler/xla/python/ifrt/sharding_serdes.cc b/tensorflow/compiler/xla/python/ifrt/sharding_serdes.cc new file mode 100644 index 00000000000000..d9ade8d6a62b96 --- /dev/null +++ b/tensorflow/compiler/xla/python/ifrt/sharding_serdes.cc @@ -0,0 +1,240 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/python/ifrt/sharding_serdes.h" + +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/python/ifrt/client.h" +#include "tensorflow/compiler/xla/python/ifrt/device.h" +#include "tensorflow/compiler/xla/python/ifrt/serdes.h" +#include "tensorflow/compiler/xla/python/ifrt/shape.h" +#include "tensorflow/compiler/xla/python/ifrt/sharding.h" +#include "tensorflow/compiler/xla/python/ifrt/sharding.pb.h" +#include "tensorflow/tsl/platform/statusor.h" + +namespace xla { +namespace ifrt { + +char DeserializeShardingOptions::ID = 0; + +namespace { + +// Serialization/deserialization for `SingleDeviceSharding`. +class SingleDeviceShardingSerDes + : public llvm::RTTIExtends { + public: + absl::string_view type_name() const override { + return "xla::ifrt::SingleDeviceSharding"; + } + + absl::StatusOr Serialize(Serializable& serializable) override { + const SingleDeviceSharding& sharding = + llvm::cast(serializable); + SingleDeviceShardingProto proto; + proto.set_device_id(sharding.devices().front()->id()); + return proto.SerializeAsString(); + } + + absl::StatusOr> Deserialize( + const std::string& serialized, + std::unique_ptr options) override { + TF_ASSIGN_OR_RETURN(auto deserialize_sharding_options, + GetDeserializeShardingOptions(std::move(options))); + SingleDeviceShardingProto proto; + if (!proto.ParseFromString(serialized)) { + return absl::InvalidArgumentError( + "Failed to parse serialized SimpleDeviceSharding"); + } + TF_ASSIGN_OR_RETURN( + Device * device, + deserialize_sharding_options->client->LookupDevice(proto.device_id())); + return SingleDeviceSharding::Create(device); + } + + static char ID; // NOLINT +}; + +// Serialization/deserialization for `OpaqueSharding`. +class OpaqueShardingSerDes + : public llvm::RTTIExtends { + public: + absl::string_view type_name() const override { + return "xla::ifrt::OpaqueSharding"; + } + + absl::StatusOr Serialize(Serializable& serializable) override { + const OpaqueSharding& sharding = llvm::cast(serializable); + OpaqueShardingProto proto; + *proto.mutable_devices() = sharding.devices().ToProto(); + return proto.SerializeAsString(); + } + + absl::StatusOr> Deserialize( + const std::string& serialized, + std::unique_ptr options) override { + TF_ASSIGN_OR_RETURN(auto deserialize_sharding_options, + GetDeserializeShardingOptions(std::move(options))); + + OpaqueShardingProto proto; + if (!proto.ParseFromString(serialized)) { + return absl::InvalidArgumentError( + "Failed to parse serialized OpaqueSharding"); + } + TF_ASSIGN_OR_RETURN(auto devices, DeviceList::FromProto( + deserialize_sharding_options->client, + proto.devices())); + return OpaqueSharding::Create(std::move(devices)); + } + + static char ID; // NOLINT +}; + +// Serialization/deserialization for `ConcreteSharding`. +class ConcreteShardingSerDes + : public llvm::RTTIExtends { + public: + absl::string_view type_name() const override { + return "xla::ifrt::ConcreteSharding"; + } + + absl::StatusOr Serialize(Serializable& serializable) override { + const ConcreteSharding& sharding = + llvm::cast(serializable); + ConcreteShardingProto proto; + *proto.mutable_devices() = sharding.devices().ToProto(); + *proto.mutable_shape() = sharding.shape().ToProto(); + for (const Shape& shape : sharding.shard_shapes()) { + *proto.add_shard_shapes() = shape.ToProto(); + } + return proto.SerializeAsString(); + } + + absl::StatusOr> Deserialize( + const std::string& serialized, + std::unique_ptr options) override { + TF_ASSIGN_OR_RETURN(auto deserialize_sharding_options, + GetDeserializeShardingOptions(std::move(options))); + + ConcreteShardingProto proto; + if (!proto.ParseFromString(serialized)) { + return absl::InvalidArgumentError( + "Failed to parse serialized ConcreteSharding"); + } + TF_ASSIGN_OR_RETURN(auto devices, DeviceList::FromProto( + deserialize_sharding_options->client, + proto.devices())); + TF_ASSIGN_OR_RETURN(auto shape, Shape::FromProto(proto.shape())); + std::vector shard_shapes; + shard_shapes.reserve(proto.shard_shapes_size()); + for (const auto& shard_shape_proto : proto.shard_shapes()) { + TF_ASSIGN_OR_RETURN(auto shard_shape, + Shape::FromProto(shard_shape_proto)); + shard_shapes.push_back(std::move(shard_shape)); + } + return ConcreteSharding::Create(std::move(devices), std::move(shape), + std::move(shard_shapes)); + } + + static char ID; // NOLINT +}; + +// Serialization/deserialization for `ConcreteEvenSharding`. +class ConcreteEvenShardingSerDes + : public llvm::RTTIExtends { + public: + absl::string_view type_name() const override { + return "xla::ifrt::ConcreteEvenSharding"; + } + + absl::StatusOr Serialize(Serializable& serializable) override { + const ConcreteEvenSharding& sharding = + llvm::cast(serializable); + ConcreteEvenShardingProto proto; + *proto.mutable_devices() = sharding.devices().ToProto(); + *proto.mutable_shape() = sharding.shape().ToProto(); + *proto.mutable_shard_shape() = sharding.shard_shape().ToProto(); + return proto.SerializeAsString(); + } + + absl::StatusOr> Deserialize( + const std::string& serialized, + std::unique_ptr options) override { + TF_ASSIGN_OR_RETURN(auto deserialize_sharding_options, + GetDeserializeShardingOptions(std::move(options))); + + ConcreteEvenShardingProto proto; + if (!proto.ParseFromString(serialized)) { + return absl::InvalidArgumentError( + "Failed to parse serialized ConcreteEvenSharding"); + } + TF_ASSIGN_OR_RETURN(auto devices, DeviceList::FromProto( + deserialize_sharding_options->client, + proto.devices())); + TF_ASSIGN_OR_RETURN(auto shape, Shape::FromProto(proto.shape())); + TF_ASSIGN_OR_RETURN(auto shard_shape, + Shape::FromProto(proto.shard_shape())); + return ConcreteEvenSharding::Create(std::move(devices), std::move(shape), + std::move(shard_shape)); + } + + static char ID; // NOLINT +}; + +// TODO(hyeontaek): Implement `ShardingParamShardingSerDes`. + +[[maybe_unused]] char SingleDeviceShardingSerDes::ID = 0; // NOLINT +[[maybe_unused]] char OpaqueShardingSerDes::ID = 0; // NOLINT +[[maybe_unused]] char ConcreteShardingSerDes::ID = 0; // NOLINT +[[maybe_unused]] char ConcreteEvenShardingSerDes::ID = 0; // NOLINT + +// clang-format off +bool register_single_device_sharding_serdes = ([]{ + RegisterSerDes( + std::make_unique()); +}(), true); + +bool register_opaque_sharding_serdes = ([]{ + RegisterSerDes( + std::make_unique()); +}(), true); + +bool register_concrete_sharding_serdes = ([]{ + RegisterSerDes( + std::make_unique()); +}(), true); + +bool register_concrete_even_sharding_serdes = ([]{ + RegisterSerDes( + std::make_unique()); +}(), true); +// clang-format on + +} // namespace + +StatusOr> +GetDeserializeShardingOptions(std::unique_ptr options) { + if (!llvm::isa(options.get())) { + return xla::InvalidArgument("options must be DeserializeShardingOptions"); + } + return std::unique_ptr( + static_cast(options.release())); +} + +} // namespace ifrt +} // namespace xla diff --git a/tensorflow/compiler/xla/python/ifrt/sharding_serdes.h b/tensorflow/compiler/xla/python/ifrt/sharding_serdes.h new file mode 100644 index 00000000000000..965670bcbc3401 --- /dev/null +++ b/tensorflow/compiler/xla/python/ifrt/sharding_serdes.h @@ -0,0 +1,48 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_IFRT_SHARDING_SERDES_H_ +#define TENSORFLOW_COMPILER_XLA_PYTHON_IFRT_SHARDING_SERDES_H_ + +#include + +#include "llvm/Support/ExtensibleRTTI.h" +#include "tensorflow/compiler/xla/python/ifrt/serdes.h" +#include "tensorflow/compiler/xla/statusor.h" + +namespace xla { +namespace ifrt { + +class Client; + +// Options for deserializing shardings. +struct DeserializeShardingOptions + : llvm::RTTIExtends { + explicit DeserializeShardingOptions(Client* client) : client(client) {} + + static char ID; // NOLINT + + // The client whose devices will be used by deserialized shardings. + Client* client; +}; + +// Casts `DeserializeOptions` into `DeserializeShardingOptions`. +StatusOr> +GetDeserializeShardingOptions(std::unique_ptr options); + +} // namespace ifrt +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_PYTHON_IFRT_SHARDING_SERDES_H_ diff --git a/tensorflow/compiler/xla/python/ifrt/sharding_serdes_test.cc b/tensorflow/compiler/xla/python/ifrt/sharding_serdes_test.cc new file mode 100644 index 00000000000000..90efc6d9667167 --- /dev/null +++ b/tensorflow/compiler/xla/python/ifrt/sharding_serdes_test.cc @@ -0,0 +1,157 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/python/ifrt/sharding_serdes.h" + +#include +#include +#include + +#include +#include +#include "absl/container/flat_hash_map.h" +#include "tensorflow/compiler/xla/python/ifrt/mock.h" +#include "tensorflow/compiler/xla/python/ifrt/serdes.h" +#include "tensorflow/compiler/xla/python/ifrt/sharding.h" + +namespace xla { +namespace ifrt { +namespace { + +using ::testing::ElementsAreArray; + +// Test fixture for sharding serialization and deserialization. It makes a mock +// client with a number of fake devices. Client implements `devices()` and +// `LookupDevice()`, and Device implements `id()`, with an arbitrary device ids +// assigned. +class ShardingSerDesTest : public ::testing::TestWithParam { + public: + void SetUp() override { + const int num_devices = GetParam(); + device_map_.reserve(num_devices); + devices_.reserve(num_devices); + for (int i = 0; i < num_devices; ++i) { + auto device = std::make_unique(); + ON_CALL(*device, id).WillByDefault([i]() { return i + 10; }); + devices_.push_back(device.get()); + device_map_.insert({i + 10, std::move(device)}); + } + client_ = std::make_unique(); + ON_CALL(*client_, devices) + .WillByDefault( + [this]() -> absl::Span { return devices_; }); + ON_CALL(*client_, LookupDevice) + .WillByDefault([this](int device_id) -> StatusOr { + auto it = device_map_.find(device_id); + if (it == device_map_.end()) { + return InvalidArgument("Unexpected device id: %d", device_id); + } + return it->second.get(); + }); + } + Client* client() { return client_.get(); } + + private: + std::unique_ptr client_; + absl::flat_hash_map> device_map_; + std::vector devices_; +}; + +TEST_P(ShardingSerDesTest, SingleDeviceShardingRoundTrip) { + auto sharding = SingleDeviceSharding::Create(client()->devices().front()); + + TF_ASSERT_OK_AND_ASSIGN(auto serialized, Serialize(*sharding)); + + auto deserialized_options = + std::make_unique(client()); + TF_ASSERT_OK_AND_ASSIGN( + auto deserialized, + Deserialize(serialized, std::move(deserialized_options))); + + const auto* out_sharding = + llvm::dyn_cast(deserialized.get()); + ASSERT_NE(out_sharding, nullptr); + EXPECT_THAT(out_sharding->devices(), ElementsAreArray(sharding->devices())); +} + +TEST_P(ShardingSerDesTest, OpaqueShardingRoundTrip) { + auto sharding = OpaqueSharding::Create(DeviceList(DeviceList::Devices( + client()->devices().begin(), client()->devices().end()))); + + TF_ASSERT_OK_AND_ASSIGN(auto serialized, Serialize(*sharding)); + + auto deserialized_options = + std::make_unique(client()); + TF_ASSERT_OK_AND_ASSIGN( + auto deserialized, + Deserialize(serialized, std::move(deserialized_options))); + + const auto* out_sharding = llvm::dyn_cast(deserialized.get()); + ASSERT_NE(out_sharding, nullptr); + EXPECT_THAT(out_sharding->devices(), ElementsAreArray(sharding->devices())); +} + +TEST_P(ShardingSerDesTest, ConcreteShardingRoundTrip) { + auto sharding = ConcreteSharding::Create( + DeviceList(DeviceList::Devices(client()->devices().begin(), + client()->devices().end())), + /*shape=*/Shape({10, 20}), + /*shard_shapes=*/{Shape({3, 20}), Shape({7, 20})}); + + TF_ASSERT_OK_AND_ASSIGN(auto serialized, Serialize(*sharding)); + + auto deserialized_options = + std::make_unique(client()); + TF_ASSERT_OK_AND_ASSIGN( + auto deserialized, + Deserialize(serialized, std::move(deserialized_options))); + + const auto* out_sharding = + llvm::dyn_cast(deserialized.get()); + ASSERT_NE(out_sharding, nullptr); + EXPECT_THAT(out_sharding->devices(), ElementsAreArray(sharding->devices())); + EXPECT_THAT(out_sharding->shape(), sharding->shape()); + EXPECT_THAT(out_sharding->shard_shapes(), + ElementsAreArray(sharding->shard_shapes())); +} + +TEST_P(ShardingSerDesTest, ConcreteEvenShardingRoundTrip) { + auto sharding = ConcreteEvenSharding::Create( + DeviceList(DeviceList::Devices(client()->devices().begin(), + client()->devices().end())), + /*shape=*/Shape({10, 20}), + /*shard_shape=*/Shape({5, 20})); + + TF_ASSERT_OK_AND_ASSIGN(auto serialized, Serialize(*sharding)); + + auto deserialized_options = + std::make_unique(client()); + TF_ASSERT_OK_AND_ASSIGN( + auto deserialized, + Deserialize(serialized, std::move(deserialized_options))); + + const auto* out_sharding = + llvm::dyn_cast(deserialized.get()); + ASSERT_NE(out_sharding, nullptr); + EXPECT_THAT(out_sharding->devices(), ElementsAreArray(sharding->devices())); + EXPECT_THAT(out_sharding->shape(), sharding->shape()); + EXPECT_THAT(out_sharding->shard_shape(), sharding->shard_shape()); +} + +INSTANTIATE_TEST_SUITE_P(NumDevices, ShardingSerDesTest, testing::Values(2)); + +} // namespace +} // namespace ifrt +} // namespace xla diff --git a/tensorflow/compiler/xla/python/ifrt/types.proto b/tensorflow/compiler/xla/python/ifrt/types.proto new file mode 100644 index 00000000000000..e9c799bcc1ed6c --- /dev/null +++ b/tensorflow/compiler/xla/python/ifrt/types.proto @@ -0,0 +1,31 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto3"; + +package xla.ifrt; + +// Wire format for `DeviceList`. +message DeviceListProto { + // Serialization and deserialization are expected to ensure that device ids + // are stable across proto construction and consumption. + repeated int32 device_ids = 1; +} + +// Wire format for `Shape`. Currently support static shapes with all dimension +// sizes greater than or equal to 0. +message ShapeProto { + repeated int64 dims = 1; +} diff --git a/tensorflow/compiler/xla/python/pjrt_ifrt/BUILD b/tensorflow/compiler/xla/python/pjrt_ifrt/BUILD index e1c1a36bb4ef62..ba7f217ec28e3f 100644 --- a/tensorflow/compiler/xla/python/pjrt_ifrt/BUILD +++ b/tensorflow/compiler/xla/python/pjrt_ifrt/BUILD @@ -36,6 +36,7 @@ cc_library( srcs = [ "xla_compiler.cc", "xla_sharding.cc", + "xla_sharding_serdes.cc", ], hdrs = [ "xla_compiler.h", @@ -43,7 +44,9 @@ cc_library( ], deps = [ ":xla_compiler_proto_cc", + ":xla_sharding_proto_cc", "//tensorflow/compiler/xla:xla_data_proto_cc", + "//tensorflow/compiler/xla/hlo/ir:hlo", "//tensorflow/compiler/xla/pjrt:pjrt_executable", "//tensorflow/compiler/xla/python/ifrt", "//tensorflow/compiler/xla/python/ifrt:serdes", @@ -110,6 +113,27 @@ xla_cc_test( ], ) +tf_proto_library( + name = "xla_sharding_proto", + srcs = ["xla_sharding.proto"], + protodeps = [ + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/python/ifrt:types_proto", + ], +) + +xla_cc_test( + name = "xla_sharding_serdes_test", + srcs = ["xla_sharding_serdes_test.cc"], + deps = [ + ":xla_ifrt", + "//tensorflow/compiler/xla/hlo/ir:hlo", + "//tensorflow/compiler/xla/python/ifrt", + "//tensorflow/compiler/xla/python/ifrt:mock", + "@com_google_googletest//:gtest_main", + ], +) + # TODO(hyeontaek): Move this target out of pjrt_ifrt. cc_library( name = "xla_executable_impl_test_lib", diff --git a/tensorflow/compiler/xla/python/pjrt_ifrt/xla_sharding.proto b/tensorflow/compiler/xla/python/pjrt_ifrt/xla_sharding.proto new file mode 100644 index 00000000000000..0ff8040b66233e --- /dev/null +++ b/tensorflow/compiler/xla/python/pjrt_ifrt/xla_sharding.proto @@ -0,0 +1,27 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto3"; + +package xla.ifrt; + +import "tensorflow/compiler/xla/python/ifrt/types.proto"; +import "tensorflow/compiler/xla/xla_data.proto"; + +// Wire format for `HloSharding`. +message HloShardingProto { + DeviceListProto devices = 1; + xla.OpSharding xla_op_sharding = 2; +} diff --git a/tensorflow/compiler/xla/python/pjrt_ifrt/xla_sharding_serdes.cc b/tensorflow/compiler/xla/python/pjrt_ifrt/xla_sharding_serdes.cc new file mode 100644 index 00000000000000..c3d8d2470600b9 --- /dev/null +++ b/tensorflow/compiler/xla/python/pjrt_ifrt/xla_sharding_serdes.cc @@ -0,0 +1,79 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "tensorflow/compiler/xla/hlo/ir/hlo_sharding.h" +#include "tensorflow/compiler/xla/python/ifrt/serdes.h" +#include "tensorflow/compiler/xla/python/ifrt/sharding_serdes.h" +#include "tensorflow/compiler/xla/python/pjrt_ifrt/xla_sharding.h" +#include "tensorflow/compiler/xla/python/pjrt_ifrt/xla_sharding.pb.h" + +namespace xla { +namespace ifrt { + +namespace { + +// Serialization/deserialization for `HloSharding`. +class HloShardingSerDes : public llvm::RTTIExtends { + public: + absl::string_view type_name() const override { + return "xla::ifrt::HloSharding"; + } + + absl::StatusOr Serialize(Serializable& serializable) override { + const HloSharding& sharding = llvm::cast(serializable); + HloShardingProto proto; + *proto.mutable_devices() = sharding.devices().ToProto(); + *proto.mutable_xla_op_sharding() = sharding.xla_hlo_sharding().ToProto(); + return proto.SerializeAsString(); + } + + absl::StatusOr> Deserialize( + const std::string& serialized, + std::unique_ptr options) override { + TF_ASSIGN_OR_RETURN(auto deserialize_sharding_options, + GetDeserializeShardingOptions(std::move(options))); + + HloShardingProto proto; + if (!proto.ParseFromString(serialized)) { + return absl::InvalidArgumentError( + "Failed to parse serialized HloSharding"); + } + TF_ASSIGN_OR_RETURN(auto devices, DeviceList::FromProto( + deserialize_sharding_options->client, + proto.devices())); + TF_ASSIGN_OR_RETURN(auto xla_hlo_sharding, + xla::HloSharding::FromProto(proto.xla_op_sharding())); + return HloSharding::Create(std::move(devices), std::move(xla_hlo_sharding)); + } + + static char ID; // NOLINT +}; + +[[maybe_unused]] char HloShardingSerDes::ID = 0; // NOLINT + +// clang-format off +bool register_hlo_sharding_serdes = ([] { + RegisterSerDes( + std::make_unique()); +}(), true); +// clang-format on + +} // namespace +} // namespace ifrt +} // namespace xla diff --git a/tensorflow/compiler/xla/python/pjrt_ifrt/xla_sharding_serdes_test.cc b/tensorflow/compiler/xla/python/pjrt_ifrt/xla_sharding_serdes_test.cc new file mode 100644 index 00000000000000..e043fb7e575f2b --- /dev/null +++ b/tensorflow/compiler/xla/python/pjrt_ifrt/xla_sharding_serdes_test.cc @@ -0,0 +1,95 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include +#include +#include "tensorflow/compiler/xla/hlo/ir/hlo_sharding.h" +#include "tensorflow/compiler/xla/python/ifrt/mock.h" +#include "tensorflow/compiler/xla/python/ifrt/sharding_serdes.h" +#include "tensorflow/compiler/xla/python/pjrt_ifrt/xla_sharding.h" + +namespace xla { +namespace ifrt { +namespace { + +using ::testing::ElementsAreArray; + +// Test fixture for sharding serialization and deserialization. It makes a mock +// client with a number of fake devices. Client implements `devices()` and +// `LookupDevice()`, and Device implements `id()`, with an arbitrary device ids +// assigned. +class XlaShardingSerDesTest : public ::testing::TestWithParam { + public: + void SetUp() override { + const int num_devices = GetParam(); + device_map_.reserve(num_devices); + devices_.reserve(num_devices); + for (int i = 0; i < num_devices; ++i) { + auto device = std::make_unique(); + ON_CALL(*device, id).WillByDefault([i]() { return i + 10; }); + devices_.push_back(device.get()); + device_map_.insert({i + 10, std::move(device)}); + } + client_ = std::make_unique(); + ON_CALL(*client_, devices) + .WillByDefault( + [this]() -> absl::Span { return devices_; }); + ON_CALL(*client_, LookupDevice) + .WillByDefault([this](int device_id) -> StatusOr { + auto it = device_map_.find(device_id); + if (it == device_map_.end()) { + return InvalidArgument("Unexpected device id: %d", device_id); + } + return it->second.get(); + }); + } + Client* client() { return client_.get(); } + + private: + std::unique_ptr client_; + absl::flat_hash_map> device_map_; + std::vector devices_; +}; + +TEST_P(XlaShardingSerDesTest, HloShardingRoundTrip) { + auto xla_hlo_sharding = xla::HloSharding::Tile(xla::TileAssignment({2, 1})); + auto sharding = HloSharding::Create( + DeviceList(DeviceList::Devices(client()->devices().begin(), + client()->devices().end())), + /*xla_hlo_sharding=*/xla_hlo_sharding); + + TF_ASSERT_OK_AND_ASSIGN(auto serialized, Serialize(*sharding)); + + auto deserialized_options = + std::make_unique(client()); + TF_ASSERT_OK_AND_ASSIGN( + auto deserialized, + Deserialize(serialized, std::move(deserialized_options))); + + const auto* out_sharding = llvm::dyn_cast(deserialized.get()); + ASSERT_NE(out_sharding, nullptr); + EXPECT_THAT(out_sharding->devices(), ElementsAreArray(sharding->devices())); + EXPECT_EQ(out_sharding->xla_hlo_sharding(), sharding->xla_hlo_sharding()); +} + +INSTANTIATE_TEST_SUITE_P(NumDevices, XlaShardingSerDesTest, testing::Values(2)); + +} // namespace +} // namespace ifrt +} // namespace xla From 726d6cf760403c1a0ba31e54856915507c5f2b60 Mon Sep 17 00:00:00 2001 From: Fergus Henderson Date: Wed, 12 Jul 2023 17:47:37 -0700 Subject: [PATCH 232/376] Add TfLiteInterpreterOptionsSetOpResolverExternalWithFallback. This combines the effects of TfLiteInterpreterOptionsSetOpResolverExternal and TfLiteInterpreterOptionsSetOpResolver, allowing use of an op resolver that uses TfLiteRegistrationExternal for user-defined custom ops while still continuing to use the builtin op resolver, which uses TfLiteRegistration objects (that don't set the TfLiteRegistrationExternal field), for builtin ops. This change required modifying the order in which we consider the fields of CallbackOpResolver in c_api.cc so that we consider the _external fields first, before the other ones. Previously the order didn't matter since these fields were mutually exclusive, but now that TfLiteInterpreterOptionsSetOpResolverExternal sets both the _external and regular fields, we want to consider the _external fields first, so that the op resolver callbacks that return TfLiteRegistrationExternal take precedence over the ones that return TfLiteRegistration, as specified in the documentation for the new API function TfLiteInterpreterOptionsSetOpResolverExternalWithFallback. PiperOrigin-RevId: 547649220 --- tensorflow/lite/core/c/BUILD | 5 +- tensorflow/lite/core/c/c_api.cc | 92 ++++--- tensorflow/lite/core/c/c_api_experimental.cc | 22 ++ tensorflow/lite/core/c/c_api_experimental.h | 29 +++ .../lite/core/c/c_api_experimental_test.cc | 237 +++++++++++++++++- 5 files changed, 341 insertions(+), 44 deletions(-) diff --git a/tensorflow/lite/core/c/BUILD b/tensorflow/lite/core/c/BUILD index a0e7a3f6640600..972df7b7677582 100644 --- a/tensorflow/lite/core/c/BUILD +++ b/tensorflow/lite/core/c/BUILD @@ -416,7 +416,10 @@ cc_test( size = "small", srcs = ["c_api_experimental_test.cc"], copts = tflite_copts(), - data = ["//tensorflow/lite:testdata/add.bin"], + data = [ + "//tensorflow/lite:testdata/add.bin", + "//tensorflow/lite:testdata/custom_sinh.bin", + ], deps = [ ":c_api", ":c_api_experimental", diff --git a/tensorflow/lite/core/c/c_api.cc b/tensorflow/lite/core/c/c_api.cc index 7880445c3fccdc..fd07b58478b4fd 100644 --- a/tensorflow/lite/core/c/c_api.cc +++ b/tensorflow/lite/core/c/c_api.cc @@ -298,13 +298,6 @@ static TfLiteRegistration* RegistrationExternalToRegistration( // FindOp for builtin op query. const TfLiteRegistration* CallbackOpResolver::FindOp(tflite::BuiltinOperator op, int version) const { - // Use Registration V3 API to find op. - if (op_resolver_callbacks_.find_builtin_op) { - return op_resolver_callbacks_.find_builtin_op( - op_resolver_callbacks_.user_data, - static_cast(op), version); - } - // Check if cached Registration is available. std::lock_guard lock(mutex_); for (const auto& created_registration : temporary_builtin_registrations_) { @@ -314,50 +307,58 @@ const TfLiteRegistration* CallbackOpResolver::FindOp(tflite::BuiltinOperator op, } } + // Try using newer RegistrationExternal API. + if (op_resolver_callbacks_.find_builtin_op_external) { + // Get a RegistrationExternal object and create a Registration (V3) object. + const TfLiteRegistrationExternal* registration_external = + op_resolver_callbacks_.find_builtin_op_external( + op_resolver_callbacks_.user_data, + static_cast(op), version); + if (registration_external && (registration_external->init != nullptr || + registration_external->free != nullptr || + registration_external->invoke != nullptr || + registration_external->prepare != nullptr)) { + TfLiteRegistration* new_registration = + RegistrationExternalToRegistration(registration_external); + temporary_builtin_registrations_.push_back( + std::unique_ptr(new_registration)); + return new_registration; + } + } + + // Use Registration V4 API to find op. + if (op_resolver_callbacks_.find_builtin_op) { + return op_resolver_callbacks_.find_builtin_op( + op_resolver_callbacks_.user_data, + static_cast(op), version); + } + // Try using older Registration V3 API to find op. if (auto* registration = BuildBuiltinOpFromLegacyRegistration( op, version, op_resolver_callbacks_.find_builtin_op_v3); registration) { return registration; } + // Try using older Registration V2 API to find op. if (auto* registration = BuildBuiltinOpFromLegacyRegistration( op, version, op_resolver_callbacks_.find_builtin_op_v2); registration) { return registration; } + // Try using older Registration V1 API to find op. if (auto* registration = BuildBuiltinOpFromLegacyRegistration( op, version, op_resolver_callbacks_.find_builtin_op_v1); registration) { return registration; } - // Try using newer RegistrationExternal API. - if (op_resolver_callbacks_.find_builtin_op_external) { - // Get a RegistrationExternal object and create a Registration (V3) object. - const TfLiteRegistrationExternal* registration_external = - op_resolver_callbacks_.find_builtin_op_external( - op_resolver_callbacks_.user_data, - static_cast(op), version); - if (registration_external) { - TfLiteRegistration* new_registration = - RegistrationExternalToRegistration(registration_external); - temporary_builtin_registrations_.push_back( - std::unique_ptr(new_registration)); - return new_registration; - } - } return nullptr; } // FindOp for custom op query. const TfLiteRegistration* CallbackOpResolver::FindOp(const char* op, int version) const { - // Use TfLiteRegistration API to find op. - if (op_resolver_callbacks_.find_custom_op) { - return op_resolver_callbacks_.find_custom_op( - op_resolver_callbacks_.user_data, op, version); - } // Check if cached Registration is available. std::lock_guard lock(mutex_); for (const auto& created_registration : temporary_custom_registrations_) { @@ -367,37 +368,48 @@ const TfLiteRegistration* CallbackOpResolver::FindOp(const char* op, } } + if (op_resolver_callbacks_.find_custom_op_external) { + // Get a RegistrationExternal object and create a Registration (V3) object. + const TfLiteRegistrationExternal* registration_external = + op_resolver_callbacks_.find_custom_op_external( + op_resolver_callbacks_.user_data, op, version); + if (registration_external && (registration_external->init != nullptr || + registration_external->free != nullptr || + registration_external->invoke != nullptr || + registration_external->prepare != nullptr)) { + TfLiteRegistration* new_registration = + RegistrationExternalToRegistration(registration_external); + temporary_builtin_registrations_.push_back( + std::unique_ptr(new_registration)); + return new_registration; + } + } + // Use TfLiteRegistration V4 API to find op. + if (op_resolver_callbacks_.find_custom_op) { + return op_resolver_callbacks_.find_custom_op( + op_resolver_callbacks_.user_data, op, version); + } + // Use older TfLiteRegistration V3 API to find op. if (auto* registration = BuildCustomOpFromLegacyRegistration( op, version, op_resolver_callbacks_.find_custom_op_v3); registration) { return registration; } + // Use older TfLiteRegistration V2 API to find op. if (auto* registration = BuildCustomOpFromLegacyRegistration( op, version, op_resolver_callbacks_.find_custom_op_v2); registration) { return registration; } + // Use even older TfLiteRegistration V1 API to find op. if (auto* registration = BuildCustomOpFromLegacyRegistration( op, version, op_resolver_callbacks_.find_custom_op_v1); registration) { return registration; } - if (op_resolver_callbacks_.find_custom_op_external) { - // Get a RegistrationExternal object and create a Registration (V2) object. - const TfLiteRegistrationExternal* registration_external = - op_resolver_callbacks_.find_custom_op_external( - op_resolver_callbacks_.user_data, op, version); - if (registration_external) { - TfLiteRegistration* new_registration = - RegistrationExternalToRegistration(registration_external); - temporary_builtin_registrations_.push_back( - std::unique_ptr(new_registration)); - return new_registration; - } - } return nullptr; } diff --git a/tensorflow/lite/core/c/c_api_experimental.cc b/tensorflow/lite/core/c/c_api_experimental.cc index 45a8d0f99241a1..a6c5baa6c9a066 100644 --- a/tensorflow/lite/core/c/c_api_experimental.cc +++ b/tensorflow/lite/core/c/c_api_experimental.cc @@ -75,6 +75,28 @@ void TfLiteInterpreterOptionsSetOpResolverExternal( options->op_resolver_callbacks.user_data = op_resolver_user_data; } +void TfLiteInterpreterOptionsSetOpResolverExternalWithFallback( + TfLiteInterpreterOptions* options, + const TfLiteRegistrationExternal* (*find_builtin_op_external)( + void* user_data, int op, int version), + const TfLiteRegistrationExternal* (*find_custom_op_external)( + void* user_data, const char* custom_op, int version), + const TfLiteRegistration* (*find_builtin_op)(void* user_data, + TfLiteBuiltinOperator op, + int version), + const TfLiteRegistration* (*find_custom_op)(void* user_data, const char* op, + int version), + void* op_resolver_user_data) { + options->op_resolver_callbacks = {}; // Sets all fields to null. + options->op_resolver_callbacks.find_builtin_op_external = + find_builtin_op_external; + options->op_resolver_callbacks.find_custom_op_external = + find_custom_op_external; + options->op_resolver_callbacks.find_builtin_op = find_builtin_op; + options->op_resolver_callbacks.find_custom_op = find_custom_op; + options->op_resolver_callbacks.user_data = op_resolver_user_data; +} + void TfLiteInterpreterOptionsSetOpResolver( TfLiteInterpreterOptions* options, const TfLiteRegistration* (*find_builtin_op)(void* user_data, diff --git a/tensorflow/lite/core/c/c_api_experimental.h b/tensorflow/lite/core/c/c_api_experimental.h index f87e0226baf62a..f766931931bfc2 100644 --- a/tensorflow/lite/core/c/c_api_experimental.h +++ b/tensorflow/lite/core/c/c_api_experimental.h @@ -101,6 +101,7 @@ TFL_CAPI_EXPORT void TfLiteInterpreterOptionsAddCustomOp( /// The `TfLiteInterpreterOptionsSetOpResolverExternal` function provides an /// alternative method for registering builtin ops and/or custom ops, by /// providing operator resolver callbacks. Unlike using +/// `TfLiteInterpreterOptionsAddRegistrationExternal`, /// `TfLiteInterpreterOptionsAddBuiltinOp` and/or /// `TfLiteInterpreterOptionsAddAddCustomOp`, these let you register all the /// operators in a single call. @@ -126,6 +127,34 @@ void TfLiteInterpreterOptionsSetOpResolverExternal( int version), void* op_resolver_user_data); +/// \private +/// Registers callbacks for resolving builtin or custom operators. +/// +/// This combines the effects of TfLiteInterpreterOptionsSetOpResolverExternal +/// and TfLiteInterpreterOptionsSetOpResolver. The callbacks that return +/// TfLiteRegistrationExternal will be called first, but if they return a +/// TfLiteRegistrationExternal object that has no methods set, then +/// the callbacks that return a TfLiteRegistration will be called to get +/// the methods. +/// +/// WARNING: This function is experimental and subject to change. +/// +/// WARNING: This function is not an official part of the API, +/// and should not be used by apps. It is intended for use only from +/// TF Lite itself. +void TfLiteInterpreterOptionsSetOpResolverExternalWithFallback( + TfLiteInterpreterOptions* options, + const TfLiteRegistrationExternal* (*find_builtin_op_external)( + void* user_data, int op, int version), + const TfLiteRegistrationExternal* (*find_custom_op_external)( + void* user_data, const char* custom_op, int version), + const TfLiteRegistration* (*find_builtin_op)(void* user_data, + TfLiteBuiltinOperator op, + int version), + const TfLiteRegistration* (*find_custom_op)(void* user_data, const char* op, + int version), + void* op_resolver_user_data); + /// Registers callbacks for resolving builtin or custom operators. /// /// The `TfLiteInterpreterOptionsSetOpResolver` function provides an alternative diff --git a/tensorflow/lite/core/c/c_api_experimental_test.cc b/tensorflow/lite/core/c/c_api_experimental_test.cc index f1d045d737c748..9b05252e1ed139 100644 --- a/tensorflow/lite/core/c/c_api_experimental_test.cc +++ b/tensorflow/lite/core/c/c_api_experimental_test.cc @@ -15,9 +15,9 @@ limitations under the License. #include "tensorflow/lite/core/c/c_api_experimental.h" -#include - #include +#include +#include #include #include @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/core/platform/resource_loader.h" #include "tensorflow/lite/builtin_ops.h" #include "tensorflow/lite/core/c/c_api.h" +#include "tensorflow/lite/core/c/c_api_opaque.h" #include "tensorflow/lite/core/c/common.h" #include "tensorflow/lite/delegates/delegate_test_util.h" #include "tensorflow/lite/testing/util.h" @@ -206,7 +207,72 @@ const TfLiteRegistrationExternal* MyFindCustomOpExternal(void*, return nullptr; } -// Test using TfLiteInterpreterCreateWithSelectedOps. +TfLiteStatus SinhPrepareOpaque(TfLiteOpaqueContext*, TfLiteOpaqueNode*) { + return kTfLiteOk; +} + +TfLiteStatus SinhEvalOpaque(TfLiteOpaqueContext* context, + TfLiteOpaqueNode* node) { + EXPECT_EQ(1, TfLiteOpaqueNodeNumberOfInputs(node)); + const TfLiteOpaqueTensor* input = TfLiteOpaqueNodeGetInput(context, node, 0); + size_t input_bytes = TfLiteOpaqueTensorByteSize(input); + const void* data_ptr = TfLiteOpaqueTensorData(input); + float input_value; + std::memcpy(&input_value, data_ptr, input_bytes); + + EXPECT_EQ(1, TfLiteOpaqueNodeNumberOfOutputs(node)); + TfLiteOpaqueTensor* output = TfLiteOpaqueNodeGetOutput(context, node, 0); + float output_value = std::sinh(input_value); + TfLiteOpaqueTensorCopyFromBuffer(output, &output_value, sizeof(output_value)); + return kTfLiteOk; +} + +TfLiteStatus SinhPrepare(TfLiteContext*, TfLiteNode*) { return kTfLiteOk; } + +TfLiteStatus SinhEval(TfLiteContext* context, TfLiteNode* node) { + EXPECT_EQ(1, node->inputs->size); + const TfLiteTensor* input = &context->tensors[node->inputs->data[0]]; + size_t input_bytes = TfLiteTensorByteSize(input); + const void* data_ptr = TfLiteTensorData(input); + float input_value; + std::memcpy(&input_value, data_ptr, input_bytes); + + EXPECT_EQ(1, node->outputs->size); + TfLiteTensor* output = &context->tensors[node->outputs->data[0]]; + float output_value = std::sinh(input_value); + TfLiteTensorCopyFromBuffer(output, &output_value, sizeof(output_value)); + return kTfLiteOk; +} + +const TfLiteRegistrationExternal* SinhFindCustomOpExternal( + void*, const char* custom_op, int version) { + if (absl::string_view(custom_op) == "Sinh" && version == 1) { + static TfLiteRegistrationExternal* registration = []() { + TfLiteRegistrationExternal* reg = + TfLiteRegistrationExternalCreate(kTfLiteBuiltinCustom, "Sinh", 1); + TfLiteRegistrationExternalSetPrepare(reg, &SinhPrepareOpaque); + TfLiteRegistrationExternalSetInvoke(reg, &SinhEvalOpaque); + return reg; + }(); + return registration; + } + return nullptr; +} + +const TfLiteRegistration* SinhFindCustomOp(void*, const char* custom_op, + int version) { + if (absl::string_view(custom_op) == "Sinh" && version == 1) { + static const TfLiteRegistration registration{/*init=*/nullptr, + /*free=*/nullptr, + /*prepare=*/SinhPrepare, + /*invoke=*/SinhEval}; + return ®istration; + } + return nullptr; +} + +// Test using TfLiteInterpreterOptionsSetOpResolverExternal and +// TfLiteInterpreterCreateWithSelectedOps. TEST(CApiExperimentalTest, SetOpResolverExternal) { TfLiteModel* model = TfLiteModelCreateFromFile( tensorflow::GetDataDependencyFilepath("tensorflow/lite/testdata/add.bin") @@ -233,6 +299,171 @@ TEST(CApiExperimentalTest, SetOpResolverExternal) { TfLiteModelDelete(model); } +// Test using TfLiteInterpreterOptionsSetOpResolverExternalWithFallback and +// TfLiteInterpreterCreateWithSelectedOps, for a builtin op, for the normal +// case where the op is found with the primary op resolver callback that returns +// a TfLiteRegistrationExternal pointer. +TEST(CApiExperimentalTest, + SetOpResolverExternalWithFallback_BuiltinOp_NormalCase) { + TfLiteModel* model = TfLiteModelCreateFromFile( + tensorflow::GetDataDependencyFilepath("tensorflow/lite/testdata/add.bin") + .c_str()); + ASSERT_NE(model, nullptr); + + TfLiteInterpreterOptions* options = TfLiteInterpreterOptionsCreate(); + + OpResolverData my_data; + TfLiteInterpreterOptionsSetOpResolverExternalWithFallback( + options, MyFindBuiltinOpExternal, MyFindCustomOpExternal, + [](void* user_data, TfLiteBuiltinOperator op, + int version) -> const TfLiteRegistration* { return nullptr; }, + [](void* user_data, const char* custom_op, + int version) -> const TfLiteRegistration* { return nullptr; }, + &my_data); + EXPECT_FALSE(my_data.called_for_add); + + TfLiteInterpreter* interpreter = + TfLiteInterpreterCreateWithSelectedOps(model, options); + ASSERT_NE(interpreter, nullptr); + ASSERT_EQ(TfLiteInterpreterAllocateTensors(interpreter), kTfLiteOk); + EXPECT_EQ(TfLiteInterpreterResetVariableTensors(interpreter), kTfLiteOk); + EXPECT_EQ(TfLiteInterpreterInvoke(interpreter), kTfLiteOk); + EXPECT_TRUE(my_data.called_for_add); + + TfLiteInterpreterDelete(interpreter); + TfLiteInterpreterOptionsDelete(options); + TfLiteModelDelete(model); +} + +// Test using TfLiteInterpreterOptionsSetOpResolverExternalWithFallback and +// TfLiteInterpreterCreateWithSelectedOps, for a builtin op, for the fallback +// case where the op is found with the secondary op resolver callback that +// returns a TfLiteRegistration pointer. +TEST(CApiExperimentalTest, + SetOpResolverExternalWithFallback_BuiltinOp_FallbackCase) { + TfLiteModel* model = TfLiteModelCreateFromFile( + tensorflow::GetDataDependencyFilepath("tensorflow/lite/testdata/add.bin") + .c_str()); + ASSERT_NE(model, nullptr); + + TfLiteInterpreterOptions* options = TfLiteInterpreterOptionsCreate(); + + OpResolverData my_data; + TfLiteInterpreterOptionsSetOpResolverExternalWithFallback( + options, + [](void* user_data, int op, + int version) -> const TfLiteRegistrationExternal* { return nullptr; }, + [](void* user_data, const char* custom_op, + int version) -> const TfLiteRegistrationExternal* { return nullptr; }, + MyFindBuiltinOp, MyFindCustomOp, &my_data); + EXPECT_FALSE(my_data.called_for_add); + + TfLiteInterpreter* interpreter = + TfLiteInterpreterCreateWithSelectedOps(model, options); + ASSERT_NE(interpreter, nullptr); + ASSERT_EQ(TfLiteInterpreterAllocateTensors(interpreter), kTfLiteOk); + EXPECT_EQ(TfLiteInterpreterResetVariableTensors(interpreter), kTfLiteOk); + EXPECT_EQ(TfLiteInterpreterInvoke(interpreter), kTfLiteOk); + EXPECT_TRUE(my_data.called_for_add); + + TfLiteInterpreterDelete(interpreter); + TfLiteInterpreterOptionsDelete(options); + TfLiteModelDelete(model); +} + +// Test using TfLiteInterpreterOptionsSetOpResolverExternalWithFallback and +// TfLiteInterpreterCreateWithSelectedOps, for a custom op, for the normal +// case where the op is found with the primary op resolver callback that returns +// a TfLiteRegistrationExternal pointer. +TEST(CApiExperimentalTest, + SetOpResolverExternalWithFallback_CustomOp_NormalCase) { + TfLiteModel* model = + TfLiteModelCreateFromFile(tensorflow::GetDataDependencyFilepath( + "tensorflow/lite/testdata/custom_sinh.bin") + .c_str()); + ASSERT_NE(model, nullptr); + + TfLiteInterpreterOptions* options = TfLiteInterpreterOptionsCreate(); + + OpResolverData my_data; + TfLiteInterpreterOptionsSetOpResolverExternalWithFallback( + options, MyFindBuiltinOpExternal, SinhFindCustomOpExternal, + [](void* user_data, TfLiteBuiltinOperator op, + int version) -> const TfLiteRegistration* { return nullptr; }, + [](void* user_data, const char* custom_op, + int version) -> const TfLiteRegistration* { return nullptr; }, + &my_data); + EXPECT_FALSE(my_data.called_for_add); + + TfLiteInterpreter* interpreter = + TfLiteInterpreterCreateWithSelectedOps(model, options); + ASSERT_NE(interpreter, nullptr); + ASSERT_EQ(TfLiteInterpreterAllocateTensors(interpreter), kTfLiteOk); + + TfLiteTensor* input_tensor = TfLiteInterpreterGetInputTensor(interpreter, 0); + const float input_value = 1.0f; + TfLiteTensorCopyFromBuffer(input_tensor, &input_value, sizeof(float)); + + EXPECT_EQ(TfLiteInterpreterInvoke(interpreter), kTfLiteOk); + + const TfLiteTensor* output_tensor = + TfLiteInterpreterGetOutputTensor(interpreter, 0); + float output_value; + TfLiteTensorCopyToBuffer(output_tensor, &output_value, sizeof(float)); + EXPECT_EQ(output_value, std::sinh(input_value)); + + TfLiteInterpreterDelete(interpreter); + TfLiteInterpreterOptionsDelete(options); + TfLiteModelDelete(model); +} + +// Test using TfLiteInterpreterOptionsSetOpResolverExternalWithFallback and +// TfLiteInterpreterCreateWithSelectedOps, for a custom op, for the fallback +// case where the op is found with the secondary op resolver callback that +// returns a TfLiteRegistration pointer. +TEST(CApiExperimentalTest, + SetOpResolverExternalWithFallback_CustomOp_FallbackCase) { + TfLiteModel* model = + TfLiteModelCreateFromFile(tensorflow::GetDataDependencyFilepath( + "tensorflow/lite/testdata/custom_sinh.bin") + .c_str()); + ASSERT_NE(model, nullptr); + + TfLiteInterpreterOptions* options = TfLiteInterpreterOptionsCreate(); + + OpResolverData my_data; + TfLiteInterpreterOptionsSetOpResolverExternalWithFallback( + options, + [](void* user_data, int op, + int version) -> const TfLiteRegistrationExternal* { return nullptr; }, + [](void* user_data, const char* custom_op, + int version) -> const TfLiteRegistrationExternal* { return nullptr; }, + MyFindBuiltinOp, SinhFindCustomOp, &my_data); + EXPECT_FALSE(my_data.called_for_add); + + TfLiteInterpreter* interpreter = + TfLiteInterpreterCreateWithSelectedOps(model, options); + ASSERT_NE(interpreter, nullptr); + ASSERT_EQ(TfLiteInterpreterAllocateTensors(interpreter), kTfLiteOk); + + TfLiteTensor* input_tensor = TfLiteInterpreterGetInputTensor(interpreter, 0); + const float input_value = 1.0f; + TfLiteTensorCopyFromBuffer(input_tensor, &input_value, sizeof(float)); + + EXPECT_EQ(TfLiteInterpreterInvoke(interpreter), kTfLiteOk); + EXPECT_FALSE(my_data.called_for_add); + + const TfLiteTensor* output_tensor = + TfLiteInterpreterGetOutputTensor(interpreter, 0); + float output_value; + TfLiteTensorCopyToBuffer(output_tensor, &output_value, sizeof(float)); + EXPECT_EQ(output_value, std::sinh(input_value)); + + TfLiteInterpreterDelete(interpreter); + TfLiteInterpreterOptionsDelete(options); + TfLiteModelDelete(model); +} + void AllocateAndSetInputs(TfLiteInterpreter* interpreter) { std::array input_dims = {2}; ASSERT_EQ(TfLiteInterpreterResizeInputTensor( From 2d938b329f49a48190e42ba42c54c51794dac52a Mon Sep 17 00:00:00 2001 From: Mahmoud Abuzaina Date: Wed, 12 Jul 2023 17:54:24 -0700 Subject: [PATCH 233/376] Addressed review comments --- tensorflow/core/kernels/mkl/mkl_conv_ops.cc | 23 ++++++++++----------- 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/tensorflow/core/kernels/mkl/mkl_conv_ops.cc b/tensorflow/core/kernels/mkl/mkl_conv_ops.cc index f13ecbeb6613ed..cd6e42976aa5ff 100644 --- a/tensorflow/core/kernels/mkl/mkl_conv_ops.cc +++ b/tensorflow/core/kernels/mkl/mkl_conv_ops.cc @@ -532,11 +532,10 @@ class MklConvFwdPrimitive : public MklPrimitive { #ifdef ENABLE_ONEDNN_V3 if (is_scale_set["src"] && is_scale_set["wei"] && is_scale_set["dst"]) { net_args.insert( - {DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, *context_.src_scale_mem}); - net_args.insert( - {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, *context_.wei_scale_mem}); - net_args.insert( - {DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST, *context_.dst_scale_mem}); + {{DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, *context_.src_scale_mem}, + {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, *context_.wei_scale_mem}, + { DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST, + *context_.dst_scale_mem }}); } #endif // ENABLE_ONEDNN_V3 } else if (!convFwdDims.fuse_bn_dims.empty()) { @@ -569,11 +568,10 @@ class MklConvFwdPrimitive : public MklPrimitive { #ifdef ENABLE_ONEDNN_V3 if (is_scale_set["src"] && is_scale_set["wei"] && is_scale_set["dst"]) { net_args.insert( - {DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, *context_.src_scale_mem}); - net_args.insert( - {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, *context_.wei_scale_mem}); - net_args.insert( - {DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST, *context_.dst_scale_mem}); + {{DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, *context_.src_scale_mem}, + {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, *context_.wei_scale_mem}, + { DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST, + *context_.dst_scale_mem }}); } #endif // ENABLE_ONEDNN_V3 } @@ -2528,10 +2526,10 @@ class MklQuantizedConvOp const float* max_filter = max_filter_vector.flat().data(); const float int_const_scale_limit = (std::is_same::value) ? 255.0 * 127.0 : 127.0 * 127.0; + // Re-scale bias if either of following 2 conditions are met: // 1. Bias is not const; // 2. Bias is const, bias has not been cached (first iteration). - size_t depth = min_filter_vector.NumElements(); bool scales_are_valid = (depth == scales_.size()); scales_.resize(depth); @@ -2564,9 +2562,10 @@ class MklQuantizedConvOp input_bias_->set_data_handle(bias_buf); } - if (!scaled_bias_buf_) + if (!scaled_bias_buf_) { AllocTmpBuffer(context, &scaled_bias_tensor_, conv_fwd_pd->bias_desc(), &scaled_bias_buf_); + } if (!scaled_bias_) { scaled_bias_ = new memory(conv_fwd_pd->bias_desc(), this->cpu_engine_, scaled_bias_buf_); From 177aa6482409ab588b7cded97759c224e68d3721 Mon Sep 17 00:00:00 2001 From: David Dunleavy Date: Wed, 12 Jul 2023 17:59:22 -0700 Subject: [PATCH 234/376] Delete `tsl_gpu_cc_test` and replace all users with `xla_cc_test`. `tsl_gpu_cc_test` is intended to be used when the test can run under CPU, GPU, and 2 GPU configurations, but all users only wanted to run the test under the single GPU config. So, there were no users of the unique functionality, so this should be an NFC. PiperOrigin-RevId: 547651026 --- .../compiler/xla/backends/profiler/gpu/BUILD | 10 +- .../compiler/xla/stream_executor/cuda/BUILD | 31 +++--- tensorflow/compiler/xla/xla.bzl | 1 + tensorflow/tensorflow.bzl | 2 +- tensorflow/tsl/tsl.default.bzl | 100 ------------------ 5 files changed, 22 insertions(+), 122 deletions(-) diff --git a/tensorflow/compiler/xla/backends/profiler/gpu/BUILD b/tensorflow/compiler/xla/backends/profiler/gpu/BUILD index a44c5735b5b8f9..d166b63e3fd9c1 100644 --- a/tensorflow/compiler/xla/backends/profiler/gpu/BUILD +++ b/tensorflow/compiler/xla/backends/profiler/gpu/BUILD @@ -5,7 +5,6 @@ load( "tsl_copts", "tsl_gpu_library", ) -load("//tensorflow/tsl:tsl.default.bzl", "tsl_gpu_cc_test") load( "//tensorflow/tsl/platform:build_config.bzl", "tf_additional_device_tracer_srcs", @@ -23,6 +22,10 @@ load( "//tensorflow/tsl/platform/default:cuda_build_defs.bzl", "if_cuda_is_configured", ) +load( + "//tensorflow/compiler/xla:xla.bzl", + "xla_cc_test", +) package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -107,10 +110,11 @@ tsl_gpu_library( ], ) -tsl_gpu_cc_test( +xla_cc_test( name = "cupti_error_manager_test", size = "small", srcs = ["cupti_error_manager_test.cc"], + copts = tf_profiler_copts() + tsl_copts(), tags = tf_cuda_tests_tags() + [ "gpu_cupti", "nomac", @@ -125,9 +129,7 @@ tsl_gpu_cc_test( ":cupti_wrapper", ":mock_cupti", "@com_google_absl//absl/memory", - "//tensorflow/tsl/platform:env_impl", "//tensorflow/tsl/profiler/utils:time_utils", - "//tensorflow/tsl/profiler/backends/cpu:annotation_stack_impl", ]), ) diff --git a/tensorflow/compiler/xla/stream_executor/cuda/BUILD b/tensorflow/compiler/xla/stream_executor/cuda/BUILD index 1eba55bea9d9d7..f3ad4282a83848 100644 --- a/tensorflow/compiler/xla/stream_executor/cuda/BUILD +++ b/tensorflow/compiler/xla/stream_executor/cuda/BUILD @@ -2,7 +2,6 @@ # CUDA-platform specific StreamExecutor support code. load("//tensorflow/tsl:tsl.bzl", "if_google", "set_external_visibility", "tsl_copts") -load("//tensorflow/tsl:tsl.default.bzl", "tsl_gpu_cc_test") load( "//tensorflow/compiler/xla/stream_executor:build_defs.bzl", "stream_executor_friends", @@ -25,6 +24,10 @@ load( "//tensorflow/tsl/platform:rules_cc.bzl", "cc_library", ) +load( + "//tensorflow/compiler/xla:xla.bzl", + "xla_cc_test", +) package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -127,49 +130,49 @@ cc_library( ], ) -tsl_gpu_cc_test( +xla_cc_test( name = "stream_search_test", size = "small", srcs = ["stream_search_test.cc"], + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), tags = tf_cuda_tests_tags(), deps = [ + ":cuda_platform", "//tensorflow/compiler/xla/stream_executor", - "//tensorflow/compiler/xla/stream_executor:stream_executor_impl", "//tensorflow/compiler/xla/stream_executor/host:host_platform", - "//tensorflow/tsl/platform:env_impl", "//tensorflow/tsl/platform:test", "//tensorflow/tsl/platform:test_main", ], ) -tsl_gpu_cc_test( +xla_cc_test( name = "cuda_driver_test", srcs = ["cuda_driver_test.cc"], + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), tags = tf_cuda_tests_tags() + [ "no_cuda_asan", # TODO(b/171512140): re-enable. "no_rocm", ], deps = [ ":cuda_driver", - "//tensorflow/tsl/platform:env_impl", "//tensorflow/tsl/platform:test", "//tensorflow/tsl/platform:test_main", "@local_config_cuda//cuda:cuda_headers", ], ) -tsl_gpu_cc_test( +xla_cc_test( name = "memcpy_test", srcs = ["memcpy_test.cc"], + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), tags = tf_cuda_tests_tags() + [ "no_cuda_asan", # TODO(b/171512140): re-enable. ], deps = [ + ":cuda_platform", "//tensorflow/compiler/xla/stream_executor", "//tensorflow/compiler/xla/stream_executor:device_memory", "//tensorflow/compiler/xla/stream_executor:multi_platform_manager", - "//tensorflow/compiler/xla/stream_executor:stream_executor_impl", - "//tensorflow/tsl/platform:env_impl", "//tensorflow/tsl/platform:test", "//tensorflow/tsl/platform:test_main", ], @@ -606,30 +609,24 @@ cc_library( ), ) -tsl_gpu_cc_test( +xla_cc_test( name = "redzone_allocator_test", srcs = ["redzone_allocator_test.cc"], + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), tags = tf_cuda_tests_tags() + [ "no_cuda_asan", # TODO(b/171512140): re-enable. ], deps = [ ":cuda_activation", ":cuda_gpu_executor", - ":stream_executor_cuda", "//tensorflow/compiler/xla/stream_executor", "//tensorflow/compiler/xla/stream_executor:device_memory_allocator", "//tensorflow/compiler/xla/stream_executor:event", "//tensorflow/compiler/xla/stream_executor:kernel", - "//tensorflow/compiler/xla/stream_executor:stream_executor_impl", "//tensorflow/compiler/xla/stream_executor/gpu:gpu_asm_opts", "//tensorflow/compiler/xla/stream_executor/gpu:redzone_allocator", - "//tensorflow/tsl/framework:allocator", - "//tensorflow/tsl/framework:allocator_registry_impl", "//tensorflow/tsl/lib/core:status_test_util", - "//tensorflow/tsl/platform:env_impl", "//tensorflow/tsl/platform:test", "//tensorflow/tsl/platform:test_main", - "//tensorflow/tsl/profiler/backends/cpu:traceme_recorder_impl", - "//tensorflow/tsl/profiler/utils:time_utils_impl", ], ) diff --git a/tensorflow/compiler/xla/xla.bzl b/tensorflow/compiler/xla/xla.bzl index 445c7d4ed46138..603469afa22efb 100644 --- a/tensorflow/compiler/xla/xla.bzl +++ b/tensorflow/compiler/xla/xla.bzl @@ -100,6 +100,7 @@ def xla_cc_test( clean_dep("//tensorflow/tsl/profiler/utils:time_utils_impl"), clean_dep("//tensorflow/tsl/profiler/backends/cpu:annotation_stack_impl"), clean_dep("//tensorflow/tsl/profiler/backends/cpu:traceme_recorder_impl"), + clean_dep("//tensorflow/tsl/profiler/protobuf:xplane_proto_cc_impl"), clean_dep("//tensorflow/compiler/xla:autotuning_proto_cc_impl"), clean_dep("//tensorflow/tsl/protobuf:dnn_proto_cc_impl"), clean_dep("//tensorflow/tsl/protobuf:protos_all_cc_impl"), diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl index af8e82b0adce41..c3885d0bd84d18 100644 --- a/tensorflow/tensorflow.bzl +++ b/tensorflow/tensorflow.bzl @@ -1578,7 +1578,7 @@ register_extension_info( label_regex_for_dep = "{extension_name}", ) -# TODO(jakeharmon): Replace with or implement in terms of tsl_gpu_cc_test, which doesn't add a +# TODO(jakeharmon): Replace with an implementation which doesn't add a # dependency on core:common_runtime def tf_gpu_cc_test( name, diff --git a/tensorflow/tsl/tsl.default.bzl b/tensorflow/tsl/tsl.default.bzl index 1d339f95d1a6c8..b4f01e6aec0efe 100644 --- a/tensorflow/tsl/tsl.default.bzl +++ b/tensorflow/tsl/tsl.default.bzl @@ -2,8 +2,6 @@ load( "//tensorflow/tsl:tsl.bzl", - "clean_dep", - "two_gpu_tags", _filegroup = "filegroup", _get_compatible_with_portable = "get_compatible_with_portable", _if_not_mobile_or_arm_or_lgpl_restricted = "if_not_mobile_or_arm_or_lgpl_restricted", @@ -11,18 +9,6 @@ load( _tsl_grpc_cc_dependencies = "tsl_grpc_cc_dependencies", _tsl_pybind_extension = "tsl_pybind_extension", ) -load( - "//tensorflow/tsl/platform:build_config.bzl", - "tsl_cc_test", -) -load( - "//tensorflow/tsl/platform:build_config_root.bzl", - "tf_gpu_tests_tags", -) -load( - "@local_config_cuda//cuda:build_defs.bzl", - "if_cuda", -) get_compatible_with_portable = _get_compatible_with_portable filegroup = _filegroup @@ -30,89 +16,3 @@ if_not_mobile_or_arm_or_lgpl_restricted = _if_not_mobile_or_arm_or_lgpl_restrict internal_hlo_deps = _internal_hlo_deps tsl_grpc_cc_dependencies = _tsl_grpc_cc_dependencies tsl_pybind_extension = _tsl_pybind_extension - -def tsl_gpu_cc_test( - name, - srcs = [], - deps = [], - tags = [], - data = [], - size = "medium", - linkstatic = 0, - args = [], - linkopts = [], - **kwargs): - """Create tests for cpu, gpu and optionally 2gpu - - Args: - name: unique name for this test target. - srcs: list of C and C++ files that are processed to create the binary target. - deps: list of other libraries to be linked in to the binary target. - tags: useful for categorizing the tests - data: files needed by this rule at runtime. - size: classification of how much time/resources the test requires. - linkstatic: link the binary in static mode. - args: command line arguments that Bazel passes to the target. - linkopts: add these flags to the C++ linker command. - **kwargs: Extra arguments to the rule. - """ - targets = [] - tsl_cc_test( - name = name + "_cpu", - size = size, - srcs = srcs, - args = args, - data = data, - copts = if_cuda(["-DNV_CUDNN_DISABLE_EXCEPTION"]), - linkopts = linkopts, - linkstatic = linkstatic, - tags = tags, - deps = deps, - **kwargs - ) - targets.append(name + "_cpu") - tsl_cc_test( - name = name + "_gpu", - size = size, - srcs = srcs, - args = args, - data = data, - copts = if_cuda(["-DNV_CUDNN_DISABLE_EXCEPTION"]), - linkopts = linkopts, - linkstatic = select({ - # TODO(allenl): Remove Mac static linking when Bazel 0.6 is out. - clean_dep("//tensorflow/tsl:macos"): 1, - "@local_config_cuda//cuda:using_nvcc": 1, - "@local_config_cuda//cuda:using_clang": 1, - "//conditions:default": 0, - }), - tags = tags + tf_gpu_tests_tags(), - deps = deps, - **kwargs - ) - targets.append(name + "_gpu") - if "multi_gpu" in tags or "multi_and_single_gpu" in tags: - cleaned_tags = tags + two_gpu_tags - if "requires-gpu-nvidia" in cleaned_tags: - cleaned_tags.remove("requires-gpu-nvidia") - tsl_cc_test( - name = name + "_2gpu", - size = size, - srcs = srcs, - args = args, - data = data, - linkopts = linkopts, - linkstatic = select({ - # TODO(allenl): Remove Mac static linking when Bazel 0.6 is out. - clean_dep("//tensorflow/tsl:macos"): 1, - "@local_config_cuda//cuda:using_nvcc": 1, - "@local_config_cuda//cuda:using_clang": 1, - "//conditions:default": 0, - }), - tags = cleaned_tags, - deps = deps, - **kwargs - ) - targets.append(name + "_2gpu") - - native.test_suite(name = name, tests = targets, tags = tags) From a310a8601214c3fc5058a85bca775c09d34e5873 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 12 Jul 2023 18:04:36 -0700 Subject: [PATCH 235/376] Remove use of jitrt, as it is no longer used. PiperOrigin-RevId: 547652015 --- tensorflow/compiler/mlir/tfrt/BUILD | 39 -- .../tests/jit/tf_jitrt_codegen_transpose.mlir | 130 ------ .../tfrt/tests/jit/tf_jitrt_pipeline.mlir | 440 ------------------ .../jit/tf_jitrt_pipeline_vectorized.mlir | 75 --- .../tf_to_corert/outline-cpurt-cluster.mlir | 57 --- .../compiler/mlir/tfrt/transforms/passes.cc | 5 +- .../mlir/tfrt/transforms/tf_to_tfrt.cc | 8 - .../mlir/tfrt/transforms/tfrt_jitrt_passes.cc | 407 ---------------- .../mlir/tfrt/transforms/tfrt_jitrt_stub.cc | 76 --- .../mlir/tfrt/transforms/tfrt_jitrt_stub.h | 71 --- tensorflow/core/runtime_fallback/BUILD | 1 - tensorflow/core/runtime_fallback/util/BUILD | 1 - .../util/fallback_test_util.cc | 2 - tensorflow/core/tfrt/graph_executor/BUILD | 1 - .../tfrt/graph_executor/graph_executor.cc | 3 - tensorflow/core/tfrt/saved_model/BUILD | 2 - 16 files changed, 1 insertion(+), 1317 deletions(-) delete mode 100644 tensorflow/compiler/mlir/tfrt/tests/jit/tf_jitrt_codegen_transpose.mlir delete mode 100644 tensorflow/compiler/mlir/tfrt/tests/jit/tf_jitrt_pipeline.mlir delete mode 100644 tensorflow/compiler/mlir/tfrt/tests/jit/tf_jitrt_pipeline_vectorized.mlir delete mode 100644 tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/outline-cpurt-cluster.mlir delete mode 100644 tensorflow/compiler/mlir/tfrt/transforms/tfrt_jitrt_passes.cc delete mode 100644 tensorflow/compiler/mlir/tfrt/transforms/tfrt_jitrt_stub.cc delete mode 100644 tensorflow/compiler/mlir/tfrt/transforms/tfrt_jitrt_stub.h diff --git a/tensorflow/compiler/mlir/tfrt/BUILD b/tensorflow/compiler/mlir/tfrt/BUILD index dabc0ece9e6303..db856a8e92f011 100644 --- a/tensorflow/compiler/mlir/tfrt/BUILD +++ b/tensorflow/compiler/mlir/tfrt/BUILD @@ -411,7 +411,6 @@ cc_library( ":cost_analysis", ":fallback_converter", ":tensor_array_side_effect_analysis", - ":tfrt_jitrt_stub", ":tfrt_pipeline_options", ":tpu_passes", ":transform_utils", @@ -640,8 +639,6 @@ cc_library( ], deps = [ "//tensorflow/compiler/mlir/tfrt:tf_to_tfrt", - "//tensorflow/compiler/mlir/tfrt/jit/transforms:tf_jitrt_passes", - "//tensorflow/compiler/mlir/tfrt/jit/transforms:tf_jitrt_test_passes", ], ) @@ -668,7 +665,6 @@ cc_library( ":test_tensor_array_side_effect_analysis", ":tf_jitrt_opdefs", ":tf_to_tfrt", - ":tfrt_jitrt_passes", ":transforms/gpu_passes", "//tensorflow/compiler/mlir:init_mlir", "//tensorflow/compiler/mlir:passes", @@ -750,7 +746,6 @@ tf_cc_binary( testonly = True, visibility = [":friends"], deps = [ - ":tf_jitrt_kernels_alwayslink", "@tf_runtime//:dtype", "@tf_runtime//:simple_tracing_sink", "@tf_runtime//tools:bef_executor_expensive_kernels", @@ -863,40 +858,6 @@ cc_library( ], ) -cc_library( - name = "tfrt_jitrt_passes", - srcs = ["transforms/tfrt_jitrt_passes.cc"], - deps = [ - ":fallback_converter", - ":tf_jitrt_opdefs", - ":tf_jitrt_pipeline", - ":tfrt_jitrt_stub", - "//tensorflow/compiler/mlir/tensorflow", - "//tensorflow/compiler/mlir/tensorflow:tensorflow_passes", - "//tensorflow/compiler/mlir/tfrt/ir:tfrt_fallback_async_opdefs", - "//tensorflow/compiler/mlir/tfrt/jit/transforms:tf_jitrt_clustering", - "//tensorflow/compiler/mlir/tfrt/jit/transforms:tf_jitrt_passes", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:TransformUtils", - "@tf_runtime//:basic_kernels_opdefs", - "@tf_runtime//backends/jitrt:jitrt_opdefs", - ], - alwayslink = 1, -) - -cc_library( - name = "tfrt_jitrt_stub", - srcs = ["transforms/tfrt_jitrt_stub.cc"], - hdrs = ["transforms/tfrt_jitrt_stub.h"], - deps = [ - ":corert_converter", - ":tfrt_pipeline_options", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:TransformUtils", - ], -) - cc_library( name = "constants", hdrs = ["constants.h"], diff --git a/tensorflow/compiler/mlir/tfrt/tests/jit/tf_jitrt_codegen_transpose.mlir b/tensorflow/compiler/mlir/tfrt/tests/jit/tf_jitrt_codegen_transpose.mlir deleted file mode 100644 index 855b5f6e6d5489..00000000000000 --- a/tensorflow/compiler/mlir/tfrt/tests/jit/tf_jitrt_codegen_transpose.mlir +++ /dev/null @@ -1,130 +0,0 @@ -// RUN: tf-tfrt-opt -tf-jitrt-pipeline="vectorize" -split-input-file %s | FileCheck %s - -func.func @transpose_2d(%arg0: tensor) -> tensor { - %0 = "tf.Const"() - {value = dense<[1, 0]> : tensor<2xi64>, - device = "/job:localhost/replica:0/task:0/device:CPU:0"} - : () -> tensor<2xi64> - %1 = "tf.Transpose"(%arg0, %0) - {device = "/job:localhost/replica:0/task:0/device:CPU:0"} - : (tensor, tensor<2xi64>) -> tensor - func.return %1 : tensor -} - -// CHECK-LABEL: func @transpose_2d -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index -// 8x8 tiling. -// CHECK: scf.for {{.*}} = %[[C0]] to {{.*}} step %[[C8]] { -// CHECK: scf.for {{.*}} = %[[C0]] to {{.*}} step %[[C8]] { -// Vector xfer reads: unrolled second vector dimension. -// CHECK-COUNT-8: vector.transfer_read -// AVX2 shuffle/asm sequence. -// CHECK-COUNT-12: vector.shuffle -// CHECK-COUNT-8: llvm.inline_asm -// CHECK-COUNT-8: vector.shuffle -// Vector xfer writes: unrolled second vector dimension. - -// ----- - -func.func @transpose_3d_021(%arg0: tensor) -> tensor { - %0 = "tf.Const"() { value = dense<[0, 2, 1]> : tensor<3xi64> } - : () -> tensor<3xi64> - %1 = "tf.Transpose"(%arg0, %0) - : (tensor, tensor<3xi64>) -> tensor - func.return %1 : tensor -} - -// CHECK-LABEL: func @transpose_3d -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// 1x8x8 tiling. -// CHECK: scf.for {{.*}} = %[[C0]] to {{.*}} step %[[C1]] { -// CHECK: scf.for {{.*}} = %[[C0]] to {{.*}} step %[[C8]] { -// CHECK: scf.for {{.*}} = %[[C0]] to {{.*}} step %[[C8]] { -// Vector xfer reads: unrolled second vector dimension. -// CHECK-COUNT-8: vector.transfer_read -// AVX2 shuffle/asm sequence. -// CHECK-COUNT-12: vector.shuffle -// CHECK-COUNT-8: llvm.inline_asm -// CHECK-COUNT-8: vector.shuffle -// Vector xfer writes: unrolled second vector dimension. - -// ----- - -func.func @transpose_3d_201(%arg0: tensor) -> tensor { - %0 = "tf.Const"() { value = dense<[2, 0, 1]> : tensor<3xi64> } - : () -> tensor<3xi64> - %1 = "tf.Transpose"(%arg0, %0) - : (tensor, tensor<3xi64>) -> tensor - func.return %1 : tensor -} - -// CHECK-LABEL: func @transpose_3d_201 -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// 8x1x8 tiling. -// CHECK: scf.for {{.*}} = %[[C0]] to {{.*}} step %[[C1]] { -// CHECK: scf.for {{.*}} = %[[C0]] to {{.*}} step %[[C8]] { -// CHECK: scf.for {{.*}} = %[[C0]] to {{.*}} step %[[C8]] { -// Vector xfer reads: unrolled second vector dimension. -// CHECK-COUNT-8: vector.transfer_read -// AVX2 shuffle/asm sequence. -// CHECK-COUNT-12: vector.shuffle -// CHECK-COUNT-8: llvm.inline_asm -// CHECK-COUNT-8: vector.shuffle -// Vector xfer writes: unrolled second vector dimension. - -// ----- - -func.func @transpose_3d_210(%arg0: tensor) -> tensor { - %0 = "tf.Const"() { value = dense<[2, 1, 0]> : tensor<3xi64> } - : () -> tensor<3xi64> - %1 = "tf.Transpose"(%arg0, %0) - : (tensor, tensor<3xi64>) -> tensor - func.return %1 : tensor -} - -// CHECK-LABEL: func @transpose_3d_210 -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// 8x1x8 tiling. -// CHECK: scf.for {{.*}} = %[[C0]] to {{.*}} step %[[C8]] { -// CHECK: scf.for {{.*}} = %[[C0]] to {{.*}} step %[[C1]] { -// CHECK: scf.for {{.*}} = %[[C0]] to {{.*}} step %[[C8]] { -// Vector xfer reads: unrolled second vector dimension. -// CHECK-COUNT-8: vector.transfer_read -// AVX2 shuffle/asm sequence. -// CHECK-COUNT-12: vector.shuffle -// CHECK-COUNT-8: llvm.inline_asm -// CHECK-COUNT-8: vector.shuffle -// Vector xfer writes: unrolled second vector dimension. - -// ----- - -func.func @transpose_3d_120(%arg0: tensor) -> tensor { - %0 = "tf.Const"() { value = dense<[1, 2, 0]> : tensor<3xi64> } - : () -> tensor<3xi64> - %1 = "tf.Transpose"(%arg0, %0) - : (tensor, tensor<3xi64>) -> tensor - func.return %1 : tensor -} - -// CHECK-LABEL: func @transpose_3d_120 -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// 1x8x8 tiling. -// CHECK: scf.for {{.*}} = %[[C0]] to {{.*}} step %[[C8]] { -// CHECK: scf.for {{.*}} = %[[C0]] to {{.*}} step %[[C1]] { -// CHECK: scf.for {{.*}} = %[[C0]] to {{.*}} step %[[C8]] { -// Vector xfer reads: unrolled second vector dimension. -// CHECK-COUNT-8: vector.transfer_read -// AVX2 shuffle/asm sequence. -// CHECK-COUNT-12: vector.shuffle -// CHECK-COUNT-8: llvm.inline_asm -// CHECK-COUNT-8: vector.shuffle -// Vector xfer writes: unrolled second vector dimension. diff --git a/tensorflow/compiler/mlir/tfrt/tests/jit/tf_jitrt_pipeline.mlir b/tensorflow/compiler/mlir/tfrt/tests/jit/tf_jitrt_pipeline.mlir deleted file mode 100644 index a01636a4d9dfe2..00000000000000 --- a/tensorflow/compiler/mlir/tfrt/tests/jit/tf_jitrt_pipeline.mlir +++ /dev/null @@ -1,440 +0,0 @@ -// RUN: tf-tfrt-opt -split-input-file -tf-jitrt-pipeline %s | FileCheck %s - -// CHECK: #map = affine_map<(d0, d1) -> (d0, d1)> - -// CHECK-LABEL: @tanh_lower_and_fuse -// CHECK-SAME: %[[ARG:.*]]: memref -func.func @tanh_lower_and_fuse(%arg0: tensor) -> tensor { - // CHECK: %[[C0:.*]] = arith.constant 0 : index - // CHECK: %[[DIM:.*]] = memref.dim %[[ARG]], %[[C0]] - // CHECK: %[[MEMREF:.*]] = memref.alloc(%[[DIM]]) {{.*}} : memref - - // CHECK: linalg.generic - // CHECK-SAME: indexing_maps = [#map, #map] - // CHECK-SAME: iterator_types = ["parallel", "parallel"] - // CHECK-SAME: ins(%[[ARG]] : memref) - // CHECK-SAME: outs(%[[MEMREF]] : memref) - // CHECK: tanh - // CHECK-NEXT: tanh - - // CHECK: return %[[MEMREF]] - %0 = "tf.Tanh"(%arg0): (tensor) -> tensor - %1 = "tf.Tanh"(%0): (tensor) -> tensor - func.return %1 : tensor -} - -// ----- - -// CHECK: #map = affine_map<(d0, d1) -> (d0, d1)> - -// CHECK-LABEL: @sigmoid_dynamic_dim -func.func @sigmoid_dynamic_dim(%arg0: tensor) -> tensor { - // CHECK: linalg.generic - // CHECK-SAME: indexing_maps = [#map, #map] - // CHECK-SAME: iterator_types = ["parallel", "parallel"] - %0 = "tf.Sigmoid"(%arg0) : (tensor) -> tensor - func.return %0 : tensor -} - -// ----- - -// CHECK-DAG: #map{{[0-9]*}} = affine_map<(d0) -> ()> -// CHECK-DAG: #map{{[0-9]*}} = affine_map<(d0) -> (d0)> - -// CHECK-LABEL: @add_scalar_with_vec -func.func @add_scalar_with_vec(%arg0: tensor, - %arg1: tensor) -> tensor { - // CHECK: linalg.generic - // CHECK-NOT: linalg.generic - %0 = "tf.AddV2"(%arg0, %arg1): (tensor, tensor) -> tensor - func.return %0 : tensor -} - -// ----- - -// CHECK: #map = affine_map<(d0) -> (d0)> - -// CHECK-LABEL: @add_vec_vec -func.func @add_vec_vec( - %arg0: tensor {rt.symbolic_shape = dense<-2>: tensor<1xi64>}, - %arg1: tensor {rt.symbolic_shape = dense<-2>: tensor<1xi64>} -) -> tensor { - // CHECK-NOT: memref.reinterpret_cast - // CHECK: linalg.generic - // CHECK-NOT: linalg.generic - %0 = "tf.AddV2"(%arg0, %arg1): (tensor, tensor) -> tensor - func.return %0 : tensor -} - -// ----- - -// CHECK: #map = affine_map<(d0) -> (d0)> - -// CHECK-LABEL: @add_vec_vec_vec -func.func @add_vec_vec_vec( - %arg0: tensor {rt.symbolic_shape = dense<-2>: tensor<1xi64>}, - %arg1: tensor {rt.symbolic_shape = dense<-2>: tensor<1xi64>}, - %arg2: tensor {rt.symbolic_shape = dense<-2>: tensor<1xi64>} -) -> tensor { - // CHECK-NOT: memref.reinterpret_cast - // CHECK: linalg.generic - // CHECK: addf - // CHECK: addf - // CHECK-NOT: linalg.generic - %0 = "tf.AddV2"(%arg0, %arg1): (tensor, tensor) -> tensor - %1 = "tf.AddV2"(%0, %arg2): (tensor, tensor) -> tensor - func.return %1 : tensor -} - -// ----- - -// Verify that symbolic shape optimization can move all the broadcasts up, and -// progressively remove all shape constraints and replace mhlo broadcasts with -// linalg.generic operations that in the end all are fused together. - -// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, 0)> -// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2) -> (d2)> -// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> - -// CHECK: compute_with_bcast -func.func @compute_with_bcast( - %arg0: tensor<1x?x1xf32> - {rt.symbolic_shape = dense<[1, -2, 1]> : tensor<3xi64>}, - %arg1: tensor<512xf32>, - %arg2: tensor<1x?x512xf32> - {rt.symbolic_shape = dense<[1, -2, 512]> : tensor<3xi64>}, - %arg3: tensor<1x?x1xf32> - {rt.symbolic_shape = dense<[1, -2, 1]> : tensor<3xi64>}, - %arg4: tensor<512xf32> -) -> tensor { - // CHECK-NOT: memref.reinterpret_cast - // CHECK: linalg.generic - // CHECK: addf - // CHECK-NEXT: math.rsqrt - // CHECK-NEXT: mulf - // CHECK-NEXT: mulf - // CHECK-NEXT: subf - // CHECK-NEXT: mulf - // CHECK-NEXT: addf - // CHECK-NEXT: linalg.yield - // CHECK-NOT: linalg.generic - %c = "tf.Const"() {value = dense<9.99999996E-13> - : tensor} : () -> tensor - %0 = "tf.AddV2"(%arg0, %c) - : (tensor<1x?x1xf32>, tensor) -> tensor - %1 = "tf.Rsqrt"(%0) - : (tensor) -> tensor - %2 = "tf.Mul"(%1, %arg1) - : (tensor, tensor<512xf32>) -> tensor - %3 = "tf.Mul"(%2, %arg2) - : (tensor, tensor<1x?x512xf32>) -> tensor - %4 = "tf.Mul"(%2, %arg3) - : (tensor, tensor<1x?x1xf32>) -> tensor - %5 = "tf.Sub"(%arg4, %4) - : (tensor<512xf32>, tensor) -> tensor - %6 = "tf.AddV2"(%3, %5) - : (tensor, tensor) -> tensor - func.return %6 : tensor -} - -// ----- - -// CHECK: add_vec_vec_vec_vec -func.func @add_vec_vec_vec_vec( - %arg0: tensor {rt.symbolic_shape = dense<-2>: tensor<1xi64>}, - %arg1: tensor {rt.symbolic_shape = dense<-2>: tensor<1xi64>}, - %arg2: tensor {rt.symbolic_shape = dense<-2>: tensor<1xi64>}, - %arg3: tensor {rt.symbolic_shape = dense<-2>: tensor<1xi64>} -) -> tensor { - // CHECK-NOT: memref.reinterpret_cast - // CHECK: linalg.generic - // CHECK: addf - // CHECK: addf - // CHECK: addf - // CHECK-NOT: linalg.generic - %0 = "tf.AddV2"(%arg0, %arg1): (tensor, tensor) -> tensor - %1 = "tf.AddV2"(%0, %arg2): (tensor, tensor) -> tensor - %2 = "tf.AddV2"(%1, %arg3): (tensor, tensor) -> tensor - func.return %2 : tensor -} - -// ----- - -// CHECK: add_vec_tensor_tensor -func.func @add_vec_tensor_tensor( - %arg0: tensor<512xf32>, - %arg1: tensor<1x?x512xf32> - {rt.symbolic_shape = dense<[1, -2, 512]> : tensor<3xi64>}, - %arg2: tensor<1x?x512xf32> - {rt.symbolic_shape = dense<[1, -2, 512]> : tensor<3xi64>} -) -> tensor<1x?x512xf32> { - // CHECK-NOT: memref.reinterpret_cast - // CHECK: linalg.generic - // CHECK: addf - // CHECK: addf - // CHECK-NOT: linalg.generic - %0 = "tf.AddV2"(%arg0, %arg1) - : (tensor<512xf32>, tensor<1x?x512xf32>) -> tensor<1x?x512xf32> - %1 = "tf.AddV2"(%arg2, %0) - : (tensor<1x?x512xf32>, tensor<1x?x512xf32>) -> tensor<1x?x512xf32> - func.return %1 : tensor<1x?x512xf32> -} - -// ----- - -// CHECK-LABEL: @tf_binary_with_bcast -func.func @tf_binary_with_bcast(%arg0: tensor, - %arg1: tensor) -> tensor { - // CHECK-NOT: shape. - // CHECK: %[[LHS:.*]] = memref.reinterpret_cast - // CHECK: %[[RHS:.*]] = memref.reinterpret_cast - // CHECK: linalg.generic {{.*}} ins(%[[LHS]], %[[RHS]] : - // CHECK: mulf - %0 = "tf.Mul"(%arg0, %arg1) - : (tensor, tensor) -> tensor - func.return %0 : tensor -} - -// ----- - -// CHECK-LABEL: @tf_binary_with_bcast_and_fusion -// CHECK-SAME: %[[ARG0:.*]]: memref, -// CHECK-SAME: %[[ARG1:.*]]: memref<4xf32>, -// CHECK-SAME: %[[ARG2:.*]]: memref<4xf32> -func.func @tf_binary_with_bcast_and_fusion(%arg0: tensor, - %arg1: tensor<4xf32>, - %arg2: tensor<4xf32>) -> tensor { - // CHECK: linalg.generic - // CHECK-SAME: ins(%[[ARG0]], %[[ARG1]], %[[ARG2]] : {{.*}}) - // CHECK: math.log1p - // CHECK-NEXT: subf - // CHECK-NEXT: mulf - // CHECK-NEXT: linalg.yield - // CHECK-NOT: linalg.generic - %0 = "tf.Log1p"(%arg0) - : (tensor) -> tensor - %1 = "tf.Sub"(%0, %arg1) - : (tensor, tensor<4xf32>) -> tensor - %2 = "tf.Mul"(%1, %arg2) - : (tensor, tensor<4xf32>) -> tensor - func.return %2 : tensor -} - -// ----- - -// CHECK: #[[MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)> - -// CHECK: tf_binary_with_bcast_symbolic_shapes -func.func @tf_binary_with_bcast_symbolic_shapes( - %arg0: tensor {rt.symbolic_shape = dense<[ -3]>: tensor<1xi64>}, - %arg1: tensor {rt.symbolic_shape = dense<[-2,-3]>: tensor<2xi64>}, - %arg2: tensor {rt.symbolic_shape = dense<[-2,-3]>: tensor<2xi64>}, - %arg3: tensor {rt.symbolic_shape = dense<[-2,-3]>: tensor<2xi64>} -) -> tensor { - // CHECK-NOT: memref.reinterpret_cast - // CHECK: linalg.generic - // CHECK: log1p - // CHECK: addf - // CHECK: addf - // CHECK: addf - // CHECK-NOT: linalg.generic - %0 = "tf.Log1p"(%arg0) - : (tensor) -> tensor - %1 = "tf.AddV2"(%0, %arg1) - : (tensor, tensor) -> tensor - %2 = "tf.AddV2"(%1, %arg2) - : (tensor, tensor) -> tensor - %3 = "tf.AddV2"(%2, %arg3) - : (tensor, tensor) -> tensor - func.return %3 : tensor -} - -// ----- - -// CHECK-LABEL: @cast_sub -func.func @cast_sub(%arg0: tensor, %arg1: tensor) - -> tensor { - // CHECK: linalg.generic - // CHECK-SAME: outs(%[[RESULT_BUF:.*]] : memref) - // CHECK-SAME: { - // CHECK: ^bb0(%[[LHS:.*]]: f16, %[[RHS:.*]]: i16, %{{.*}}: f16): - // CHECK: %[[RHS_CASTED:.*]] = arith.sitofp %[[RHS]] : i16 to f16 - // CHECK: %[[RESULT:.*]] = arith.subf %[[LHS]], %[[RHS_CASTED]] : f16 - // CHECK: linalg.yield %[[RESULT]] : f16 - // CHECK: } - // CHECK: return %[[RESULT_BUF]] : memref - %0 = "tf.Cast"(%arg0) : (tensor) -> tensor - %1 = "tf.Sub"(%arg1, %0) : (tensor, tensor) - -> tensor - func.return %1 : tensor -} - -// ----- - -// CHECK-DAG: #map{{[0-9]*}} = affine_map<(d0, d1) -> (d1, d0)> -// CHECK-DAG: #map{{[0-9]*}} = affine_map<(d0, d1) -> (d0, d1)> - -// CHECK-LABEL: @tf_transpose_const_perm -func.func @tf_transpose_const_perm(%arg0: tensor<2x3xf32>) -> tensor<3x2xf32> { - // CHECK: %[[OUT:.*]] = memref.alloc() {{.*}} : memref<3x2xf32> - // CHECK: linalg.generic {indexing_maps = [#map{{[0-9]*}}, #map{{[0-9]*}}] - // CHECK-SAME: ins(%arg0 : memref<2x3xf32>) - // CHECK-SAME: outs(%[[OUT]] : memref<3x2xf32>) - %0 = "tf.Const"() { value = dense<[1, 0]> : tensor<2xi32> } - : () -> tensor<2xi32> - %1 = "tf.Transpose"(%arg0, %0) - : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32> - func.return %1 : tensor<3x2xf32> -} - -// ----- - -// CHECK-DAG: #map{{[0-9]*}} = affine_map<(d0, d1, d2) -> (d2, d0, d1)> -// CHECK-DAG: #map{{[0-9]*}} = affine_map<(d0, d1, d2) -> (d0, d1, d2)> - -// CHECK-LABEL: @tf_transpose_after_transpose -func.func @tf_transpose_after_transpose(%arg0: tensor) - -> tensor { - // CHECK: %[[OUT:.*]] = memref.alloc - // CHECK: linalg.generic {indexing_maps = [#map{{[0-9]*}}, #map{{[0-9]*}}] - // CHECK-SAME: ins(%arg0 : memref) - // CHECK-SAME: outs(%[[OUT]] : memref) - // CHECK-NOT: linalg.generic - %0 = "tf.Const"() { value = dense<[0, 2, 1]> : tensor<3xi32> } - : () -> tensor<3xi32> - %1 = "tf.Const"() { value = dense<[2, 1, 0]> : tensor<3xi32> } - : () -> tensor<3xi32> - %2 = "tf.Transpose"(%arg0, %0) - : (tensor, tensor<3xi32>) -> tensor - %3 = "tf.Transpose"(%2, %1) - : (tensor, tensor<3xi32>) -> tensor - func.return %3 : tensor -} - -// ----- - -// CHECK-LABEL: @bias_add_and_relu -// CHECK-SAME: %[[ARG0:.*]]: memref -// CHECK-SAME: %[[ARG1:.*]]: memref<32xf32> -func.func @bias_add_and_relu(%arg0: tensor, - %arg1: tensor<32xf32>) -> tensor { - // CHECK: linalg.generic - // CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : {{.*}}) - // CHECK: addf - // CHECK: maxf - // CHECK-NEXT: linalg.yield - // CHECK-NOT: linalg.generic - %0 = "tf.BiasAdd"(%arg0, %arg1) - : (tensor, tensor<32xf32>) -> tensor - %1 = "tf.Relu"(%0): (tensor) -> tensor - func.return %1 : tensor -} - -// ----- - -// CHECK-LABEL: @sub_sub -func.func @sub_sub(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { - // CHECK: linalg.generic - // CHECK-SAME: outs(%[[RESULT_BUF:.*]] : memref) - // CHECK: ^bb0(%[[A:.*]]: f16, %[[B:.*]]: f16, %[[C:.*]]: f16, %{{.*}}: f16): - // CHECK: %[[TMP:.*]] = arith.subf %[[B]], %[[C]] - // CHECK: %[[RESULT:.*]] = arith.subf %[[A]], %[[TMP]] - // CHECK: linalg.yield %[[RESULT]] - // CHECK: return %[[RESULT_BUF]] : memref - %0 = "tf.Sub"(%arg0, %arg1) : (tensor, tensor) -> tensor - %1 = "tf.Sub"(%arg2, %0) : (tensor, tensor) -> tensor - func.return %1 : tensor -} - -// ----- - -// CHECK-LABEL: @strided_slice_1d_to_0d -func.func @strided_slice_1d_to_0d(%arg0: tensor<3xi32>) -> tensor { - %cst_0 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> - %cst_1 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32> - // CHECK: %[[SUBVIEW:.*]] = memref.subview %arg0[0] [1] [1] - // CHECK-SAME: : memref<3xi32> to memref<1xi32, strided<[1]>> - // CHECK: %[[RET:.*]] = memref.collapse_shape %[[SUBVIEW]] - // CHECK: return %[[RET]] - %0 = "tf.StridedSlice"(%arg0, %cst_1, %cst_0, %cst_0) - { - begin_mask = 0 : i64, - ellipsis_mask = 0 : i64, - end_mask = 0 : i64, - new_axis_mask = 0 : i64, - shrink_axis_mask = 1 : i64 - } : (tensor<3xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) - -> tensor - func.return %0 : tensor -} - -// ----- - -// CHECK: memref.global "private" constant @__constant_2xi32 : memref<2xi32> = dense<[0, 1]> -// CHECK-SAME: {alignment = 64 : i64} -// CHECK-LABEL: @constant_folding -func.func @constant_folding() -> tensor<2xi32> { - %0 = "tf.Const"() {value = dense<0> : tensor} : () -> tensor - %1 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor - // CHECK: %[[CONST:.*]] = memref.get_global @__constant_2xi32 : memref<2xi32> - // CHECK: return %[[CONST]] - %2 = "tf.Pack"(%0, %1) {axis = 0 : i64} - : (tensor, tensor) -> tensor<2xi32> - func.return %2 : tensor<2xi32> -} - -// ----- - -// CHECK-LABEL: @add_floormod_add -func.func @add_floormod_add(%arg0: tensor) -> tensor { - // CHECK: linalg.generic - // CHECK-NOT: linalg.generic - %0 = "tf.AddV2"(%arg0, %arg0) - : (tensor, tensor) -> tensor - %1 = "tf.FloorMod"(%0, %arg0) - : (tensor, tensor) -> tensor - %2 = "tf.AddV2"(%1, %arg0) - : (tensor, tensor) -> tensor - func.return %2 : tensor -} - -// ----- - -// CHECK-LABEL: @min_clip_by_value -func.func @min_clip_by_value(%V__0: tensor) -> tensor { - %dims0 = "tf.Const"() { value = dense<[1, 2]> : tensor<2xi32> }: () -> tensor<2xi32> - %0 = "tf.Min"(%V__0, %dims0) {keep_dims = true} : (tensor, tensor<2xi32>) -> tensor - %1 = "tf.ClipByValue"(%V__0, %0, %V__0) : (tensor, tensor, tensor) -> tensor - func.return %1 : tensor -} - -// ----- - -// CHECK-LABEL: @rint_sq_sub -func.func @rint_sq_sub(%arg0: tensor) -> tensor { - // CHECK: linalg.generic - // CHECK-NOT: linalg.generic - %0 = "tf.Rint"(%arg0) : (tensor) -> tensor - %1 = "tf.Square"(%arg0) : (tensor) -> tensor - %2 = "tf.Sub"(%0, %1) : (tensor, tensor) -> tensor - func.return %2 : tensor -} - -// ----- - -// CHECK-LABEL: @do_not_fuse_if_multiple_uses -func.func @do_not_fuse_if_multiple_uses(%arg0: tensor) - -> (tensor, tensor) { - // CHECK: linalg.generic - // CHECK: math.rsqrt - // CHECK-NEXT: math.rsqrt - // CHECK-NEXT: linalg.yield - %0 = "tf.Rsqrt"(%arg0) : (tensor) -> tensor - %1 = "tf.Rsqrt"(%0) : (tensor) -> tensor - // CHECK: linalg.generic - // CHECK: math.rsqrt - // CHECK-NEXT: linalg.yield - %2 = "tf.Rsqrt"(%1) : (tensor) -> tensor - // CHECK-NOT: linalg.generic - func.return %1, %2 : tensor, tensor -} diff --git a/tensorflow/compiler/mlir/tfrt/tests/jit/tf_jitrt_pipeline_vectorized.mlir b/tensorflow/compiler/mlir/tfrt/tests/jit/tf_jitrt_pipeline_vectorized.mlir deleted file mode 100644 index 07d2d6a3f08434..00000000000000 --- a/tensorflow/compiler/mlir/tfrt/tests/jit/tf_jitrt_pipeline_vectorized.mlir +++ /dev/null @@ -1,75 +0,0 @@ -// RUN: tf-tfrt-opt -tf-jitrt-pipeline="vectorize" \ -// RUN: %s -split-input-file | FileCheck %s - -// CHECK-LABEL: @reduce_row_sum_2d_dynamic -func.func @reduce_row_sum_2d_dynamic(%input: tensor) -> tensor { - %dim_to_reduce = "tf.Const"() {value = dense<[1]> : tensor<1xi32>} - : () -> tensor<1xi32> - %0 = "tf.Sum"(%input, %dim_to_reduce) {keep_dims = false} - : (tensor, tensor<1xi32>) -> tensor - func.return %0 : tensor -} -// CHECK: scf.for -// CHECK: scf.for -// CHECK-COUNT-4: arith.addf %{{.*}}, %{{.*}} : vector<4xf32> - -// ----- - -// CHECK-LABEL: @reduce_column_sum_2d_dynamic -func.func @reduce_column_sum_2d_dynamic(%input: tensor) -> tensor { - %dim_to_reduce = "tf.Const"() {value = dense<[0]> : tensor<1xi32>} - : () -> tensor<1xi32> - %0 = "tf.Sum"(%input, %dim_to_reduce) {keep_dims = false} - : (tensor, tensor<1xi32>) -> tensor - func.return %0 : tensor -} -// CHECK: scf.for -// CHECK: scf.for -// CHECK-COUNT-4: arith.addf %{{.*}}, %{{.*}} : vector<4xf32> - -// ----- - -// CHECK-LABEL: @reduce_row_mean_2d_dynamic -func.func @reduce_row_mean_2d_dynamic(%input: tensor) -> tensor { - %dim_to_reduce = "tf.Const"() {value = dense<[1]> : tensor<1xi32>} - : () -> tensor<1xi32> - %0 = "tf.Mean"(%input, %dim_to_reduce) {keep_dims = false} - : (tensor, tensor<1xi32>) -> tensor - func.return %0 : tensor -} -// CHECK: scf.for -// CHECK: scf.for -// CHECK-COUNT-4: arith.addf %{{.*}}, %{{.*}} : vector<4xf32> -// CHECK: scf.yield -// CHECK: arith.divf %{{.*}}, %{{.*}} : vector<4xf32> - -// ----- - -// CHECK-LABEL: @reduce_1d_dynamic -func.func @reduce_1d_dynamic(%input: tensor) -> tensor { - %dim_to_reduce = "tf.Const"() {value = dense<[0]> : tensor<1xi32>} - : () -> tensor<1xi32> - %0 = "tf.Sum"(%input, %dim_to_reduce) {keep_dims = false} - : (tensor, tensor<1xi32>) -> tensor - func.return %0 : tensor -} -// CHECK: scf.for -// CHECK: arith.addf %{{.*}}, %{{.*}} : vector<8xf32> -// CHECK: vector.reduction - -// ----- - -// CHECK-LABEL: @reduction_of_cast -func.func @reduction_of_cast(%arg0: tensor) -> tensor { - %cst = "tf.Const"() - {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32> - %0 = "tf.Cast"(%arg0) {Truncate = false} - : (tensor) -> tensor - %1 = "tf.Prod"(%0, %cst) {keep_dims = false} - : (tensor, tensor<1xi32>) -> tensor - func.return %1 : tensor -} -// CHECK: scf.for -// CHECK: arith.trunci -// CHECK: scf.for -// CHECK: arith.muli diff --git a/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/outline-cpurt-cluster.mlir b/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/outline-cpurt-cluster.mlir deleted file mode 100644 index 9ade6a0d6f0243..00000000000000 --- a/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/outline-cpurt-cluster.mlir +++ /dev/null @@ -1,57 +0,0 @@ -// RUN: tf-tfrt-opt -split-input-file -tf-outline-jitrt-cluster %s \ -// RUN: | FileCheck %s - -// ----- -// Outline a simple cluster with a single operation. - -// CHECK-LABEL: func @simple_cluster -func.func @simple_cluster(%arg0: tensor) -> tensor { - // CHECK: %[[RES:.*]] = jitrt.call(%arg0) - // CHECK-SAME: {callee = @kernel::@compute} - // CHECK-SAME: (tensor) -> tensor - %0 = "tf_device.cluster"() ({ - %1 = "tf.Rsqrt"(%arg0) : (tensor) -> tensor - tf_device.return %1 : tensor - }) { policy = "tfrt.auto-fusion" } : () -> tensor - func.return %0 : tensor -} - -// CHECK: module @kernel attributes { -// CHECK-SAME: tfrt.compiled -// CHECK-SAME: "tfrt.max-arg-size" = 1 : i64 -// CHECK-SAME: } -// CHECK: func @compute( -// CHECK-SAME: %arg0: tensor -// CHECK-SAME: ) -> tensor { -// CHECK: %[[RET:.*]] = "tf.Rsqrt"(%arg0) -// CHECK: return %[[RET]] -// CHECK: } - -// ----- -// Check that tf.Transpose constraint propagated to the function argument. - -// CHECK-LABEL: func @cluster_with_transpose -func.func @cluster_with_transpose(%arg0: tensor, - %arg1: tensor<2xi32>) -> tensor { - // CHECK: %[[RES:.*]] = jitrt.call(%arg0, %arg1) - // CHECK-SAME: {callee = @kernel::@compute} - // CHECK-SAME: (tensor, tensor<2xi32>) -> tensor - %0 = "tf_device.cluster"() ({ - %1 = "tf.Transpose"(%arg0, %arg1) - : (tensor, tensor<2xi32>) -> tensor - tf_device.return %1 : tensor - }) { policy = "tfrt.auto-fusion" } : () -> tensor - func.return %0 : tensor -} - -// CHECK: module @kernel attributes { -// CHECK-SAME: tfrt.compiled -// CHECK-SAME: "tfrt.max-arg-size" = 2 : i64 -// CHECK-SAME: } -// CHECK: func @compute( -// CHECK-SAME: %arg0: tensor -// CHECK-SAME: %arg1: tensor<2xi32> {rt.constraint = "value"} -// CHECK-SAME: ) -> tensor { -// CHECK: %[[RET:.*]] = "tf.Transpose"(%arg0, %arg1) -// CHECK: return %[[RET]] -// CHECK: } diff --git a/tensorflow/compiler/mlir/tfrt/transforms/passes.cc b/tensorflow/compiler/mlir/tfrt/transforms/passes.cc index bf27bac6ffb11c..4375b78fc2497f 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/passes.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/passes.cc @@ -27,7 +27,6 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_asset_sinking_pass.h" #include "tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.h" #include "tensorflow/compiler/mlir/tfrt/transforms/set_shape_invariant_in_while_ops.h" -#include "tensorflow/compiler/mlir/tfrt/transforms/tfrt_jitrt_stub.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/util/device_name_utils.h" @@ -128,7 +127,7 @@ void CreateTFExecutorToTFPreInvariantOptimizationPipelineHelper( // flow, which is converted back after the optimization passes are performed. pm.addPass(mlir::TF::CreateTFFunctionalControlFlowToRegions()); pm.addPass(mlir::createInlinerPass()); - pm.addNestedPass( + pm.addNestedPass( mlir::TF::CreateRemoveUnusedWhileResultsPass()); pm.addPass(mlir::TF::CreateTFRegionControlFlowToFunctional()); @@ -173,8 +172,6 @@ void CreateTFExecutorToTFPreInvariantOptimizationPipelineHelper( pm.addNestedPass( mlir::TF::CreateTensorDeviceCopyConversionPass()); - AddTfrtJitRtPasses(options, pm); - // Rewriter operation sequences to device specific fusions. DeviceNameUtils::ParsedName parsed_name; diff --git a/tensorflow/compiler/mlir/tfrt/transforms/tf_to_tfrt.cc b/tensorflow/compiler/mlir/tfrt/transforms/tf_to_tfrt.cc index 5cde65d2c65508..f6b0dbc3767c9a 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/tf_to_tfrt.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/tf_to_tfrt.cc @@ -55,7 +55,6 @@ limitations under the License. #include "tensorflow/compiler/mlir/tfrt/transforms/fallback_converter.h" #include "tensorflow/compiler/mlir/tfrt/transforms/gpu_passes.h" #include "tensorflow/compiler/mlir/tfrt/transforms/passes.h" -#include "tensorflow/compiler/mlir/tfrt/transforms/tfrt_jitrt_stub.h" #include "tensorflow/compiler/mlir/tfrt/transforms/utils.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/types.h" @@ -85,7 +84,6 @@ void getDependentConversionDialects(mlir::DialectRegistry ®istry) { tfrt::fallback_async::FallbackAsyncDialect, tfrt::compiler::TFRTDialect>(); mlir::func::registerAllExtensions(registry); - RegisterJitRtDialects(registry); } mlir::Value GetFunctionInputChain(mlir::Operation *op) { @@ -1564,9 +1562,6 @@ class TfToTfrtConversionPass SetUpTFToTFRTConversionLegality(&target, func_type_converter, corert_converter.chain_type()); - PopulateJitRtConversionPatterns(&target, &context, &patterns, - &corert_converter); - PopulateTFToTFRTConversionPatterns( &context, &patterns, &corert_converter, &fallback_converter, &symbol_table, &cost_analysis, &tensor_array_side_effect_analysis, @@ -1689,9 +1684,6 @@ class TfToTfrtConversionPass chain_value = create_op; } - chain_value = - CreateJitRtFallbackCompileKernel(builder, module, chain_value); - builder.create(func_op.getLoc(), chain_value); } diff --git a/tensorflow/compiler/mlir/tfrt/transforms/tfrt_jitrt_passes.cc b/tensorflow/compiler/mlir/tfrt/transforms/tfrt_jitrt_passes.cc deleted file mode 100644 index 83fb4e2343b1ea..00000000000000 --- a/tensorflow/compiler/mlir/tfrt/transforms/tfrt_jitrt_passes.cc +++ /dev/null @@ -1,407 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include - -#include "llvm/Support/FormatVariadic.h" -#include "mlir/Transforms/RegionUtils.h" // from @llvm-project -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" -#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" -#include "tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_async.h" -#include "tensorflow/compiler/mlir/tfrt/jit/opdefs/tf_jitrt_ops.h" -#include "tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_clustering.h" -#include "tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.h" -#include "tensorflow/compiler/mlir/tfrt/transforms/fallback_converter.h" -#include "tensorflow/compiler/mlir/tfrt/transforms/tfrt_jitrt_stub.h" -#include "tfrt/jitrt/opdefs/jitrt_ops.h" // from @tf_runtime -#include "tfrt/basic_kernels/opdefs/basic_kernels.h" // from @tf_runtime - -namespace tensorflow { -namespace { - -class TfrtJitRtStubImpl : public TfrtJitRtStub { - void RegisterJitRtDialects(mlir::DialectRegistry ®istry) override; - - void PopulateJitRtConversionPatterns( - mlir::ConversionTarget *target, mlir::MLIRContext *context, - mlir::RewritePatternSet *patterns, - CoreRTConverter *corert_converter) override; - - mlir::Value CreateJitRtFallbackCompileKernel( - mlir::OpBuilder &builder, mlir::ModuleOp module, - mlir::Value chain_value) override; - - void AddTfrtJitRtPasses(const TfrtPipelineOptions &options, - mlir::OpPassManager &pm) override; -}; - -void TfrtJitRtStubImpl::RegisterJitRtDialects(mlir::DialectRegistry ®istry) { - registry.insert(); -} - -// TODO(ezhulenev): tf_device.cluster operations after auto-fusion should -// have the correct device assigned based on the fused operations. We should -// use this device to convert operands and results from/to corert handles. -// For now it is safe to assume that it is "CPU" because we do not support -// any other devices and do not support distributed models. -constexpr char kJitRtDevice[] = "/job:localhost/replica:0/task:0/device:CPU:0"; - -// Convert jitrt.call operations to the tf_jitrt.fallback.execute operation. -class JitRtCallToJitRtCompileAndExecuteConversion - : public OpConversionPattern { - public: - explicit JitRtCallToJitRtCompileAndExecuteConversion(MLIRContext *context) - : OpConversionPattern(context) {} - - LogicalResult matchAndRewrite( - tfrt::jitrt::CallOp call, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - // Convert operands to fallback tensors. - llvm::SmallVector fallback_operands; - if (failed(tfrt_compiler::ConvertFallbackOperands( - call, kJitRtDevice, adaptor.getOperands(), &fallback_operands, - rewriter))) - return rewriter.notifyMatchFailure(call, "failed to convert operand"); - - // tf_jitrt.fallback.execute always produces fallback tensors. - llvm::SmallVector result_types( - call->getNumResults(), - rewriter.getType()); - - // Replace jitrt.call operation with a tf_jitrt.fallback.execute operation. - rewriter.replaceOpWithNewOp( - call, result_types, call.getCallee(), fallback_operands, kJitRtDevice); - - return success(); - } -}; - -// Helper function for inserting TFRT JitRt dialect conversions. -void TfrtJitRtStubImpl::PopulateJitRtConversionPatterns( - mlir::ConversionTarget *target, MLIRContext *context, - RewritePatternSet *patterns, CoreRTConverter *corert_converter) { - target->addLegalDialect(); - target->addIllegalDialect(); - // Lower jitrt.call to the pair of compile and execute operations. - patterns->add(context); -} - -mlir::Value TfrtJitRtStubImpl::CreateJitRtFallbackCompileKernel( - mlir::OpBuilder &builder, mlir::ModuleOp module, mlir::Value chain_value) { - // Pre-compile all JIT compiled kernels found in the module. - llvm::SmallVector compiled; - - // A set SymbolRef attributes referencing compiled kernels. - llvm::DenseSet kernels; - - // Compile all kernels in parallell. - module.walk([&](tf_jitrt::FallbackExecuteOp execute) { - // Do not compiled the same kernel multiple times. - if (kernels.contains(execute.getKernel())) return; - - auto compile = builder.create( - execute.getLoc(), builder.getType(), - execute.getKernel(), execute.getDevice()); - compiled.push_back(compile.getResult()); - kernels.insert(compile.getKernel()); - }); - - // Wait for the compilation completion before returning from init function. - if (!compiled.empty()) { - // Do not forget to wait for the fallback kernels initialization. - compiled.insert(compiled.begin(), chain_value); - chain_value = builder.create( - module.getLoc(), builder.getType(), - compiled); - } - - return chain_value; -} - -// -------------------------------------------------------------------------- // -// Outline tf_device.cluster operation regions into functions in the nested -// modules and replaces all cluster operations with jitrt.call operations. -// -------------------------------------------------------------------------- // - -class OutlineJitRtClustersPass - : public PassWrapper> { - public: - llvm::StringRef getArgument() const final { - return "tf-outline-jitrt-cluster"; - } - llvm::StringRef getDescription() const final { - return "Outlines `tf_device.cluster` operations into functions and " - "replaces them with `jitrt.call` operations."; - } - - void runOnOperation() override; - - void getDependentDialects(mlir::DialectRegistry ®istry) const override { - registry.insert(); - } - - public: - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OutlineJitRtClustersPass) - - private: - struct CompiledModule { - ModuleOp module; - func::FuncOp entrypoint; - llvm::SetVector operands; - }; - - // Creates a nested module with a single function that will be compiled into - // the kernel at runtime. - CompiledModule CreateCompiledModule(tf_device::ClusterOp cluster, - int64_t max_arg_size, - SymbolTable *symbol_table); - - // Update compiled module entrypoint signature with inferred operands - // constraints. - LogicalResult SetEntrypointConstraints(CompiledModule &compiled); - - // Outlines cluster operation regions into compiled modules, and replaces - // cluster operation with a jitrt.call operation. - LogicalResult OutlineClusterOp(tf_device::ClusterOp cluster, - int64_t max_arg_size, - SymbolTable *symbol_table); - - // Mapping from the outlined module string representation to the module itself - // and an entrypoint function. Used to deduplicate identical modules during - // the `tf_device.cluster` outlining. - llvm::StringMap> outlined_; -}; - -OutlineJitRtClustersPass::CompiledModule -OutlineJitRtClustersPass::CreateCompiledModule(tf_device::ClusterOp cluster, - int64_t max_arg_size, - SymbolTable *symbol_table) { - MLIRContext *ctx = cluster->getContext(); - Location loc = cluster.getLoc(); - - // Create a module that will hold compiled function and async wrappers. - // TODO(ezhulenev): Give better names to module and function. - auto compiled_module = ModuleOp::create(loc, {"kernel"}); - compiled_module->setAttr("tfrt.compiled", UnitAttr::get(ctx)); - compiled_module->setAttr( - "tfrt.max-arg-size", - IntegerAttr::get(IntegerType::get(ctx, 64), max_arg_size)); - - SymbolTable compiled_module_symbol_table(compiled_module); - - // Find out the cluster arguments and their types. - llvm::SetVector live_ins; - getUsedValuesDefinedAbove(cluster.getBody(), cluster.getBody(), live_ins); - - llvm::SmallVector operand_types; - operand_types.reserve(live_ins.size()); - for (Value v : live_ins) operand_types.emplace_back(v.getType()); - - // Create a function in the compiled module. - auto compiled_func_type = - FunctionType::get(ctx, operand_types, cluster->getResultTypes()); - auto compiled_func = func::FuncOp::create(loc, "compute", compiled_func_type); - compiled_module_symbol_table.insert(compiled_func); - - // Replace uses of live-in values within cluster region with block arguments. - Block *compiled_func_block = compiled_func.addEntryBlock(); - for (auto p : llvm::zip(live_ins, compiled_func_block->getArguments())) - replaceAllUsesInRegionWith(std::get<0>(p), std::get<1>(p), - cluster.getBody()); - - // Move all operations in cluster into compiled_func's entry block. - auto &cluster_body = cluster.GetBody().getOperations(); - compiled_func_block->getOperations().splice( - compiled_func_block->end(), cluster_body, cluster_body.begin(), - cluster_body.end()); - - // Replace `tf_device.return` terminator with `func.return` in the function - // body. - auto device_return = - cast(compiled_func_block->getTerminator()); - OpBuilder builder(device_return.getOperation()); - builder.create(device_return.getLoc(), - device_return.getOperands()); - device_return.erase(); - - // TODO(ezhulenev): MLIR doesn't define operation equivalence upstream yet, - // replace module printing with a more principled solution when available. - // Operations in the cluster can be in different order, however define the - // identical Tensorflow programs, with current approach we'll not be able - // to detect duplicates like this. - - // Remove location attribute attached to Tensorflow operations to be able to - // deduplicate compiled clusters with the same set of operations. - // - // TODO(ezhulenev): Figure out how to propagate locations for error reporting, - // right now JitRt will ignore them anyway. - compiled_module.walk([](Operation *op) { op->removeAttr("_class"); }); - - // Serialize prepared module to string. - std::string serialized; - llvm::raw_string_ostream os(serialized); - compiled_module.print(os); - - // Try to find if identical module was already outlined. - auto it = outlined_.find(serialized); - - // Return identical module that was already outlined earlier. - if (it != outlined_.end()) { - compiled_module.erase(); // erase identical module - return {it->second.first, it->second.second, live_ins}; - } - - // Insert compiled module into the symbol table and assign it a unique name. - symbol_table->insert(compiled_module); - - // Cache unique module. - outlined_.insert({std::move(serialized), {compiled_module, compiled_func}}); - - return {compiled_module, compiled_func, live_ins}; -} - -LogicalResult OutlineJitRtClustersPass::SetEntrypointConstraints( - CompiledModule &compiled) { - func::FuncOp func = compiled.entrypoint; - - // Functions outlined from jitrt device clusters must have a single block. - assert(func.getBody().getBlocks().size() == 1 && "expected single block"); - - mlir::TFDevice::ClusteringPolicySet policies; - populateTfJitRtConstraintsPolicies(policies); - - // Infer constraints on the values defined in the entrypoint function - // (including function entry block arguments). - mlir::TFDevice::ValuesConstraintSet constraints; - if (failed(mlir::TFDevice::PropagateValuesConstraints( - func.getBody(), policies, constraints, /*resolve=*/true))) - return failure(); - - // Annotate arguments with inferred constraints. - for (unsigned i = 0; i < func.getNumArguments(); ++i) { - if (auto constraint = constraints.GetConstraint(func.getArgument(i))) { - auto constraint_name = mlir::StringAttr::get( - &getContext(), llvm::formatv("{0}", *constraint).str()); - func.setArgAttr(i, "rt.constraint", constraint_name); - } - } - - return success(); -} - -LogicalResult OutlineJitRtClustersPass::OutlineClusterOp( - tf_device::ClusterOp cluster, int64_t max_arg_size, - SymbolTable *symbol_table) { - Location loc = cluster->getLoc(); - OpBuilder builder(cluster); - - CompiledModule compiled_module = - CreateCompiledModule(cluster, max_arg_size, symbol_table); - func::FuncOp compiled_func = compiled_module.entrypoint; - - // Add constraints to the entrypoint arguments. - if (failed(SetEntrypointConstraints(compiled_module))) return failure(); - - // Replace device cluster with a jitrt.call operation. - auto module_name = *compiled_module.module.getSymName(); - auto func_name = compiled_func.getSymName(); - auto func_flat_ref = - mlir::SymbolRefAttr::get(builder.getContext(), func_name); - auto func_ref = mlir::SymbolRefAttr::get(builder.getContext(), module_name, - {func_flat_ref}); - - auto cluster_func_op = builder.create( - loc, cluster.getResultTypes(), func_ref, - compiled_module.operands.getArrayRef()); - - cluster.replaceAllUsesWith(cluster_func_op); - cluster.erase(); - - return success(); -} - -void OutlineJitRtClustersPass::runOnOperation() { - ModuleOp module = getOperation(); - SymbolTable symbol_table(module); - - // Keep track of the maximum argument size for each function with tf_device - // cluster operations in the function body. We need to pass it to the compiled - // module to correctly compute its cost later. - llvm::DenseMap max_arg_size_map; - - auto get_max_arg_size = [&](mlir::func::FuncOp func) -> int64_t { - auto it = max_arg_size_map.find(func); - if (it != max_arg_size_map.end()) return it->second; - return max_arg_size_map[func] = tf_jitrt::GetMaxArgSize(func); - }; - - OpBuilder builder(module.getContext()); - auto result = module.walk([&](tf_device::ClusterOp cluster) -> WalkResult { - // Ensure that cluster was formed for TFRT JIT compilation. - auto policy = cluster->getAttr("policy").dyn_cast_or_null(); - if (!policy || policy.getValue() != "tfrt.auto-fusion") - return WalkResult::advance(); - - // Get the maximum argument size of the parent function. - mlir::func::FuncOp parent_func = - cluster->getParentOfType(); - int64_t max_arg_size = get_max_arg_size(parent_func); - - if (failed(OutlineClusterOp(cluster, max_arg_size, &symbol_table))) - return WalkResult::interrupt(); - return WalkResult::advance(); - }); - - if (result.wasInterrupted()) { - module->emitError("Failed to outline tf_device.cluster operations"); - signalPassFailure(); - } -} - -std::unique_ptr CreateOutlineJitRtClustersPass() { - return std::make_unique(); -} - -void TfrtJitRtStubImpl::AddTfrtJitRtPasses(const TfrtPipelineOptions &options, - mlir::OpPassManager &pm) { - // Sink small constants into the outlined clusters to reduce the number of - // arguments for each of the execute operations. - auto is_compilable_const = [](mlir::tf_device::ClusterOp cluster, - mlir::ElementsAttr value) -> bool { - // Ensure that cluster was formed for TFRT JIT compilation. - auto policy = cluster->getAttr("policy").dyn_cast_or_null(); - if (!policy || policy.getValue() != "tfrt.auto-fusion") return false; - - // Check that TF->JitRt compiler supports constant compilation. - return mlir::succeeded(IsCompilableConstant(value)); - }; - - pm.addNestedPass( - mlir::TFDevice::CreateClusterConstantSinkingPass(is_compilable_const)); - - // Outline formed JIT compiled device clusters into function. - pm.addPass(CreateOutlineJitRtClustersPass()); -} - -mlir::PassRegistration tf_outline_jitrt_cluster_pass( - CreateOutlineJitRtClustersPass); - -const bool kUnused = - (RegisterTfrtJitRtStub(std::make_unique()), true); - -} // namespace -} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tfrt/transforms/tfrt_jitrt_stub.cc b/tensorflow/compiler/mlir/tfrt/transforms/tfrt_jitrt_stub.cc deleted file mode 100644 index 1bde6382c79bdc..00000000000000 --- a/tensorflow/compiler/mlir/tfrt/transforms/tfrt_jitrt_stub.cc +++ /dev/null @@ -1,76 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/mlir/tfrt/transforms/tfrt_jitrt_stub.h" - -#include -#include -#include - -namespace tensorflow { -namespace { - -class TfrtJitRtStubRegistry { - public: - TfrtJitRtStubRegistry() : stub_(std::make_unique()) {} - - void Register(std::unique_ptr stub) { - stub_ = std::move(stub); - } - - TfrtJitRtStub &Get() { return *stub_; } - - private: - std::unique_ptr stub_; -}; - -TfrtJitRtStubRegistry &GetGlobalTfrtJitRtStubRegistry() { - static auto *const stub = new TfrtJitRtStubRegistry; - return *stub; -} - -} // namespace - -void RegisterTfrtJitRtStub(std::unique_ptr stub) { - GetGlobalTfrtJitRtStubRegistry().Register(std::move(stub)); -} - -void RegisterJitRtDialects(mlir::DialectRegistry ®istry) { - GetGlobalTfrtJitRtStubRegistry().Get().RegisterJitRtDialects(registry); -} - -// Helper function for inserting TFRT JitRt dialect conversions. -void PopulateJitRtConversionPatterns(mlir::ConversionTarget *target, - mlir::MLIRContext *context, - mlir::RewritePatternSet *patterns, - CoreRTConverter *corert_converter) { - GetGlobalTfrtJitRtStubRegistry().Get().PopulateJitRtConversionPatterns( - target, context, patterns, corert_converter); -} - -mlir::Value CreateJitRtFallbackCompileKernel(mlir::OpBuilder &builder, - mlir::ModuleOp module, - mlir::Value chain_value) { - return GetGlobalTfrtJitRtStubRegistry() - .Get() - .CreateJitRtFallbackCompileKernel(builder, module, chain_value); -} - -void AddTfrtJitRtPasses(const TfrtPipelineOptions &options, - mlir::OpPassManager &pm) { - GetGlobalTfrtJitRtStubRegistry().Get().AddTfrtJitRtPasses(options, pm); -} - -} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tfrt/transforms/tfrt_jitrt_stub.h b/tensorflow/compiler/mlir/tfrt/transforms/tfrt_jitrt_stub.h deleted file mode 100644 index d9c00c4d376909..00000000000000 --- a/tensorflow/compiler/mlir/tfrt/transforms/tfrt_jitrt_stub.h +++ /dev/null @@ -1,71 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_TFRT_JITRT_STUB_H_ -#define TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_TFRT_JITRT_STUB_H_ - -#include - -#include "mlir/IR/Builders.h" // from @llvm-project -#include "mlir/IR/BuiltinOps.h" // from @llvm-project -#include "mlir/IR/DialectRegistry.h" // from @llvm-project -#include "mlir/IR/Value.h" // from @llvm-project -#include "mlir/Pass/PassManager.h" // from @llvm-project -#include "mlir/Transforms/DialectConversion.h" // from @llvm-project -#include "tensorflow/compiler/mlir/tfrt/transforms/corert_converter.h" -#include "tensorflow/compiler/mlir/tfrt/transforms/tfrt_pipeline_options.h" - -namespace tensorflow { - -class TfrtJitRtStub { - public: - virtual ~TfrtJitRtStub() = default; - - virtual void RegisterJitRtDialects(mlir::DialectRegistry ®istry) {} - - virtual void PopulateJitRtConversionPatterns( - mlir::ConversionTarget *target, mlir::MLIRContext *context, - mlir::RewritePatternSet *patterns, CoreRTConverter *corert_converter) {} - - virtual mlir::Value CreateJitRtFallbackCompileKernel( - mlir::OpBuilder &builder, mlir::ModuleOp module, - mlir::Value chain_value) { - return chain_value; - } - - virtual void AddTfrtJitRtPasses(const TfrtPipelineOptions &options, - mlir::OpPassManager &pm) {} -}; - -void RegisterTfrtJitRtStub(std::unique_ptr stub); - -void RegisterJitRtDialects(mlir::DialectRegistry ®istry); - -// Helper function for inserting TFRT JitRt dialect conversions. -void PopulateJitRtConversionPatterns(mlir::ConversionTarget *target, - mlir::MLIRContext *context, - mlir::RewritePatternSet *patterns, - CoreRTConverter *corert_converter); - -mlir::Value CreateJitRtFallbackCompileKernel(mlir::OpBuilder &builder, - mlir::ModuleOp module, - mlir::Value chain_value); - -void AddTfrtJitRtPasses(const TfrtPipelineOptions &options, - mlir::OpPassManager &pm); - -} // namespace tensorflow - -#endif // TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_TFRT_JITRT_STUB_H_ diff --git a/tensorflow/core/runtime_fallback/BUILD b/tensorflow/core/runtime_fallback/BUILD index 435d5474147060..e04b69e250d8b0 100644 --- a/tensorflow/core/runtime_fallback/BUILD +++ b/tensorflow/core/runtime_fallback/BUILD @@ -31,7 +31,6 @@ tf_cc_binary( deps = [ ":bef_executor_lib", "@com_google_absl//absl/strings", - "//tensorflow/compiler/mlir/tfrt:tf_jitrt_kernels_alwayslink", "//tensorflow/core/platform:stream_executor", "//tensorflow/core/runtime_fallback/conversion:conversion_alwayslink", "//tensorflow/core/runtime_fallback/kernel:kernel_fallback_kernels_alwayslink", diff --git a/tensorflow/core/runtime_fallback/util/BUILD b/tensorflow/core/runtime_fallback/util/BUILD index a820dad23a3dda..ee575490dd35a6 100644 --- a/tensorflow/core/runtime_fallback/util/BUILD +++ b/tensorflow/core/runtime_fallback/util/BUILD @@ -81,7 +81,6 @@ cc_library( hdrs = ["fallback_test_util.h"], tags = ["no_oss"], deps = [ - "//tensorflow/compiler/mlir/tfrt:tf_jitrt_request_context", "//tensorflow/core:framework", "//tensorflow/core/platform:threadpool_interface", "//tensorflow/core/runtime_fallback/kernel:kernel_fallback_execute_compat", diff --git a/tensorflow/core/runtime_fallback/util/fallback_test_util.cc b/tensorflow/core/runtime_fallback/util/fallback_test_util.cc index 9d5029d747aa5c..af9bba7079fd3f 100644 --- a/tensorflow/core/runtime_fallback/util/fallback_test_util.cc +++ b/tensorflow/core/runtime_fallback/util/fallback_test_util.cc @@ -18,7 +18,6 @@ limitations under the License. #include #include -#include "tensorflow/compiler/mlir/tfrt/jit/tf_jitrt_request_context.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/runtime_fallback/kernel/kernel_fallback_execute_compat.h" #include "tensorflow/core/runtime_fallback/runtime/kernel_utils.h" @@ -76,7 +75,6 @@ tfrt::ExecutionContext CreateFallbackTestExecutionContext( /*cancellation_manager=*/nullptr); TF_DCHECK_OK(status); - status = SetUpTfJitRtRequestContext(&request_context_builder); TF_DCHECK_OK(status); auto request_context = std::move(request_context_builder).build(); diff --git a/tensorflow/core/tfrt/graph_executor/BUILD b/tensorflow/core/tfrt/graph_executor/BUILD index 5a3a69305074df..0eb60064cbecd1 100644 --- a/tensorflow/core/tfrt/graph_executor/BUILD +++ b/tensorflow/core/tfrt/graph_executor/BUILD @@ -64,7 +64,6 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow:error_util", "//tensorflow/compiler/mlir/tensorflow:import_model", "//tensorflow/compiler/mlir/tfrt:import_model", - "//tensorflow/compiler/mlir/tfrt:tf_jitrt_request_context", "//tensorflow/compiler/mlir/tfrt:tfrt_compile_options", "//tensorflow/compiler/mlir/tfrt:transforms/update_op_cost_in_tfrt_mlir", "//tensorflow/compiler/mlir/tfrt/transforms/mlrt:import_model", diff --git a/tensorflow/core/tfrt/graph_executor/graph_executor.cc b/tensorflow/core/tfrt/graph_executor/graph_executor.cc index 0e7059be656fad..4b33b4750ef4dd 100644 --- a/tensorflow/core/tfrt/graph_executor/graph_executor.cc +++ b/tensorflow/core/tfrt/graph_executor/graph_executor.cc @@ -44,7 +44,6 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h" #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" -#include "tensorflow/compiler/mlir/tfrt/jit/tf_jitrt_request_context.h" #include "tensorflow/compiler/mlir/tfrt/transforms/mlrt/import_model.h" #include "tensorflow/compiler/mlir/tfrt/transforms/update_op_cost_in_tfrt_mlir.h" #include "tensorflow/compiler/mlir/tfrt/translate/import_model.h" @@ -238,8 +237,6 @@ StatusOr> CreateRequestInfo( fallback_request_state.set_cancellation_manager( &request_info->cancellation_manager); - TF_RETURN_IF_ERROR( - tensorflow::SetUpTfJitRtRequestContext(&request_context_builder)); // Set priority in the builder. tfrt::RequestOptions request_options; request_options.priority = run_options.priority; diff --git a/tensorflow/core/tfrt/saved_model/BUILD b/tensorflow/core/tfrt/saved_model/BUILD index 699b84ce7c5812..3f08eabe5e5626 100644 --- a/tensorflow/core/tfrt/saved_model/BUILD +++ b/tensorflow/core/tfrt/saved_model/BUILD @@ -147,8 +147,6 @@ cc_library( "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", - "//tensorflow/compiler/mlir/tfrt:tf_jitrt_kernels_alwayslink", - "//tensorflow/compiler/mlir/tfrt:tfrt_jitrt_passes", "//tensorflow/core/framework:graph_proto_cc", "//tensorflow/core/platform:thread_annotations", "//tensorflow/core/protobuf:for_core_protos_cc", From 1c0b86b0bc7a1ea95e867e1d7ccf16edb67a0f05 Mon Sep 17 00:00:00 2001 From: Junwhan Ahn Date: Wed, 12 Jul 2023 18:36:09 -0700 Subject: [PATCH 236/376] Fix segfault in `XlaCallModule` shape inference Entries to `xla_call_module_loaders_` were added eagerly with nullptr as values (to save a lookup on a miss), but this caused some entries to keep nullptr loaders if we fail to initialize the loader. This CL changes the logic so that we insert only non-nullptr loaders to the map. Confirmed that the added MLIR test crashes without the fix. PiperOrigin-RevId: 547656759 --- .../tensorflow/tests/shape_inference.mlir | 6 ++ .../tensorflow/transforms/shape_inference.cc | 78 +++++++++---------- 2 files changed, 45 insertions(+), 39 deletions(-) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir b/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir index 71fbf7cca9ee58..c65bd89421132c 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir @@ -1297,6 +1297,12 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr func.return %0 : tensor<*xf32> } + func.func @xla_call_module_parsing_error(%arg0: tensor) -> tensor<*xf32> { + %0 = "tf.Identity"(%arg0) : (tensor) -> tensor<*xf32> + %1 = "tf.XlaCallModule"(%arg0, %0) {Sout = [#tf_type.shape<*>], device = "", dim_args_spec = [], module = "invalid-stablehlo-module", platforms = [], version = 4 : i64} : (tensor, tensor<*xf32>) -> tensor<*xf32> + func.return %1 : tensor<*xf32> + } + // CHECK-LABEL: func @xla_host_compute_mlir_empty_module func.func @xla_host_compute_mlir_empty_module(%arg0: tensor<2xf32>) -> tensor<*xf32> { // CHECK: "tf._XlaHostComputeMlir" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc index 19c6050aaf21c8..92bbc1f5a9910b 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc @@ -1191,49 +1191,49 @@ bool ShapeInference::InferShapeForCaseRegion(CaseRegionOp op) { bool ShapeInference::InferShapeForXlaCallModule(XlaCallModuleOp op) { tensorflow::XlaCallModuleLoader* loader; - { - const auto [it, inserted] = xla_call_module_loaders_.insert({op, nullptr}); - + if (auto it = xla_call_module_loaders_.find(op); + it != xla_call_module_loaders_.end()) { + loader = it->second.get(); + } else { // Lazily parse XlaCallModule's embedded HLO module and cache the loader to // avoid repeatedly parsing the module. - if (inserted) { - std::vector dim_args_spec; - for (auto attr : op.getDimArgsSpec().getAsRange()) { - dim_args_spec.push_back(attr.getValue().str()); - } - std::vector disabled_checks; - for (auto attr : op.getDisabledChecks().getAsRange()) { - disabled_checks.push_back(attr.getValue().str()); - } - std::vector platforms; - for (auto attr : op.getPlatforms().getAsRange()) { - platforms.push_back(attr.getValue().str()); - } - // Always use the first platform. The assumption is that shape inference - // results should be the same regardless of which platform is chosen. - // Very old versions of the op have an empty platforms attribute. - std::string loading_platform = - (platforms.empty() ? "CPU" : platforms.front()); - - // It is a terrible idea to have local MLIR contexts so we need to - // register extensions here, again. - mlir::DialectRegistry registry; - registry.insert(); - mlir::func::registerAllExtensions(registry); - xla_call_module_context_.appendDialectRegistry(registry); - - auto l = tensorflow::XlaCallModuleLoader::Create( - &xla_call_module_context_, op.getVersion(), op.getModule().str(), - std::move(dim_args_spec), std::move(disabled_checks), - std::move(platforms), std::move(loading_platform)); - if (!l.ok()) { - LLVM_DEBUG(llvm::dbgs() << "Parsing error in XlaCallModule: " - << l.status().ToString() << "\n"); - return false; - } - it->second = *std::move(l); + + std::vector dim_args_spec; + for (auto attr : op.getDimArgsSpec().getAsRange()) { + dim_args_spec.push_back(attr.getValue().str()); + } + std::vector disabled_checks; + for (auto attr : op.getDisabledChecks().getAsRange()) { + disabled_checks.push_back(attr.getValue().str()); + } + std::vector platforms; + for (auto attr : op.getPlatforms().getAsRange()) { + platforms.push_back(attr.getValue().str()); + } + // Always use the first platform. The assumption is that shape inference + // results should be the same regardless of which platform is chosen. + // Very old versions of the op have an empty platforms attribute. + std::string loading_platform = + (platforms.empty() ? "CPU" : platforms.front()); + + // It is a terrible idea to have local MLIR contexts so we need to + // register extensions here, again. + mlir::DialectRegistry registry; + registry.insert(); + mlir::func::registerAllExtensions(registry); + xla_call_module_context_.appendDialectRegistry(registry); + + auto l = tensorflow::XlaCallModuleLoader::Create( + &xla_call_module_context_, op.getVersion(), op.getModule().str(), + std::move(dim_args_spec), std::move(disabled_checks), + std::move(platforms), std::move(loading_platform)); + if (!l.ok()) { + LLVM_DEBUG(llvm::dbgs() << "Parsing error in XlaCallModule: " + << l.status().ToString() << "\n"); + return false; } + it = xla_call_module_loaders_.insert({op, *std::move(l)}).first; loader = it->second.get(); } From 16d2b012e3cfc4fc697d0006eaaf237aae2f7e9f Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 12 Jul 2023 22:28:06 -0700 Subject: [PATCH 237/376] Internal Code Change PiperOrigin-RevId: 547691215 --- tensorflow/python/lib/core/BUILD | 1 - 1 file changed, 1 deletion(-) diff --git a/tensorflow/python/lib/core/BUILD b/tensorflow/python/lib/core/BUILD index 5526bddbaf51ff..46152cc1402cf9 100644 --- a/tensorflow/python/lib/core/BUILD +++ b/tensorflow/python/lib/core/BUILD @@ -176,7 +176,6 @@ cc_library( ":ndarray_tensor", ":ndarray_tensor_bridge", ":py_util", - ":safe_pyobject_ptr", "//tensorflow/c:safe_ptr", "//tensorflow/c:tf_status_helper", "//tensorflow/c/eager:c_api", From 22f177bbee5941df81db68211137b31a95f5def3 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 12 Jul 2023 22:40:02 -0700 Subject: [PATCH 238/376] Internal Code Change PiperOrigin-RevId: 547692952 --- .../compiler/mlir/tensorflow/translate/export_graphdef.cc | 2 -- 1 file changed, 2 deletions(-) diff --git a/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc b/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc index 74cf842327062a..63ed9aac1db4e4 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc @@ -69,8 +69,6 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" namespace tensorflow { -using llvm::dyn_cast; -using llvm::isa; using mlir::BlockArgument; using mlir::Dialect; using mlir::Operation; From 7af32ab11fda194da4c1a179a7046293c3e25ade Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Wed, 12 Jul 2023 23:40:36 -0700 Subject: [PATCH 239/376] Revert: Attempt to be less restrictive in FusionCanShareBufferHint(). Initially I added a restriction that a fusion parameter should not be (transitively) used by more than one fusion output. I added this restriction because there was a test that was failing otherwise. After thinking about this again, the real bug was just that I didn't check whether any of the users of the fusion output for which we want to see whether it can share the buffer with the fusion operand has some non-elementwise user. PiperOrigin-RevId: 547702127 --- .../compiler/xla/service/gpu/gpu_compiler.cc | 25 +++++++++++-------- .../service/gpu/gpu_copy_insertion_test.cc | 7 +++--- 2 files changed, 18 insertions(+), 14 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index f490f9b127e21a..18cf6860a0419d 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -1676,8 +1676,9 @@ std::optional GpuCompiler::FusionCanShareBufferHint( } // We need to make sure that the fusion parameter is accessed in the same - // iteration order as the fusion output. Also, there should not be any other - // fusion output that accesses it in a different iteration order. To make sure + // iteration order as the fusion output. Also, there should not be two fusion + // outputs that consume the fusion parameter, because we do not want to share + // the same fusion operand with two different fusion outputs. To make sure // that the iteration order is the same, we only allow ops on the path from // fusion parameter to fusion output which are elementwise (no copy) or // bitcast or an elementwise dynamic update slice (i.e. with the first operand @@ -1702,8 +1703,12 @@ std::optional GpuCompiler::FusionCanShareBufferHint( q.pop(); if (hlo_operand == output) { found_path_to_output = true; - // We still need to process the users of 'hlo_operand'. There can be other - // users in addition to the tuple user. + // The output should have at most 1 user: the tuple op (in case of a + // multi-output fusion) + if (hlo_operand->user_count() > 1) { + return false; + } + continue; } for (HloInstruction* hlo : hlo_operand->users()) { if (non_bitcast_root->opcode() == HloOpcode::kDynamicUpdateSlice && @@ -1730,8 +1735,10 @@ std::optional GpuCompiler::FusionCanShareBufferHint( } else if ((!hlo->IsElementwiseOnOperand( hlo->operand_index(hlo_operand)) || hlo->opcode() == HloOpcode::kCopy) && - hlo->opcode() != HloOpcode::kBitcast && - hlo->opcode() != HloOpcode::kTuple) { + hlo->opcode() != HloOpcode::kBitcast) { + // This check also catches the case that we reach a different fusion + // output, as that fusion output would have a tuple op as user, which we + // do not allow here. // Even if 'hlo' is not elementwise on the operand, it is ok if we are // coming from the second operand and 'hlo' is a DynamicUpdateSlice // which is the non_bitcast_root. This corresponds to the special case @@ -1745,11 +1752,9 @@ std::optional GpuCompiler::FusionCanShareBufferHint( return false; } } - if (visited.contains(hlo)) { - continue; + if (visited.insert(hlo).second) { + q.push(hlo); } - visited.insert(hlo); - q.push(hlo); } } return found_path_to_output; diff --git a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion_test.cc index eb953cb2c395c4..9889142b22d2b1 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion_test.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion_test.cc @@ -204,14 +204,13 @@ fused_computation { param_1.1 = f32[2,3]{1,0} parameter(1) neg = f32[2,3]{1,0} negate(param_1.1) mul = f32[2,3]{1,0} multiply(param_0.1, neg) - transpose = f32[3,2]{1,0} transpose(neg), dimensions={1,0} - ROOT tuple = (f32[2,3]{1,0}, f32[2,3]{1,0}, f32[3,2]{1,0}) tuple(mul, neg, transpose) + ROOT tuple = (f32[2,3]{1,0}, f32[2,3]{1,0}) tuple(mul, neg) } ENTRY main { param_0 = f32[2,3]{1,0} parameter(0) param_1 = f32[2,3]{1,0} parameter(1) - ROOT fusion = (f32[2,3]{1,0}, f32[2,3]{1,0}, f32[3,2]{1,0}) fusion(param_0, param_1), kind=kLoop, calls=fused_computation + ROOT fusion = (f32[2,3]{1,0}, f32[2,3]{1,0}) fusion(param_0, param_1), kind=kLoop, calls=fused_computation } )"; @@ -221,7 +220,7 @@ ENTRY main { ExpectOptionalTrue( GpuCompiler::FusionCanShareBufferHint(fusion, fusion->operand(0), {0})); // The second operand cannot share the buffer with the second fusion output, - // because the 'neg' op is also used by a non-elementwise op. + // because the 'neg' op is also used on the path to the first fusion output. ExpectOptionalFalse( GpuCompiler::FusionCanShareBufferHint(fusion, fusion->operand(1), {1})); // The first operand cannot share the buffer with the second fusion output, From 1268853e63a4e15271d50478c766dddbae527d09 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 12 Jul 2023 23:55:40 -0700 Subject: [PATCH 240/376] Internal Code Change PiperOrigin-RevId: 547704731 --- tensorflow/compiler/mlir/tfrt/benchmarks/BUILD | 2 -- 1 file changed, 2 deletions(-) diff --git a/tensorflow/compiler/mlir/tfrt/benchmarks/BUILD b/tensorflow/compiler/mlir/tfrt/benchmarks/BUILD index 8632e4a71855af..a5a53c3e5150a7 100644 --- a/tensorflow/compiler/mlir/tfrt/benchmarks/BUILD +++ b/tensorflow/compiler/mlir/tfrt/benchmarks/BUILD @@ -51,7 +51,6 @@ cc_library( deps = [ ":benchmark", "//tensorflow/compiler/jit:flags", - "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tfrt:host_context_util", "//tensorflow/compiler/mlir/tfrt:runtime_fallback_executor", "//tensorflow/compiler/mlir/tfrt:tf_jitrt_pipeline", @@ -61,7 +60,6 @@ cc_library( "//tensorflow/core:test", "//tensorflow/core:test_main", "@llvm-project//llvm:Support", - "@llvm-project//mlir:IR", "@llvm-project//mlir:Parser", "@llvm-project//mlir:mlir_c_runner_utils", "@tf_runtime//:basic_kernels_alwayslink", From b9df58ccd364b0eb5fbbe452da0af3c99b34a664 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 13 Jul 2023 01:18:05 -0700 Subject: [PATCH 241/376] Internal Code Change PiperOrigin-RevId: 547719212 --- tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc index eb84f8143113b4..23886f95d089db 100644 --- a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc @@ -282,7 +282,7 @@ Status ConvertTFExecutorToStablehloFlatbuffer( } // for now always output mlir - if (/*export_to_mlir*/ true) { + if (/*export_to_mlir*/ /* DISABLES CODE */ (true)) { llvm::raw_string_ostream os(*result); module.print(os); return statusHandler.ConsumeStatus(); From 8181a7ff54ea71de602cde3f901fe2a912d32b49 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 13 Jul 2023 01:33:47 -0700 Subject: [PATCH 242/376] Integrate LLVM at llvm/llvm-project@a69b2e3d1c1a Updates LLVM usage to match [a69b2e3d1c1a](https://github.com/llvm/llvm-project/commit/a69b2e3d1c1a) PiperOrigin-RevId: 547722023 --- third_party/llvm/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl index 843f993f9edec6..b9f904631a41c3 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "b10899d869954e1426684cbc20a43d7303075d49" - LLVM_SHA256 = "62df1d4c4a10d9fa1c805b8eeddd5448e819ee98cf2ac8306b63b68d67656568" + LLVM_COMMIT = "a69b2e3d1c1a123e66df58116e5ca0e57e808307" + LLVM_SHA256 = "6a613ef7f464231b3b5c953095bd19c2a7a813e9b8df1e474b66ba542f04f87f" tf_http_archive( name = name, From a09d9aee8ca880ffaf2b8a9420a50139d62ade41 Mon Sep 17 00:00:00 2001 From: Raman Sarokin Date: Thu, 13 Jul 2023 01:40:16 -0700 Subject: [PATCH 243/376] Fixed string comparison. PiperOrigin-RevId: 547723283 --- tensorflow/lite/delegates/gpu/common/task/gpu_operation.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tensorflow/lite/delegates/gpu/common/task/gpu_operation.cc b/tensorflow/lite/delegates/gpu/common/task/gpu_operation.cc index 117467691754dc..2abe35c7a248f9 100644 --- a/tensorflow/lite/delegates/gpu/common/task/gpu_operation.cc +++ b/tensorflow/lite/delegates/gpu/common/task/gpu_operation.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include "absl/strings/match.h" #include "absl/strings/str_replace.h" #include "absl/strings/substitute.h" #include "tensorflow/lite/delegates/gpu/common/access_type.h" @@ -606,7 +607,7 @@ GPUOperation CreateGpuOperation(const OperationDef& definition, op.elementwise_code_ = std::move(descriptor.code); op.elementwise_ = true; if (definition.src_tensors.size() > 1 && - op.elementwise_code_.find("in2_value")) { + absl::StrContains(op.elementwise_code_, "in2_value")) { const auto second_tensor_def = definition.src_tensors[1]; if (NeedsBroadcast(second_tensor_def, second_shape)) { const std::string x_coord = second_shape.w == 1 ? "0" : "X_COORD"; From e07594edc52a96290a4ebec93b41818aed1147b6 Mon Sep 17 00:00:00 2001 From: Ilia Sergachev Date: Thu, 13 Jul 2023 01:40:33 -0700 Subject: [PATCH 244/376] [XLA:GPU] Roll-forward cl/543680393: Fuse more inputs into Triton GEMMs. - Let the GEMM rewriter do more complex traversals of inputs and fuse elementwise operations and broadcasts of scalar constants. - Limit the number of parameters per fusion. - Reorganize GPU compiler pipeline: bf16 float normalization is now required both before and after Triton GEMM fusion. - Remove an autotuner config that for unknown reasons fails on Volta with new fusions. One problem with the original CL was fixed in cl/544612599. Other ones are fixed in this one and are covered by the new tests GemmRewriterTritonTest.DoNotFuseIncompatibleDimOrders and DoNotFuseTooManyParameters. New fusion kinds are now under a flag. PiperOrigin-RevId: 547723336 --- .../compiler/xla/debug_options_flags.cc | 6 + tensorflow/compiler/xla/service/gpu/BUILD | 7 + .../xla/service/gpu/gemm_rewriter_triton.cc | 451 ++++++++++++------ .../xla/service/gpu/gemm_rewriter_triton.h | 49 +- .../service/gpu/gemm_rewriter_triton_test.cc | 156 +++++- .../compiler/xla/service/gpu/gpu_compiler.cc | 37 +- .../xla/service/gpu/ir_emitter_triton.cc | 2 +- .../xla/service/gpu/ir_emitter_triton_test.cc | 161 +++++++ tensorflow/compiler/xla/xla.proto | 4 +- 9 files changed, 713 insertions(+), 160 deletions(-) diff --git a/tensorflow/compiler/xla/debug_options_flags.cc b/tensorflow/compiler/xla/debug_options_flags.cc index c3f9d75942da09..5b8c7ce5d7ce04 100644 --- a/tensorflow/compiler/xla/debug_options_flags.cc +++ b/tensorflow/compiler/xla/debug_options_flags.cc @@ -138,6 +138,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_gpu_enable_cudnn_int8x32_convolution_reordering(true); opts.set_xla_gpu_triton_gemm_any(false); opts.set_xla_gpu_enable_triton_softmax_fusion(false); + opts.set_xla_gpu_triton_fusion_level(1); // Moving reduce-scatter out of while loops can increase memory footprint, so // turning it off by default. @@ -1130,6 +1131,11 @@ void MakeDebugOptionsFlags(std::vector* flag_list, "Forces any reductions during matrix multiplications to use the " "accumulator type and not the output type. The precision of the dot " "operation may not increase that much if there is output fusion.")); + flag_list->push_back(tsl::Flag( + "xla_gpu_triton_fusion_level", + int32_setter_for(&DebugOptions::set_xla_gpu_triton_fusion_level), + debug_options->xla_gpu_triton_fusion_level(), + "Triton fusion level, higher levels mean more fused operations.")); } // NOLINT(readability/fn_size) // Allocates flag_values and flag_objects; this function must not be called more diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index a3a5b5c0d3fa78..ce10f6e407ae21 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -434,6 +434,7 @@ cc_library( "//tensorflow/tsl/platform:errors", "//tensorflow/tsl/platform:logging", "//tensorflow/tsl/platform:path", + "//tensorflow/tsl/platform:statusor", "//tensorflow/tsl/platform:tensor_float_32_hdr_lib", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", @@ -489,6 +490,8 @@ xla_test( "//tensorflow/compiler/xla:autotuning_proto_cc", "//tensorflow/compiler/xla:error_spec", "//tensorflow/compiler/xla/hlo/ir:hlo", + "//tensorflow/compiler/xla/service:pattern_matcher", + "//tensorflow/compiler/xla/service:pattern_matcher_gmock", "//tensorflow/compiler/xla/service/gpu/tests:gpu_codegen_test", "//tensorflow/compiler/xla/stream_executor:device_description", "//tensorflow/compiler/xla/stream_executor/cuda:cublas_plugin", @@ -1154,18 +1157,22 @@ cc_library( "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status", + "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/hlo/ir:hlo", "//tensorflow/compiler/xla/hlo/utils:hlo_query", "//tensorflow/compiler/xla/service:hlo_creation_utils", "//tensorflow/compiler/xla/service:hlo_pass", + "//tensorflow/compiler/xla/service:instruction_fusion", + "//tensorflow/compiler/xla/stream_executor:device_description", "//tensorflow/tsl/platform:errors", "//tensorflow/tsl/platform:status", "//tensorflow/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", ], diff --git a/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton.cc b/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton.cc index 6b28352ccd61ab..f3b29b86d93a18 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton.cc @@ -22,12 +22,15 @@ limitations under the License. #include #include #include +#include #include #include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/log/check.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/autotuning.pb.h" @@ -37,6 +40,7 @@ limitations under the License. #include "tensorflow/compiler/xla/hlo/ir/hlo_instruction.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_instructions.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_opcode.h" +#include "tensorflow/compiler/xla/hlo/ir/hlo_schedule.h" #include "tensorflow/compiler/xla/hlo/utils/hlo_query.h" #include "tensorflow/compiler/xla/layout.h" #include "tensorflow/compiler/xla/literal_util.h" @@ -46,9 +50,12 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/gpu/matmul_utils.h" #include "tensorflow/compiler/xla/service/hlo_creation_utils.h" +#include "tensorflow/compiler/xla/service/instruction_fusion.h" #include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/stream_executor/device_description.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/tsl/platform/errors.h" @@ -57,6 +64,25 @@ limitations under the License. namespace xla { namespace gpu { + +bool TensorIterationSpec::operator==(const TensorIterationSpec& other) const { + for (int dim = 0; dim < TensorIterationSpec::kMaxDimsPerTensor; ++dim) { + if (dim_iteration_specs_[dim].size() != other[dim].size()) { + return false; + } + for (int fragment = 0; fragment < dim_iteration_specs_[dim].size(); + ++fragment) { + if (dim_iteration_specs_[dim][fragment].stride != + other[dim][fragment].stride || + dim_iteration_specs_[dim][fragment].count != + other[dim][fragment].count) { + return false; + } + } + } + return true; +} + namespace { // Batch dimensions of an operand of a dot instruction. @@ -95,10 +121,10 @@ int64_t NonContractingDimensionIndex(const HloInstruction& dot, } // Data types that are tested to work in the triton GEMM emitter. -bool IsSupportedDataType(PrimitiveType t, GpuVersion gpu_version) { +bool IsSupportedDataType(PrimitiveType type, GpuVersion gpu_version) { auto cuda_compute_capability = std::get(gpu_version); - switch (t) { + switch (type) { case PRED: case S8: case S16: @@ -114,21 +140,17 @@ bool IsSupportedDataType(PrimitiveType t, GpuVersion gpu_version) { } } -Status RequireTritonFusibleConvert(const HloInstruction* input, - GpuVersion gpu_version) { - if (!IsSupportedDataType(input->operand(0)->shape().element_type(), - gpu_version)) { - return Unimplemented("unsupported data type"); - } +FusionDecision RequireTritonFusibleConvert(const HloInstruction* input, + GpuVersion gpu_version) { // TODO(b/266862494): Can pick up almost any // convert, but if it's reducing the data volume it should rather be fused // to the output of the producer kernel. However not all operations support // output fusion - then it should be fused here anyway! if (ShapeUtil::ByteSizeOf(input->operand(0)->shape()) > ShapeUtil::ByteSizeOf(input->shape())) { - return FailedPrecondition("narrowing conversion"); + return "Narrowing conversion."; } - return OkStatus(); + return FusionDecision{}; } // Handles numbers of dimensions of a target HLO instruction @@ -142,6 +164,13 @@ class DimensionOrder { int64_t target_dim_number; int subdim_number; int64_t size; + bool operator==(const DimDescription& other) const { + return target_dim_number == other.target_dim_number && + subdim_number == other.subdim_number && size == other.size; + } + std::string ToString() const { + return absl::StrCat(target_dim_number, ":", subdim_number, ":", size); + } }; // Sequence describing all dimensions of HLO's output shape // in layout minor-to-major (physical) order. @@ -171,34 +200,32 @@ class DimensionOrder { // Transforms the DimensionOrder so that from a description of the output // of `hlo` it becomes a description of the input of `hlo`. - Status HandleInstruction(const HloInstruction* hlo) { + FusionDecision HandleInstruction(const HloInstruction* hlo) { VLOG(7) << hlo->ToString(); if (hlo->opcode() == HloOpcode::kParameter) { - return OkStatus(); + return FusionDecision{}; } else if (hlo->opcode() == HloOpcode::kTranspose || hlo->opcode() == HloOpcode::kCopy) { return HandleCopyOrTranspose(hlo); } else if (hlo->operand_count() > 0 && IsTritonSupportedElementwise( hlo->opcode(), hlo->operand(0)->shape().element_type())) { - return OkStatus(); + return FusionDecision{}; } else if (hlo->opcode() == HloOpcode::kBitcast) { return HandleBitcast(hlo); } else if (hlo->opcode() == HloOpcode::kReshape) { if (!ShapeUtil::ReshapeIsBitcast(hlo->operand(0)->shape(), hlo->shape())) { - return Unimplemented("Non-bitcast reshape."); + return "Non-bitcast reshape."; } return HandleBitcast(hlo); } else if (hlo_query::IsScalarConstant(hlo) || hlo_query::IsBroadcastOfScalarConstant(*hlo)) { // Dimension order collapses on a scalar, for simplicity leave it equal // to the output one for now. - return OkStatus(); - } else { - return Unimplemented("Instruction: %s", hlo->ToString()); + return FusionDecision{}; } - return OkStatus(); + return "Unimplemented instruction."; } // Get the raw data of the dimension order. @@ -210,20 +237,32 @@ class DimensionOrder { return splittable_dimension_index_; } + // Tells that two dimension orders describe the same tensor physical layout. + bool IsPhysicallyEquivalent(const DimensionOrder& other) const; + + std::string ToString() const { + return absl::StrJoin(dim_order_, "-", + [](std::string* out, const DimDescription& d) { + absl::StrAppend(out, d.ToString()); + }); + } + private: // See HandleInstruction() for the general description of Handle*(). - Status HandleBitcast(const HloInstruction* hlo); - Status HandleCopyOrTranspose(const HloInstruction* hlo); + FusionDecision HandleBitcast(const HloInstruction* hlo); + FusionDecision HandleCopyOrTranspose(const HloInstruction* hlo); DimOrderVector dim_order_; - int64_t splittable_dimension_index_; + const int64_t splittable_dimension_index_; }; -DotFusionAnalysis::TensorIterationSpec DimensionOrderToTensorIterationSpec( +using DimIterationSpec = TensorIterationSpec::DimIterationSpec; + +TensorIterationSpec DimensionOrderToTensorIterationSpec( const DimensionOrder& order) { const DimensionOrder::DimOrderVector& dim_order_vector = order.GetDimOrderVector(); - DotFusionAnalysis::TensorIterationSpec tensor_spec; + TensorIterationSpec tensor_spec; int64_t accumulated_stride = 1; for (int dim_order_index = 0; dim_order_index < dim_order_vector.size(); ++dim_order_index) { @@ -236,8 +275,7 @@ DotFusionAnalysis::TensorIterationSpec DimensionOrderToTensorIterationSpec( continue; } - DotFusionAnalysis::DimIterationSpec& dim_spec = - tensor_spec[dim.target_dim_number]; + DimIterationSpec& dim_spec = tensor_spec[dim.target_dim_number]; if (dim_order_index > 0 && dim_order_vector[dim_order_index - 1].target_dim_number == dim.target_dim_number) { @@ -257,7 +295,7 @@ DotFusionAnalysis::TensorIterationSpec DimensionOrderToTensorIterationSpec( accumulated_stride *= dim.size; } // Create all absent dimensions as degenerate ones to simplify later queries. - for (DotFusionAnalysis::DimIterationSpec& dim_spec : tensor_spec) { + for (DimIterationSpec& dim_spec : tensor_spec) { if (dim_spec.empty()) { dim_spec.push_back({/*stride=*/0, /*count=*/1, /*subfragments=*/{1}}); } @@ -265,6 +303,11 @@ DotFusionAnalysis::TensorIterationSpec DimensionOrderToTensorIterationSpec( return tensor_spec; } +bool DimensionOrder::IsPhysicallyEquivalent(const DimensionOrder& other) const { + return DimensionOrderToTensorIterationSpec(*this) == + DimensionOrderToTensorIterationSpec(other); +} + DimensionOrder DimensionOrder::FromDotOperand(const HloInstruction& dot, const int operand_number, const int64_t split_k) { @@ -287,7 +330,7 @@ DimensionOrder DimensionOrder::FromDotOutput(const HloInstruction& dot) { return DimensionOrder(&dot); } -Status DimensionOrder::HandleBitcast(const HloInstruction* hlo) { +FusionDecision DimensionOrder::HandleBitcast(const HloInstruction* hlo) { const Shape& operand_shape = hlo->operand(0)->shape(); DimOrderVector operand_dim_order; operand_dim_order.reserve(dim_order_.size()); @@ -301,7 +344,7 @@ Status DimensionOrder::HandleBitcast(const HloInstruction* hlo) { ++out_dim) { if (operand_remaining_size >= out_dim->size) { if (operand_remaining_size % out_dim->size) { - return Unimplemented("Unsupported bitcast: %s", hlo->ToString()); + return "Unsupported bitcast"; } // Output dimension fragment completely fits into the operand one: // just copy it as is. @@ -319,7 +362,7 @@ Status DimensionOrder::HandleBitcast(const HloInstruction* hlo) { // If there is a remaining fragment of a previous operand dimension // assign it first. if (out_remaining_size % operand_remaining_size) { - return Unimplemented("Unsupported bitcast: %s", hlo->ToString()); + return "Unsupported bitcast"; } operand_dim_order.push_back( {out_dim->target_dim_number, subdim_index, operand_remaining_size}); @@ -337,7 +380,7 @@ Status DimensionOrder::HandleBitcast(const HloInstruction* hlo) { // assign the remainder of the output and carry over the remainder // of the operand. if (operand_dim_size % out_remaining_size) { - return Unimplemented("Unsupported bitcast: %s", hlo->ToString()); + return "Unsupported bitcast"; } operand_remaining_size = operand_dim_size / out_remaining_size; new_fragment_size = out_remaining_size; @@ -358,7 +401,7 @@ Status DimensionOrder::HandleBitcast(const HloInstruction* hlo) { int subdim_index = operand_dim_order.back().subdim_number + 1; while (operand_dim_iter != operand_shape.layout().minor_to_major().cend()) { if (operand_shape.dimensions(*operand_dim_iter) != 1) { - return Unimplemented("Unsupported bitcast: %s", hlo->ToString()); + return "Unsupported bitcast"; } operand_dim_order.push_back( {operand_dim_order.back().target_dim_number, subdim_index, 1}); @@ -367,10 +410,11 @@ Status DimensionOrder::HandleBitcast(const HloInstruction* hlo) { } dim_order_ = operand_dim_order; - return OkStatus(); + return FusionDecision{}; } -Status DimensionOrder::HandleCopyOrTranspose(const HloInstruction* hlo) { +FusionDecision DimensionOrder::HandleCopyOrTranspose( + const HloInstruction* hlo) { // Every HLO dimension can correspond to a group of subdimensions in // dim_order_. For the easier handling of permutations: group dim_order_ by // dimension, apply permutations, then finally remove the grouping. @@ -419,25 +463,25 @@ Status DimensionOrder::HandleCopyOrTranspose(const HloInstruction* hlo) { dim_order_.push_back(subdim); } } - return OkStatus(); + return FusionDecision{}; } // Tells if the dimension order is supported by the triton GEMM emitter. // Only the dimension indicated by SplittableDimensionIndex() can be split // physically once by other dimensions. Other ones can be only split logically. // All subdimensions within a dimension have to be ordered. -Status RequireTritonGemmSupportedDimOrder(const DimensionOrder& order) { - std::array subdim_counters = { +FusionDecision RequireTritonGemmSupportedDimOrder(const DimensionOrder& order) { + std::array subdim_counters = { -1, -1, -1, -1}; - std::array split_counters = { + std::array split_counters = { -1, -1, -1, -1}; const DimensionOrder::DimOrderVector& dim_order_vector = order.GetDimOrderVector(); + VLOG(8) << order.ToString(); for (int i = 0; i < dim_order_vector.size(); i++) { const auto [dim_number, subdim_number, size] = dim_order_vector[i]; - VLOG(8) << dim_number << "\t" << subdim_number << "\t" << size; if (subdim_counters[dim_number] != subdim_number - 1) { - return Unimplemented("Transpose within a dimension."); + return "Transpose within a dimension."; } ++subdim_counters[dim_number]; if (size == 1) { @@ -447,31 +491,209 @@ Status RequireTritonGemmSupportedDimOrder(const DimensionOrder& order) { ++split_counters[dim_number]; if (dim_number == order.SplittableDimensionIndex()) { if (split_counters[dim_number] > 1) { - return Unimplemented("2nd split of a splittable dimension."); + return "2nd split of a splittable dimension."; } } else if (split_counters[dim_number] > 0) { - return Unimplemented("Split of a non-splittable dimension."); + return "Split of a non-splittable dimension."; } } } - return OkStatus(); + return FusionDecision{}; +} + +// Difference of input and output data volumes of an instruction. +int64_t InputMinusOutputBytes(const HloInstruction& hlo) { + CHECK(!hlo.shape().IsTuple()); + int64_t input_size = 0; + for (const HloInstruction* operand : hlo.operands()) { + CHECK(!operand->shape().IsTuple()); + input_size += ShapeUtil::ByteSizeOf(operand->shape()); + } + return input_size - ShapeUtil::ByteSizeOf(hlo.shape()); +} + +// Tells if an instruction has no input into which it could be fused. +// More cases should be added here. +bool CanNotBeFusedIntoAProducer(const HloInstruction& hlo) { + return hlo_query::AllOperandsAreParametersOrConstants(hlo); +} + +// Tells that fusing an instruction is efficient. +bool IsInputWorthFusing(const HloInstruction& hlo) { + if (hlo.user_count() > 1) { + return false; + } + // Let input and output data volumes of a fusion grow by small amounts. + constexpr int kIoToleranceBytes = 1024; + return hlo_query::AllOperandsAreParametersOrConstants(hlo) || + InputMinusOutputBytes(hlo) <= kIoToleranceBytes; } -// Transforms dim_order describing the output of `hlo` into a +// Checks if the instruction is possible and profitable to fuse. +// If so tries to transform dim_order describing output of `hlo` into a // description of its input if it is supported by the triton GEMM emitter. -Status CanFuse(const HloInstruction* hlo, DimensionOrder& dim_order, - const GpuVersion gpu_version) { - if (hlo->opcode() == HloOpcode::kConvert) { - return RequireTritonFusibleConvert(hlo, gpu_version); - } else if (hlo->IsElementwise() && hlo->opcode() != HloOpcode::kCopy) { - // Temporarily forbid fusing elementwise operations - // other than copy and convert. - return Unimplemented("Unsupported elementwise operation"); +FusionDecision CanFuse(const HloInstruction& hlo, DimensionOrder& dim_order, + const GpuVersion gpu_version) { + if (hlo.opcode() == HloOpcode::kTuple || + hlo.opcode() == HloOpcode::kGetTupleElement) { + return "Unsupported instruction."; + } + for (const HloInstruction* operand : hlo.operands()) { + if (!IsSupportedDataType(operand->shape().element_type(), gpu_version)) { + return "Unsupported input data type."; + } + } + if (!IsSupportedDataType(hlo.shape().element_type(), gpu_version)) { + return "Unsupported output data type."; + } + if (hlo.IsConstant()) { + return "Not fusing a constant."; + } + if (hlo.opcode() == HloOpcode::kBroadcast) { + return "Not fusing a broadcast."; + } + if (hlo.GetModule()->config().debug_options().xla_gpu_triton_fusion_level() < + 2) { + if (hlo.opcode() == HloOpcode::kConvert) { + if (FusionDecision decision = + RequireTritonFusibleConvert(&hlo, gpu_version); + !decision) { + return decision; + } + } else if (hlo.IsElementwise() && hlo.opcode() != HloOpcode::kCopy) { + return "Ignored elementwise operation"; + } + } else { + if (!CanNotBeFusedIntoAProducer(hlo) && !IsInputWorthFusing(hlo)) { + return "Not obviously profitable to fuse as input."; + } + } + + if (FusionDecision decision = dim_order.HandleInstruction(&hlo); !decision) { + return decision; } - TF_RETURN_IF_ERROR(dim_order.HandleInstruction(hlo)); return RequireTritonGemmSupportedDimOrder(dim_order); } +// Clone an instruction into the fusion. +void Fuse(HloInstruction& hlo, + absl::flat_hash_map& + old_to_new_mapping, + std::vector& call_operands, + HloComputation::Builder& builder) { + if (old_to_new_mapping.contains(&hlo)) { + return; + } + VLOG(3) << "Fusing " << hlo.ToString(); + auto get_or_add_parameter = [&](HloInstruction& instr) { + if (auto it = old_to_new_mapping.find(&instr); + it != old_to_new_mapping.end()) { + return it->second; + } + call_operands.push_back(&instr); + return old_to_new_mapping + .insert({&instr, + builder.AddInstruction(HloInstruction::CreateParameter( + call_operands.size() - 1, instr.shape(), + absl::StrCat("parameter_", call_operands.size() - 1)))}) + .first->second; + }; + if (hlo.opcode() == HloOpcode::kParameter || + hlo.opcode() == HloOpcode::kGetTupleElement) { + get_or_add_parameter(hlo); + } else { + std::vector hlo_new_operands; + for (HloInstruction* operand : hlo.operands()) { + hlo_new_operands.push_back(get_or_add_parameter(*operand)); + } + old_to_new_mapping[&hlo] = builder.AddInstruction( + hlo.CloneWithNewOperands(hlo.shape(), hlo_new_operands)); + } +} + +// Tells how many new parameters does a fusion gain by fusing the operation as +// an input. +int64_t NumAddedParameters(const HloInstruction& hlo) { + // Non-scalar constant is equivalent to a parameter: one input, one output. + if (hlo.opcode() == HloOpcode::kConstant && + !ShapeUtil::IsScalar(hlo.shape())) { + return 0; + } + // All other instructions add all own inputs and remove own single output. + return hlo.operand_count() - 1; +} + +// Fuse an instruction with all its fusible inputs. +// If an input is not fusible stop there and make a parameter of the new +// fusion, otherwise put it onto stack and check its own inputs first. +void FuseWithInputsRecursively( + HloInstruction* root, DimensionOrder root_dim_order, + // Dimension orders describing inputs of corresponding instructions. + absl::flat_hash_map& dim_orders, + const GpuVersion gpu_version, + absl::flat_hash_map& + old_to_new_mapping, + std::vector& call_operands, + HloComputation::Builder& builder) { + absl::flat_hash_set visited; + std::stack to_fuse; + // Instructions at the edge 'to_fuse' that can either get fused too or + // become parameters of the fusion. Used to track the number of parameters + // of the fusion. + absl::flat_hash_set inputs; + // Currently only one physically unique dim order per scope is supported. + // Let it change while the scope has one input; afterwards require all + // of them to be physically compatible. + const HloInstruction* reference_dim_order_hlo = nullptr; + if (CanFuse(*root, root_dim_order, gpu_version)) { + to_fuse.push(root); + inputs.insert(root->operands().begin(), root->operands().end()); + // root_dim_order went through output -> input transformation here. + CHECK(dim_orders.insert({root, root_dim_order}).second) << root->ToString(); + } + visited.insert(root); + while (!to_fuse.empty()) { + bool top_is_ready_to_fuse = true; + HloInstruction* hlo = to_fuse.top(); + if (reference_dim_order_hlo == nullptr && hlo->operand_count() > 1) { + reference_dim_order_hlo = hlo; + } + for (HloInstruction* operand : hlo->mutable_operands()) { + if (visited.insert(operand).second) { + // Stop adding new parameters. + if (inputs.size() >= DotFusionAnalysis::kMaxParameterPerScope && + NumAddedParameters(*operand) > 0) { + continue; + } + // Operand's output is described by its consumer's input. + DimensionOrder operand_dim_order(dim_orders.at(hlo)); + // CanFuse() makes output -> input transformation of + // operand_dim_order if succeeds. + if (CanFuse(*operand, operand_dim_order, gpu_version)) { + if (reference_dim_order_hlo != nullptr && + !operand_dim_order.IsPhysicallyEquivalent( + dim_orders.at(reference_dim_order_hlo))) { + continue; + } + to_fuse.push(operand); + if (operand->opcode() != HloOpcode::kParameter) { + inputs.erase(operand); + } + inputs.insert(operand->operands().begin(), operand->operands().end()); + // Save the dimension order description of operand's input. + CHECK(dim_orders.insert({operand, operand_dim_order}).second) + << operand->ToString(); + top_is_ready_to_fuse = false; + } + } + } + if (top_is_ready_to_fuse) { + Fuse(*hlo, old_to_new_mapping, call_operands, builder); + to_fuse.pop(); + } + } +} + // Extracts into fused computations parts of HLO graph including dot() // operations that can target the triton GEMM emitter. class GemmRewriterTritonVisitor : public DfsHloRewriteVisitor { @@ -483,8 +705,9 @@ class GemmRewriterTritonVisitor : public DfsHloRewriteVisitor { // and replaces the original dot() with a call to the computation. Status HandleDot(HloInstruction* dot) override { VLOG(5) << dot->ToString(); - - if (!CanTritonHandleGEMM(*dot, gpu_version_)) { + FusionDecision can_handle = CanTritonHandleGEMM(*dot, gpu_version_); + if (!can_handle) { + VLOG(3) << can_handle.Explain(); return OkStatus(); } @@ -503,72 +726,28 @@ class GemmRewriterTritonVisitor : public DfsHloRewriteVisitor { std::string suggested_name = absl::StrCat("triton_gemm_", dot->name()); HloComputation::Builder builder( absl::StrCat(suggested_name, "_computation")); + std::vector call_operands; // Original instruction -> fused one. absl::flat_hash_map old_to_new_mapping; - absl::flat_hash_set visited; - std::vector call_operands; - // Traverse and fuse dot() inputs bottom-up starting from direct operands. - // If an input is not fusible stop there and make it a parameter of the new - // fusion, otherwise put it onto stack and check its own inputs first. - std::stack to_fuse; - // Dimension orders describing inputs of corresponding instructions. - absl::flat_hash_map dim_orders; - to_fuse.push(dot); - while (!to_fuse.empty()) { - bool top_is_ready_to_fuse = true; - HloInstruction* hlo = to_fuse.top(); - for (HloInstruction* operand : hlo->mutable_operands()) { - if (visited.insert(operand).second) { - DimensionOrder operand_dim_order = [&] { - // Direct dot inputs are described by default dimension orders. - if (operand == dot->operand(0)) { - return DimensionOrder::FromDotOperand(*dot, 0); - } else if (operand == dot->operand(1)) { - return DimensionOrder::FromDotOperand(*dot, 1); - } - // Otherwise operand's output is described by its consumer's input. - return DimensionOrder(dim_orders.at(hlo)); - }(); - // CanFuse() makes output -> input transformation of - // operand_dim_order if succeeds. - if (CanFuse(operand, operand_dim_order, gpu_version_).ok()) { - VLOG(3) << "Fusing " << operand->ToString(); - to_fuse.push(operand); - // Save the dimension order description of operand's input. - dim_orders.insert({operand, operand_dim_order}); - top_is_ready_to_fuse = false; - } - } - } - if (top_is_ready_to_fuse) { - if (hlo->opcode() == HloOpcode::kParameter || - hlo->opcode() == HloOpcode::kGetTupleElement) { - old_to_new_mapping[hlo] = - builder.AddInstruction(HloInstruction::CreateParameter( - call_operands.size(), hlo->shape(), - absl::StrCat("parameter_", call_operands.size()))); - call_operands.push_back(hlo); - } else { - std::vector hlo_new_operands; - for (HloInstruction* operand : hlo->operands()) { - const auto iter = old_to_new_mapping.find(operand); - if (iter != old_to_new_mapping.end()) { - hlo_new_operands.push_back(iter->second); - } else { - hlo_new_operands.push_back( - builder.AddInstruction(HloInstruction::CreateParameter( - call_operands.size(), operand->shape(), - absl::StrCat("parameter_", call_operands.size())))); - call_operands.push_back(operand); - } - } - old_to_new_mapping[hlo] = builder.AddInstruction( - hlo->CloneWithNewOperands(hlo->shape(), hlo_new_operands)); - } - to_fuse.pop(); - } - } + + auto fuse_inputs = [&](int operand_number) { + absl::flat_hash_map dim_orders; + int operand_count_before = call_operands.size(); + // Direct dot inputs have well defined dimension orders. + FuseWithInputsRecursively( + dot->mutable_operand(operand_number), + DimensionOrder::FromDotOperand(*dot, operand_number), dim_orders, + gpu_version_, old_to_new_mapping, call_operands, builder); + return call_operands.size() - operand_count_before; + }; + // Separate traversal from LHS and RHS inputs of the dot: they use + // differently shaped tiles but may go through same HLO graph nodes. + TF_RET_CHECK(fuse_inputs(0) <= DotFusionAnalysis::kMaxParameterPerScope); + TF_RET_CHECK(fuse_inputs(1) <= DotFusionAnalysis::kMaxParameterPerScope); + + Fuse(*dot, old_to_new_mapping, call_operands, builder); + HloComputation* computation = dot->GetModule()->AddComputationAndUnifyNamesAndIds(builder.Build(), /*is_entry=*/false); @@ -592,7 +771,7 @@ class GemmRewriterTritonVisitor : public DfsHloRewriteVisitor { } else { TF_RETURN_IF_ERROR(ReplaceInstruction(dot, dot_fusion)); } - VLOG(5) << computation->ToString(); + XLA_VLOG_LINES(5, computation->ToString()); return OkStatus(); } @@ -643,7 +822,7 @@ StatusOr MakeSplitKOperand( for (const HloInstruction* param : analysis.ScopeParameters(scope)) { // If an operand of dot does not read any parameters its K dimension // does not need analysis for fragmentation. - const DotFusionAnalysis::DimIterationSpec* spec = + const DimIterationSpec* spec = analysis.IterSpec(scope, param, contracting_dim_idx); // Split contracting dimension is not implemented yet. CHECK_EQ(spec->size(), 1); @@ -885,8 +1064,8 @@ DotFusionAnalysis::DotFusionAnalysis(const HloComputation* dot_computation, absl::flat_hash_map dim_orders; DimensionOrder dot_operand_dim_order = DimensionOrder::FromDotOperand(*dot, operand_number, split_k); - TF_CHECK_OK(dot_operand_dim_order.HandleInstruction(dot_operand)); - TF_CHECK_OK(RequireTritonGemmSupportedDimOrder(dot_operand_dim_order)) + CHECK(dot_operand_dim_order.HandleInstruction(dot_operand)); + CHECK(RequireTritonGemmSupportedDimOrder(dot_operand_dim_order)) << dot_computation->ToString(); dim_orders.insert({dot_operand, dot_operand_dim_order}); visited.insert(dot_operand); @@ -907,14 +1086,18 @@ DotFusionAnalysis::DotFusionAnalysis(const HloComputation* dot_computation, {hlo_operand, DimensionOrder(dim_orders.at(hlo))}); CHECK(inserted); DimensionOrder& hlo_operand_dim_order = it->second; - TF_CHECK_OK(hlo_operand_dim_order.HandleInstruction(hlo_operand)); - TF_CHECK_OK(RequireTritonGemmSupportedDimOrder(hlo_operand_dim_order)) + CHECK(hlo_operand_dim_order.HandleInstruction(hlo_operand)); + CHECK(RequireTritonGemmSupportedDimOrder(hlo_operand_dim_order)) << " " << dot_computation->ToString(); to_process.push(hlo_operand); } } + // For now all parameters of one scope have to use the same tiling. for (const HloInstruction* parameter : parameters_[scope]) { + CHECK(dim_orders.at(parameter).IsPhysicallyEquivalent( + dim_orders.at(*parameters_[scope].cbegin()))) + << dot_computation->ToString(); iter_specs_[scope][parameter] = DimensionOrderToTensorIterationSpec(dim_orders.at(parameter)); } @@ -926,22 +1109,22 @@ DotFusionAnalysis::DotFusionAnalysis(const HloComputation* dot_computation, .second); } -const DotFusionAnalysis::DimIterationSpec* DotFusionAnalysis::IterSpec( +const DimIterationSpec* DotFusionAnalysis::IterSpec( const DotFusionAnalysis::Scope scope, const HloInstruction* hlo, const int dimension) const { auto ret = iter_specs_.at(scope).find(hlo); if (ret != iter_specs_.at(scope).end()) { - return &ret->second.at(dimension); + return &ret->second[dimension]; } return nullptr; } -bool CanTritonHandleGEMM(const HloInstruction& dot, - const GpuVersion gpu_version) { +FusionDecision CanTritonHandleGEMM(const HloInstruction& dot, + const GpuVersion gpu_version) { if (dot.opcode() != HloOpcode::kDot || absl::c_any_of(dot.precision_config().operand_precision(), [](int x) { return x != PrecisionConfig::DEFAULT; })) { - return false; + return "Non-default precision."; } auto supported_output_type = [&](const PrimitiveType t) { @@ -961,21 +1144,21 @@ bool CanTritonHandleGEMM(const HloInstruction& dot, // TODO(b/266862493): Support more output types. if (!supported_output_type(dot.shape().element_type())) { - return false; + return "Unsupported output data type."; } if (!IsSupportedDataType(dot.operand(0)->shape().element_type(), gpu_version) || !IsSupportedDataType(dot.operand(1)->shape().element_type(), gpu_version)) { - return false; + return "Unsupported input data type."; } const DotDimensionNumbers& dim_numbers = dot.dot_dimension_numbers(); // TODO(b/269580541): support multiple batch dimensions. if (dim_numbers.lhs_batch_dimensions().size() > 1) { - return false; + return "Multiple batch dimensions."; } // Cases where lhs or rhs have no non-contracting dims are not handled. @@ -985,10 +1168,10 @@ bool CanTritonHandleGEMM(const HloInstruction& dot, dim_numbers.rhs_batch_dimensions().size() + dim_numbers.rhs_contracting_dimensions().size() == dot.operand(1)->shape().rank()) { - return false; + return "No non-contracting dimensions."; } - return true; + return FusionDecision{}; } bool ShouldTritonHandleGEMM(const HloInstruction& dot, @@ -1008,7 +1191,7 @@ bool ShouldTritonHandleGEMM(const HloInstruction& dot, while (!queue.empty()) { const HloInstruction* current = queue.front(); queue.pop(); - if (!CanFuse(current, dim_order, gpu_version).ok()) { + if (!CanFuse(*current, dim_order, gpu_version)) { continue; } // Stop as soon as a profitable operation is fused. diff --git a/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton.h b/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton.h index 715c79d9114659..0afc939b43ede2 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton.h +++ b/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton.h @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/compiler/xla/hlo/ir/hlo_opcode.h" #include "tensorflow/compiler/xla/service/gpu/gpu_types.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/compiler/xla/service/instruction_fusion.h" namespace xla { namespace gpu { @@ -52,13 +53,13 @@ Status MakeDotSplitKBatch(HloInstruction* dot_fusion, const AutotuneResult::TritonGemmKey& tiling); // Filters GEMMs which can be handled using Triton. -bool CanTritonHandleGEMM(const HloInstruction&, GpuVersion gpu_version); +FusionDecision CanTritonHandleGEMM(const HloInstruction&, + GpuVersion gpu_version); // Filters GEMMs which are better to handle using Triton. bool ShouldTritonHandleGEMM(const HloInstruction&, GpuVersion gpu_version); -// Analysis of iteration of HLO shapes within a fusion around dot(). -class DotFusionAnalysis { +class TensorIterationSpec { public: // Description of basic iteration: `count` elements separated by `stride`. struct IterationSpecFragment { @@ -68,16 +69,42 @@ class DotFusionAnalysis { // of several HLO dimensions. Product of subfragments equals `count`. std::vector subfragments; }; - // Description of complex iteration over a sequence of several strides. // Describes a logically contiguous dimension of a tensor physically // separated into multiple fragments by other dimensions. using DimIterationSpec = std::vector; // At most: contracting, non-contracting, split-K, another batch. - static const int kMaxDimsPerTensor = 4; - using TensorIterationSpec = std::array; + static constexpr int kMaxDimsPerTensor = 4; + using StorageType = std::array; + + const DimIterationSpec& operator[](int dimension) const { + return dim_iteration_specs_[dimension]; + } + + DimIterationSpec& operator[](int dimension) { + return dim_iteration_specs_[dimension]; + } + + // Compares physical layouts of tensors ignoring subfragments of dimensions. + bool operator==(const TensorIterationSpec& other) const; + + StorageType::iterator begin() { return dim_iteration_specs_.begin(); } + StorageType::iterator end() { return dim_iteration_specs_.end(); } + StorageType::const_iterator cbegin() const { + return dim_iteration_specs_.cbegin(); + } + StorageType::const_iterator cend() const { + return dim_iteration_specs_.cend(); + } + + private: + StorageType dim_iteration_specs_; +}; +// Analysis of iteration of HLO shapes within a fusion around dot(). +class DotFusionAnalysis { + public: // Execute analysis of dot fusion computation. // split_k indicates whether this operation was converted to the split-K // form and tells the analysis how to interpret the batch dimensions. @@ -88,9 +115,15 @@ class DotFusionAnalysis { // defined by left operand, right operand and output. enum class Scope { LHS = 0, RHS = 1, OUTPUT = 2 }; + // Every parameter requires a separate piece of shared memory for asynchronous + // loads. Multiple parameters are approximately equivalent to multiple + // pipeline stages. + static constexpr int kMaxParameterPerScope = 4; + // Scope -> HLO -> dot dimension number -> iteration spec at the HLO's output. - const DimIterationSpec* IterSpec(Scope scope, const HloInstruction*, - int dimension) const; + const TensorIterationSpec::DimIterationSpec* IterSpec(Scope scope, + const HloInstruction*, + int dimension) const; // Parameter HLO instructions used in a scope of `dot`. const absl::flat_hash_set& ScopeParameters( const Scope scope) const { diff --git a/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton_test.cc b/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton_test.cc index d02faa5b3abdc9..95eaf51915d2e5 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton_test.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton_test.cc @@ -94,7 +94,7 @@ ENTRY e { GmockMatch(m::Fusion(m::Parameter(), m::Parameter()))); } -TEST_F(GemmRewriterTritonTest, DoNotFuseConstant) { +TEST_F(GemmRewriterTritonTest, DoNotFuseConstants) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(R"( HloModule m @@ -102,14 +102,14 @@ HloModule m ENTRY e { p0 = s8[60,5] parameter(0) c0 = f16[60,5] convert(p0) - cst1 = f16[600] constant({...}) - r1 = f16[5,120] reshape(cst1) + cst1 = f16[] constant(1234) + r1 = f16[5,120] broadcast(cst1) ROOT d = f16[60,120] dot(c0, r1), lhs_contracting_dims={1}, rhs_contracting_dims={0} })")); EXPECT_TRUE(GemmRewriterTriton(gpu_version_).Run(module.get()).value()); EXPECT_THAT(module->entry_computation()->root_instruction(), - GmockMatch(m::Fusion(m::Constant(), m::Parameter()))); + GmockMatch(m::Fusion(m::Parameter(), m::Broadcast()))); } using TritonDotAnalysisTest = HloTestBase; @@ -793,6 +793,154 @@ ENTRY e { EXPECT_TRUE(GemmRewriterTriton(cc).Run(module.get()).value()); } +class GemmRewriterTritonLevel2Test : public GemmRewriterTritonTest { + public: + DebugOptions GetDebugOptionsForTest() override { + DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest(); + debug_options.set_xla_gpu_triton_fusion_level(2); + return debug_options; + } +}; + +TEST_F(GemmRewriterTritonLevel2Test, DoNotFuseIncompatibleDimOrders) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +HloModule m + +ENTRY e { + p0 = f16[5,3] parameter(0) + p1 = f16[5,7] parameter(1) + p2 = f16[7,5] parameter(2) + t = f16[5,7] transpose(p2), dimensions={1,0} + a = f16[5,7] add(t, p1) + ROOT d = f16[3,7] dot(p0, a), + lhs_contracting_dims={0}, rhs_contracting_dims={0} +})")); + + EXPECT_TRUE(GemmRewriterTriton(gpu_version_).Run(module.get()).value()); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch(m::Fusion(m::Parameter(), m::Parameter(), m::Transpose()))); +} + +TEST_F(GemmRewriterTritonLevel2Test, DoNotFuseTooManyParameters) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +ENTRY e { + tmp_0 = f32[] constant(1) + tmp_1 = f32[3,49]{1,0} broadcast(tmp_0), dimensions={} + tmp_2 = f32[3,49]{1,0} parameter(6) + tmp_3 = f32[] constant(0) + tmp_4 = f32[3,49]{1,0} broadcast(tmp_3), dimensions={} + tmp_5 = pred[3,49]{1,0} compare(tmp_2, tmp_4), direction=GT + tmp_6 = f32[3,49]{1,0} convert(tmp_5) + tmp_7 = f32[3,49]{1,0} subtract(tmp_1, tmp_6) + tmp_8 = s32[] parameter(13) + tmp_9 = f32[] convert(tmp_8) + tmp_10 = f32[] maximum(tmp_9, tmp_0) + tmp_11 = f32[] divide(tmp_3, tmp_10) + tmp_12 = f32[3,49]{1,0} broadcast(tmp_11), dimensions={} + tmp_13 = pred[3,49]{1,0} parameter(7) + tmp_14 = pred[3,49]{1,0} parameter(10) + tmp_15 = pred[3,49]{1,0} and(tmp_13, tmp_14) + tmp_16 = f32[3,49]{1,0} convert(tmp_15) + tmp_17 = f32[3,49]{1,0} multiply(tmp_12, tmp_16) + tmp_18 = f32[3,49]{1,0} negate(tmp_17) + tmp_19 = f32[3,49]{1,0} multiply(tmp_7, tmp_18) + tmp_20 = f32[3,49]{1,0} parameter(19) + tmp_21 = f32[3,49]{1,0} subtract(tmp_1, tmp_20) + tmp_22 = f32[3,49]{1,0} divide(tmp_19, tmp_21) + tmp_23 = f32[3,49]{1,0} negate(tmp_22) + tmp_24 = f32[3,49]{1,0} negate(tmp_6) + tmp_25 = f32[3,49]{1,0} multiply(tmp_24, tmp_17) + tmp_26 = f32[3,49]{1,0} divide(tmp_25, tmp_20) + tmp_27 = f32[3,49]{1,0} add(tmp_23, tmp_26) + tmp_28 = f32[3,49]{1,0} parameter(18) + tmp_29 = f32[3,49]{1,0} multiply(tmp_27, tmp_28) + tmp_30 = f32[3,49]{1,0} parameter(17) + tmp_31 = f32[3,49]{1,0} multiply(tmp_29, tmp_30) + tmp_32 = f32[3,49]{1,0} parameter(16) + tmp_33 = f32[3,49]{1,0} multiply(tmp_31, tmp_32) + tmp_34 = f32[3,49]{1,0} parameter(15) + tmp_35 = f32[3,49]{1,0} add(tmp_33, tmp_34) + tmp_36 = f32[3,49]{1,0} parameter(14) + tmp_37 = f32[3,49]{1,0} add(tmp_35, tmp_36) + tmp_38 = f32[1,1]{1,0} constant({ {0} }) + tmp_39 = f32[1,1]{1,0} broadcast(tmp_38), dimensions={0,1} + tmp_40 = f32[] reshape(tmp_39) + tmp_41 = f32[3,32]{1,0} broadcast(tmp_40), dimensions={} + tmp_42 = u32[48]{0} parameter(11) + tmp_43 = u32[48]{0} parameter(5) + tmp_44 = u32[96]{0} concatenate(tmp_42, tmp_43), dimensions={0} + tmp_45 = u32[3,32]{1,0} reshape(tmp_44) + tmp_46 = u32[96]{0} reshape(tmp_45) + tmp_47 = u32[] constant(1) + tmp_48 = u32[3,32]{1,0} broadcast(tmp_47), dimensions={} + tmp_49 = u32[96]{0} reshape(tmp_48) + tmp_50 = u32[96]{0} shift-right-logical(tmp_46, tmp_49) + tmp_51 = u32[3,32]{1,0} reshape(tmp_50) + tmp_52 = u32[3,32]{1,0} or(tmp_51, tmp_48) + tmp_53 = f32[3,32]{1,0} bitcast-convert(tmp_52) + tmp_54 = f32[3,32]{1,0} broadcast(tmp_0), dimensions={} + tmp_55 = f32[3,32]{1,0} subtract(tmp_53, tmp_54) + tmp_56 = f32[1,1]{1,0} constant({ {1} }) + tmp_57 = f32[1,1]{1,0} broadcast(tmp_56), dimensions={0,1} + tmp_58 = f32[] reshape(tmp_57) + tmp_59 = f32[3,32]{1,0} broadcast(tmp_58), dimensions={} + tmp_60 = f32[3,32]{1,0} multiply(tmp_55, tmp_59) + tmp_61 = f32[3,32]{1,0} add(tmp_60, tmp_41) + tmp_62 = f32[3,32]{1,0} maximum(tmp_41, tmp_61) + tmp_63 = f32[3,32]{1,0} broadcast(tmp_3), dimensions={} + tmp_64 = pred[3,32]{1,0} compare(tmp_62, tmp_63), direction=LT + tmp_65 = f32[3,32]{1,0} convert(tmp_64) + tmp_66 = f32[3,49]{1,0} parameter(9) + tmp_67 = f32[49]{0} parameter(4) + tmp_68 = f32[3,49]{1,0} broadcast(tmp_67), dimensions={1} + tmp_69 = f32[3,49]{1,0} add(tmp_66, tmp_68) + tmp_70 = f32[1,49]{1,0} parameter(12) + tmp_71 = f32[1,49]{1,0} broadcast(tmp_0), dimensions={} + tmp_72 = f32[1,49]{1,0} divide(tmp_70, tmp_71) + tmp_73 = f32[1,49]{1,0} broadcast(tmp_72), dimensions={0,1} + tmp_74 = f32[49]{0} reshape(tmp_73) + tmp_75 = f32[3,49]{1,0} broadcast(tmp_74), dimensions={1} + tmp_76 = f32[3,49]{1,0} subtract(tmp_69, tmp_75) + tmp_77 = f32[1,49]{1,0} parameter(3) + tmp_78 = f32[1,49]{1,0} parameter(8) + tmp_79 = f32[1,49]{1,0} divide(tmp_78, tmp_71) + tmp_80 = f32[1,49]{1,0} multiply(tmp_72, tmp_72) + tmp_81 = f32[1,49]{1,0} subtract(tmp_79, tmp_80) + tmp_82 = f32[1,49]{1,0} add(tmp_81, tmp_71) + tmp_83 = f32[1,49]{1,0} rsqrt(tmp_82) + tmp_84 = f32[1,49]{1,0} multiply(tmp_77, tmp_83) + tmp_85 = f32[1,49]{1,0} broadcast(tmp_84), dimensions={0,1} + tmp_86 = f32[49]{0} reshape(tmp_85) + tmp_87 = f32[3,49]{1,0} broadcast(tmp_86), dimensions={1} + tmp_88 = f32[3,49]{1,0} multiply(tmp_76, tmp_87) + tmp_89 = f32[1,49]{1,0} parameter(2) + tmp_90 = f32[1,49]{1,0} broadcast(tmp_89), dimensions={0,1} + tmp_91 = f32[49]{0} reshape(tmp_90) + tmp_92 = f32[3,49]{1,0} broadcast(tmp_91), dimensions={1} + tmp_93 = f32[3,49]{1,0} add(tmp_88, tmp_92) + tmp_94 = f32[49,32]{1,0} parameter(1) + tmp_95 = f32[3,32]{1,0} dot(tmp_93, tmp_94), lhs_contracting_dims={1}, rhs_contracting_dims={0} + tmp_96 = f32[32]{0} parameter(0) + tmp_97 = f32[3,32]{1,0} broadcast(tmp_96), dimensions={1} + tmp_98 = f32[3,32]{1,0} add(tmp_95, tmp_97) + tmp_99 = f32[3,32]{1,0} multiply(tmp_65, tmp_98) + tmp_100 = f32[3,32]{1,0} divide(tmp_99, tmp_63) + tmp_101 = f32[3,32]{1,0} maximum(tmp_100, tmp_63) + ROOT tmp_102 = f32[49,32]{1,0} dot(tmp_37, tmp_101), lhs_contracting_dims={0}, rhs_contracting_dims={0} +})")); + + EXPECT_TRUE(GemmRewriterTriton(gpu_version_).Run(module.get()).value()); + EXPECT_EQ(module->entry_computation()->root_instruction()->opcode(), + HloOpcode::kFusion); + EXPECT_EQ(module->entry_computation()->root_instruction()->fusion_kind(), + HloInstruction::FusionKind::kCustom); + EXPECT_LE(module->entry_computation()->root_instruction()->operand_count(), + DotFusionAnalysis::kMaxParameterPerScope * 2); +} + } // namespace } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index 18cf6860a0419d..9670d2dc954216 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -973,6 +973,29 @@ Status GpuCompiler::OptimizeHloPostLayoutAssignment( }); } + GpuFloatSupport bf16_support(BF16); + GpuFloatSupport f8e5m2_support(F8E5M2); + GpuFloatSupport f8e4m3fn_support(F8E4M3FN); + FloatSupport f8e4m3b11fnuz_support(F8E4M3B11FNUZ); + FloatSupport f8e5m2fnuz_support(F8E5M2FNUZ); + FloatSupport f8e4m3fnuz_support(F8E4M3FNUZ); + + auto add_float_normalization = [&](HloPassPipeline& pipeline) { + auto& sub_pipeline = + pipeline.AddPass("float_normalization"); + sub_pipeline.AddPass(&bf16_support); + sub_pipeline.AddPass(&f8e5m2_support); + sub_pipeline.AddPass(&f8e4m3fn_support); + sub_pipeline.AddPass(&f8e4m3b11fnuz_support); + sub_pipeline.AddPass(&f8e5m2fnuz_support); + sub_pipeline.AddPass(&f8e4m3fnuz_support); + // Remove `f32 -> bf16 -> f32` casts inserted by bf16 normalization. + if (debug_options.xla_gpu_simplify_all_fp_conversions()) { + sub_pipeline.AddPass(); + } + }; + add_float_normalization(pipeline); + // By default use an externally provided thread pool. tsl::thread::ThreadPool* thread_pool = options.thread_pool; std::optional overriding_thread_pool; @@ -994,18 +1017,8 @@ Status GpuCompiler::OptimizeHloPostLayoutAssignment( &pipeline, hlo_module, stream_exec, debug_options, options, gpu_target_config, autotune_results, thread_pool)); - GpuFloatSupport bf16_support(BF16); - pipeline.AddPass(&bf16_support); - GpuFloatSupport f8e5m2_support(F8E5M2); - pipeline.AddPass(&f8e5m2_support); - GpuFloatSupport f8e4m3fn_support(F8E4M3FN); - pipeline.AddPass(&f8e4m3fn_support); - FloatSupport f8e4m3b11fnuz_support(F8E4M3B11FNUZ); - pipeline.AddPass(&f8e4m3b11fnuz_support); - FloatSupport f8e5m2fnuz_support(F8E5M2FNUZ); - pipeline.AddPass(&f8e5m2fnuz_support); - FloatSupport f8e4m3fnuz_support(F8E4M3FNUZ); - pipeline.AddPass(&f8e4m3fnuz_support); + // The Triton autotuner can insert new reductions. + add_float_normalization(pipeline); // Remove `f32 -> bf16 -> f32` casts inserted by bf16 normalization. if (debug_options.xla_gpu_simplify_all_fp_conversions()) { diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_triton.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_triton.cc index d1b570f06054b2..73a3a378f6fa73 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_triton.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_triton.cc @@ -793,7 +793,7 @@ StatusOr MatMulImpl( if (!analysis.ScopeParameters(DotFusionAnalysis::Scope::LHS).empty()) { const HloInstruction* lhs_param0 = *analysis.ScopeParameters(DotFusionAnalysis::Scope::LHS).begin(); - const DotFusionAnalysis::DimIterationSpec* lhs_nc_iter_spec = + const TensorIterationSpec::DimIterationSpec* lhs_nc_iter_spec = analysis.IterSpec(DotFusionAnalysis::Scope::LHS, lhs_param0, lhs_noncontracting_dim_idx); lhs_nc_split = lhs_nc_iter_spec->size() > 1; diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_triton_test.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_triton_test.cc index fc4bb7204c1632..e87b973b4a60c8 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_triton_test.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_triton_test.cc @@ -25,11 +25,14 @@ limitations under the License. #include "tensorflow/compiler/xla/autotuning.pb.h" #include "tensorflow/compiler/xla/error_spec.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_computation.h" +#include "tensorflow/compiler/xla/hlo/ir/hlo_instruction.h" #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" #include "tensorflow/compiler/xla/service/gpu/gpu_device_info_for_tests.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/gpu/launch_dimensions.h" #include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h" +#include "tensorflow/compiler/xla/service/pattern_matcher.h" +#include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h" #include "tensorflow/compiler/xla/stream_executor/device_description.h" #include "tensorflow/compiler/xla/tests/verified_hlo_module.h" #include "tensorflow/tsl/lib/core/status_test_util.h" @@ -42,6 +45,8 @@ namespace xla { namespace gpu { namespace { +namespace m = ::xla::match; + class TritonGemmNoTF32Test : public GpuCodegenTest { public: void SetUp() override { @@ -715,6 +720,162 @@ ENTRY e { EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{/*aabs=*/1e-6, /*arel=*/1e-6})); } +class TritonGemmLevel2Test : public TritonGemmTest { + public: + DebugOptions GetDebugOptionsForTest() override { + DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest(); + debug_options.set_xla_gpu_triton_fusion_level(2); + return debug_options; + } +}; + +TEST_F(TritonGemmLevel2Test, BinaryOperationWithSmallInputsIsFused) { + const std::string kHloText = R"( +HloModule m + +ENTRY e { + p0 = s8[7,3] parameter(0) + p1 = f32[3,16] parameter(1) + p2 = f32[3,16] parameter(2) + e = f32[3,16] exponential(p1) + a = f32[3,16] add(e, p2) + c = f32[7,3] convert(p0) + ROOT d = f32[7,16] dot(c, a), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +})"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + GetOptimizedModule(kHloText)); + + EXPECT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch(m::Fusion(m::Parameter(), m::Parameter(), m::Parameter()) + .WithFusionKind(HloInstruction::FusionKind::kCustom))); + + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-1, /*arel=*/1e-3})); +} + +TEST_F(TritonGemmLevel2Test, BinaryOperationWithLargeInputsIsNotFused) { + const std::string kHloText = R"( +HloModule m + +ENTRY e { + p0 = f16[333,1000] parameter(0) + p1 = f32[1000,333] parameter(1) + p1n = f32[1000,333] negate(p1) + p2 = f32[1000,333] parameter(2) + p2n = f32[1000,333] negate(p2) + s = f32[1000,333] subtract(p1n, p2n) + c = f32[333,1000] convert(p0) + ROOT d = f32[1000,1000] dot(s, c), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +})"; + + MatchOptimizedHlo(kHloText, R"( +; CHECK: fused_computation +; CHECK: negate +; CHECK: negate +; CHECK: ROOT +; CHECK-SAME: subtract +; CHECK: ENTRY +; CHECK: kLoop +; CHECK: kCustom +)"); + + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-1, /*arel=*/1e-3})); +} + +TEST_F(TritonGemmLevel2Test, BinaryOperationOnLargeParametersIsFused) { + const std::string kHloText = R"( +HloModule m + +ENTRY e { + p0 = f16[1000,111] parameter(0) + p1 = f32[111,10000] parameter(1) + p2 = f32[111,10000] parameter(2) + s = f32[111,10000] subtract(p1, p2) + c = f32[1000,111] convert(p0) + ROOT d = f32[10000,1000] dot(s, c), + lhs_contracting_dims={0}, rhs_contracting_dims={1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + GetOptimizedModule(kHloText)); + + EXPECT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch(m::Fusion(m::Parameter(), m::Parameter(), m::Parameter()) + .WithFusionKind(HloInstruction::FusionKind::kCustom))); + + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-1, /*arel=*/1e-3})); +} + +TEST_F(TritonGemmLevel2Test, LinkingLibdeviceTwiceWorks) { + const std::string kHloText = R"( +HloModule m + +ENTRY e { + p0 = s8[7,3] parameter(0) + c0 = f32[7,3] convert(p0) + e0 = f32[7,3] exponential(c0) + p1 = f32[3,16] parameter(1) + e1 = f32[3,16] exponential(p1) + d0 = f32[7,16] dot(c0, e1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} + d1 = f32[7,16] dot(e0, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} + ROOT a = f32[7,16] add(d0, d1) +})"; + + MatchOptimizedHlo(kHloText, R"( +; CHECK: ENTRY +; CHECK-NEXT: parameter +; CHECK-NEXT: parameter +; CHECK-NEXT: kCustom +; CHECK-NEXT: kCustom +; CHECK-NEXT: ROOT +; CHECK-SAME: add +)"); + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + GetOptimizedModule(kHloText)); + + EXPECT_THAT(module->entry_computation()->root_instruction(), + GmockMatch(m::Add( + m::Fusion(m::Parameter(), m::Parameter()) + .WithFusionKind(HloInstruction::FusionKind::kCustom), + m::Fusion(m::Parameter(), m::Parameter()) + .WithFusionKind(HloInstruction::FusionKind::kCustom)))); + + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-2, /*arel=*/1e-2})); +} + +TEST_F(TritonGemmLevel2Test, BroadcastOfConstantIsNotFused) { + const std::string kHloText = R"( +HloModule m + +ENTRY e { + p0 = f16[70,30] parameter(0) + p0c = f32[70,30] convert(p0) + constant_3663 = f32[] constant(4321) + bc0 = f32[30,5] broadcast(constant_3663) + p1 = f32[30,5] parameter(1) + a = f32[30,5] add(p1, bc0) + ROOT d = f32[70,5] dot(p0c, a), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +})"; + + MatchOptimizedHlo(kHloText, R"( +; CHECK: ENTRY +; CHECK: constant +; CHECK: broadcast +; CHECK: fusion +; CHECK-SAME: kind=kCustom +)"); + + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/2e-3, /*arel=*/2e-3})); +} + TEST_F(TritonGemmTest, Naming) { const char* hlo_text = R"( HloModule t diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto index 3eb4ae20db045d..f41923aeae68b6 100644 --- a/tensorflow/compiler/xla/xla.proto +++ b/tensorflow/compiler/xla/xla.proto @@ -570,7 +570,9 @@ message DebugOptions { bool xla_gpu_triton_gemm_disable_reduced_precision_reduction = 226; - // Next id: 229 + int32 xla_gpu_triton_fusion_level = 229; + + // Next id: 230 // Extra options to pass to the compilation backend (e.g. LLVM); specific // interpretation of these values is left to the backend. From dfd5e6ea4134fa1050ca56e1850d0ab54c72ea54 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 13 Jul 2023 02:02:02 -0700 Subject: [PATCH 245/376] Update GraphDef version to 1556. PiperOrigin-RevId: 547727193 --- tensorflow/core/public/version.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index 4beda718dbecb6..e6f885c504bda6 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -108,7 +108,7 @@ limitations under the License. #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 -#define TF_GRAPH_DEF_VERSION 1555 // Updated: 2023/7/12 +#define TF_GRAPH_DEF_VERSION 1556 // Updated: 2023/7/13 // Checkpoint compatibility versions (the versions field in SavedSliceMeta). // From 0afb54261291aac5b47f1958272cf4bf2a83a7b1 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 13 Jul 2023 02:02:23 -0700 Subject: [PATCH 246/376] compat: Update forward compatibility horizon to 2023-07-13 PiperOrigin-RevId: 547727263 --- tensorflow/python/compat/compat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index c2a858c6505dee..1bc69b19356702 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -29,7 +29,7 @@ # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2023, 7, 12) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2023, 7, 13) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None From 0db26373a6f813ff67503eca9c50e7ca40b0830f Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Thu, 13 Jul 2023 05:11:38 -0700 Subject: [PATCH 247/376] Clear the patch file, it is not needed anymore. PiperOrigin-RevId: 547761195 --- third_party/llvm/generated.patch | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/third_party/llvm/generated.patch b/third_party/llvm/generated.patch index 5539280dba4e32..e69de29bb2d1d6 100644 --- a/third_party/llvm/generated.patch +++ b/third_party/llvm/generated.patch @@ -1,18 +0,0 @@ -Auto generated patch. Do not edit or delete it, even if empty. -diff -ruN --strip-trailing-cr a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp ---- a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp -+++ b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp -@@ -79,10 +79,10 @@ - void EmulateFloatPattern::rewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const { - Location loc = op->getLoc(); -+ TypeConverter *converter = getTypeConverter(); - SmallVector resultTypes; -- assert( -- succeeded(getTypeConverter()->convertTypes(op->getResultTypes(), resultTypes)) && -- "type conversions shouldn't fail in this pass"); -+ LogicalResult pass = converter->convertTypes(op->getResultTypes(), resultTypes); -+ (void) pass; - Operation *expandedOp = - rewriter.create(loc, op->getName().getIdentifier(), operands, resultTypes, - op->getAttrs(), op->getSuccessors(), /*regions=*/{}); From d87b3cb500683c9645c596081bf7dda86d9f5ba8 Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Thu, 13 Jul 2023 05:39:58 -0700 Subject: [PATCH 248/376] Add HloVerifier to run_hlo_module. This is to guard against errors in input HLO. PiperOrigin-RevId: 547767557 --- tensorflow/compiler/xla/tools/BUILD | 1 + tensorflow/compiler/xla/tools/run_hlo_module.cc | 5 +++++ 2 files changed, 6 insertions(+) diff --git a/tensorflow/compiler/xla/tools/BUILD b/tensorflow/compiler/xla/tools/BUILD index bf53fbd0a5dfbc..961bd682056eff 100644 --- a/tensorflow/compiler/xla/tools/BUILD +++ b/tensorflow/compiler/xla/tools/BUILD @@ -418,6 +418,7 @@ cc_library( "//tensorflow/compiler/xla/hlo/ir:hlo", "//tensorflow/compiler/xla/service:hlo_proto_cc", "//tensorflow/compiler/xla/service:hlo_runner", + "//tensorflow/compiler/xla/service:hlo_verifier", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/tsl/platform:path", "//tensorflow/tsl/platform:status", diff --git a/tensorflow/compiler/xla/tools/run_hlo_module.cc b/tensorflow/compiler/xla/tools/run_hlo_module.cc index 3cb0bbdeac9fa4..09d7e316ceb6fa 100644 --- a/tensorflow/compiler/xla/tools/run_hlo_module.cc +++ b/tensorflow/compiler/xla/tools/run_hlo_module.cc @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_comparison.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/compiler/xla/service/hlo_verifier.h" #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/tools/hlo_control_flow_flattening.h" #include "tensorflow/compiler/xla/tools/hlo_module_loader.h" @@ -254,6 +255,10 @@ Status RunAndCompare( options.use_buffer_assignment_from_proto ? &buffer_assignment_proto : nullptr)); + HloVerifier verifier( + HloVerifierOpts{}.WithLayoutSensitive(false).WithAllowMixedPrecision( + true)); + TF_RETURN_IF_ERROR(verifier.Run(test_module.get()).status()); if (compilation_env_modifier_hook) { TF_CHECK_OK(compilation_env_modifier_hook(options, *test_module)) << "Could not adjust the compilation environment for user provided " From 08a1c52e618ef31a87eb2c2dd46f92da180c0b10 Mon Sep 17 00:00:00 2001 From: Ashish Shenoy Date: Thu, 13 Jul 2023 05:51:38 -0700 Subject: [PATCH 249/376] Remove unused autograph converter code in `tensorflow/python/autograph/converters/list_comprehensions.py`. PiperOrigin-RevId: 547770102 --- tensorflow/python/autograph/BUILD | 1 - tensorflow/python/autograph/converters/BUILD | 32 -------- .../python/autograph/converters/__init__.py | 30 ------- .../converters/list_comprehensions.py | 78 ------------------- .../converters/list_comprehensions_test.py | 57 -------------- tensorflow/tools/pip_package/BUILD | 1 - 6 files changed, 199 deletions(-) delete mode 100644 tensorflow/python/autograph/converters/__init__.py delete mode 100644 tensorflow/python/autograph/converters/list_comprehensions.py delete mode 100644 tensorflow/python/autograph/converters/list_comprehensions_test.py diff --git a/tensorflow/python/autograph/BUILD b/tensorflow/python/autograph/BUILD index 01306a77aa1018..af9ec45ab93dfd 100644 --- a/tensorflow/python/autograph/BUILD +++ b/tensorflow/python/autograph/BUILD @@ -13,7 +13,6 @@ py_strict_library( srcs_version = "PY3", visibility = ["//visibility:public"], deps = [ - "//tensorflow/python/autograph/converters:__init__", "//tensorflow/python/autograph/core:converter", "//tensorflow/python/autograph/impl:api", "//tensorflow/python/autograph/lang:directives", diff --git a/tensorflow/python/autograph/converters/BUILD b/tensorflow/python/autograph/converters/BUILD index 70d4bab1d2be48..0b1dcdd0ca3900 100644 --- a/tensorflow/python/autograph/converters/BUILD +++ b/tensorflow/python/autograph/converters/BUILD @@ -5,15 +5,6 @@ package( licenses = ["notice"], ) -py_strict_library( - name = "__init__", - srcs = ["__init__.py"], - visibility = ["//tensorflow:__subpackages__"], - deps = [ - ":list_comprehensions", - ], -) - py_strict_library( name = "slices", srcs = ["slices.py"], @@ -40,17 +31,6 @@ py_strict_library( ], ) -py_strict_library( - name = "list_comprehensions", - srcs = ["list_comprehensions.py"], - visibility = ["//tensorflow:__subpackages__"], - deps = [ - "//tensorflow/python/autograph/core:converter", - "//tensorflow/python/autograph/pyct:templates", - "@gast_archive//:gast", - ], -) - py_strict_library( name = "logical_expressions", srcs = ["logical_expressions.py"], @@ -352,18 +332,6 @@ py_strict_test( ], ) -py_strict_test( - name = "list_comprehensions_test", - srcs = ["list_comprehensions_test.py"], - python_version = "PY3", - srcs_version = "PY3", - deps = [ - ":list_comprehensions", - "//tensorflow/python/autograph/core:test_lib", - "//tensorflow/python/platform:client_testlib", - ], -) - py_strict_test( name = "lists_test", srcs = ["lists_test.py"], diff --git a/tensorflow/python/autograph/converters/__init__.py b/tensorflow/python/autograph/converters/__init__.py deleted file mode 100644 index fc8ae684c2a2a8..00000000000000 --- a/tensorflow/python/autograph/converters/__init__.py +++ /dev/null @@ -1,30 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Code converters used by Autograph.""" - -# Naming conventions: -# * each converter should specialize on a single idiom; be consistent with -# the Python reference for naming -# * all converters inherit core.converter.Base -# * module names describe the idiom that the converter covers, plural -# * the converter class is named consistent with the module, singular and -# includes the word Transformer -# -# Example: -# -# lists.py -# class ListTransformer(converter.Base) - -from tensorflow.python.autograph.converters import list_comprehensions diff --git a/tensorflow/python/autograph/converters/list_comprehensions.py b/tensorflow/python/autograph/converters/list_comprehensions.py deleted file mode 100644 index 8e8b97d03cc43d..00000000000000 --- a/tensorflow/python/autograph/converters/list_comprehensions.py +++ /dev/null @@ -1,78 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Lowers list comprehensions into for and if statements. - -Example: - - result = [x * x for x in xs] - -becomes - - result = [] - for x in xs: - elt = x * x - result.append(elt) -""" - -import gast - -from tensorflow.python.autograph.core import converter -from tensorflow.python.autograph.pyct import templates - - -# TODO(mdan): This should covert directly to operator calls. - - -class ListCompTransformer(converter.Base): - """Lowers list comprehensions into standard control flow.""" - - def visit_Assign(self, node): - if not isinstance(node.value, gast.ListComp): - return self.generic_visit(node) - if len(node.targets) > 1: - raise NotImplementedError('multiple assignments') - - target, = node.targets - list_comp_node = node.value - - template = """ - target = [] - """ - initialization = templates.replace(template, target=target) - - template = """ - target.append(elt) - """ - body = templates.replace(template, target=target, elt=list_comp_node.elt) - - for gen in reversed(list_comp_node.generators): - for gen_if in reversed(gen.ifs): - template = """ - if test: - body - """ - body = templates.replace(template, test=gen_if, body=body) - template = """ - for target in iter_: - body - """ - body = templates.replace( - template, iter_=gen.iter, target=gen.target, body=body) - - return initialization + body - - -def transform(node, ctx): - return ListCompTransformer(ctx).visit(node) diff --git a/tensorflow/python/autograph/converters/list_comprehensions_test.py b/tensorflow/python/autograph/converters/list_comprehensions_test.py deleted file mode 100644 index 630aad030c1e0a..00000000000000 --- a/tensorflow/python/autograph/converters/list_comprehensions_test.py +++ /dev/null @@ -1,57 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for list_comprehensions module.""" - -from tensorflow.python.autograph.converters import list_comprehensions -from tensorflow.python.autograph.core import converter_testing -from tensorflow.python.platform import test - - -class ListCompTest(converter_testing.TestCase): - - def assertTransformedEquivalent(self, f, *inputs): - tr = self.transform(f, list_comprehensions) - self.assertEqual(f(*inputs), tr(*inputs)) - - def test_basic(self): - - def f(l): - s = [e * e for e in l] - return s - - self.assertTransformedEquivalent(f, []) - self.assertTransformedEquivalent(f, [1, 2, 3]) - - def test_multiple_generators(self): - - def f(l): - s = [e * e for sublist in l for e in sublist] # pylint:disable=g-complex-comprehension - return s - - self.assertTransformedEquivalent(f, []) - self.assertTransformedEquivalent(f, [[1], [2], [3]]) - - def test_cond(self): - - def f(l): - s = [e * e for e in l if e > 1] - return s - - self.assertTransformedEquivalent(f, []) - self.assertTransformedEquivalent(f, [1, 2, 3]) - - -if __name__ == '__main__': - test.main() diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD index 08033624332047..d9d87af17935c1 100644 --- a/tensorflow/tools/pip_package/BUILD +++ b/tensorflow/tools/pip_package/BUILD @@ -115,7 +115,6 @@ COMMON_PIP_DEPS = [ "//tensorflow/lite/python:tflite_convert", "//tensorflow/lite/toco/python:toco_from_protos", "//tensorflow/lite/tools:visualize", - "//tensorflow/python/autograph/converters:list_comprehensions", "//tensorflow/python/autograph/core:test_lib", "//tensorflow/python/autograph/impl/testing:pybind_for_testing", "//tensorflow/python/autograph/pyct/testing:basic_definitions", From f31e1b5f0729e784db623ea7efbd39da8dc164d6 Mon Sep 17 00:00:00 2001 From: Ilia Sergachev Date: Thu, 13 Jul 2023 06:01:32 -0700 Subject: [PATCH 250/376] [XLA:GPU] Roll forward cl/543697810: Fuse outputs into Triton GEMMs. PiperOrigin-RevId: 547772221 --- tensorflow/compiler/xla/service/gpu/BUILD | 1 + .../xla/service/gpu/gemm_rewriter_triton.cc | 449 +++++++++++++----- .../service/gpu/gemm_rewriter_triton_test.cc | 110 ++++- .../xla/service/gpu/ir_emitter_triton_test.cc | 118 +++++ .../service/gpu/tests/gemm_rewrite_test.cc | 1 + tensorflow/compiler/xla/shape_util.cc | 16 + tensorflow/compiler/xla/shape_util.h | 3 + tensorflow/compiler/xla/shape_util_test.cc | 16 + 8 files changed, 580 insertions(+), 134 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index ce10f6e407ae21..50e052bb133fb7 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -1155,6 +1155,7 @@ cc_library( ":matmul_utils", "//tensorflow/compiler/xla:autotuning_proto_cc", "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:permutation_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla:status_macros", diff --git a/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton.cc b/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton.cc index f3b29b86d93a18..1c26b87c8ab9b2 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton.cc @@ -44,6 +44,7 @@ limitations under the License. #include "tensorflow/compiler/xla/hlo/utils/hlo_query.h" #include "tensorflow/compiler/xla/layout.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/permutation_util.h" #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" #include "tensorflow/compiler/xla/service/gpu/cublas_padding_requirements.h" #include "tensorflow/compiler/xla/service/gpu/gpu_types.h" @@ -140,6 +141,23 @@ bool IsSupportedDataType(PrimitiveType type, GpuVersion gpu_version) { } } +// Tells if f(a+b) == f(a) + f(b). +bool IsDistributiveOverAddition(const HloInstruction& hlo) { + // The list is most likely incomplete. + // For example division can be added too but only for operand #0. + if (hlo.opcode() == HloOpcode::kMultiply || + hlo.opcode() == HloOpcode::kNegate || + hlo.opcode() == HloOpcode::kBitcast || + hlo.opcode() == HloOpcode::kReshape || hlo.opcode() == HloOpcode::kCopy || + hlo.opcode() == HloOpcode::kTranspose || + hlo.opcode() == HloOpcode::kConvert || + hlo.opcode() == HloOpcode::kBroadcast || + hlo.opcode() == HloOpcode::kSlice) { + return true; + } + return false; +} + FusionDecision RequireTritonFusibleConvert(const HloInstruction* input, GpuVersion gpu_version) { // TODO(b/266862494): Can pick up almost any @@ -198,27 +216,30 @@ class DimensionOrder { // Create dimension order describing dot's output. static DimensionOrder FromDotOutput(const HloInstruction& dot); - // Transforms the DimensionOrder so that from a description of the output - // of `hlo` it becomes a description of the input of `hlo`. - FusionDecision HandleInstruction(const HloInstruction* hlo) { + enum class TransformDirection { kInputToOutput, kOutputToInput }; + + // Transforms the DimensionOrder so that from a description one side + // of `hlo` it becomes a description of the other side of `hlo`. + FusionDecision HandleInstruction(const HloInstruction* hlo, + TransformDirection direction) { VLOG(7) << hlo->ToString(); if (hlo->opcode() == HloOpcode::kParameter) { return FusionDecision{}; } else if (hlo->opcode() == HloOpcode::kTranspose || hlo->opcode() == HloOpcode::kCopy) { - return HandleCopyOrTranspose(hlo); + return HandleCopyOrTranspose(hlo, direction); } else if (hlo->operand_count() > 0 && IsTritonSupportedElementwise( hlo->opcode(), hlo->operand(0)->shape().element_type())) { return FusionDecision{}; } else if (hlo->opcode() == HloOpcode::kBitcast) { - return HandleBitcast(hlo); + return HandleBitcast(hlo, direction); } else if (hlo->opcode() == HloOpcode::kReshape) { if (!ShapeUtil::ReshapeIsBitcast(hlo->operand(0)->shape(), hlo->shape())) { return "Non-bitcast reshape."; } - return HandleBitcast(hlo); + return HandleBitcast(hlo, direction); } else if (hlo_query::IsScalarConstant(hlo) || hlo_query::IsBroadcastOfScalarConstant(*hlo)) { // Dimension order collapses on a scalar, for simplicity leave it equal @@ -249,8 +270,9 @@ class DimensionOrder { private: // See HandleInstruction() for the general description of Handle*(). - FusionDecision HandleBitcast(const HloInstruction* hlo); - FusionDecision HandleCopyOrTranspose(const HloInstruction* hlo); + FusionDecision HandleBitcast(const HloInstruction*, TransformDirection); + FusionDecision HandleCopyOrTranspose(const HloInstruction*, + TransformDirection); DimOrderVector dim_order_; const int64_t splittable_dimension_index_; @@ -330,101 +352,108 @@ DimensionOrder DimensionOrder::FromDotOutput(const HloInstruction& dot) { return DimensionOrder(&dot); } -FusionDecision DimensionOrder::HandleBitcast(const HloInstruction* hlo) { - const Shape& operand_shape = hlo->operand(0)->shape(); - DimOrderVector operand_dim_order; - operand_dim_order.reserve(dim_order_.size()); - // Size of not yet assigned part of current operand dimension. - int64_t operand_remaining_size = 1; - // Iterate in parallel over output dimension order and operand dimensions +FusionDecision DimensionOrder::HandleBitcast(const HloInstruction* hlo, + TransformDirection direction) { + const Shape& target_shape = (direction == TransformDirection::kOutputToInput) + ? hlo->operand(0)->shape() + : hlo->shape(); + DimOrderVector target_dim_order; + target_dim_order.reserve(dim_order_.size()); + // Size of not yet assigned part of current target dimension. + int64_t target_remaining_size = 1; + // Iterate in parallel over source dimension order and target dimensions // in minor_to_major order. Find groups of dimensions of equal size - // and project the output dimension order onto the operand. - auto operand_dim_iter = operand_shape.layout().minor_to_major().cbegin(); - for (auto out_dim = dim_order_.cbegin(); out_dim != dim_order_.cend(); - ++out_dim) { - if (operand_remaining_size >= out_dim->size) { - if (operand_remaining_size % out_dim->size) { + // and project the source dimension order onto the target. + auto target_dim_iter = target_shape.layout().minor_to_major().cbegin(); + for (auto src_dim = dim_order_.cbegin(); src_dim != dim_order_.cend(); + ++src_dim) { + if (target_remaining_size >= src_dim->size) { + if (target_remaining_size % src_dim->size) { return "Unsupported bitcast"; } - // Output dimension fragment completely fits into the operand one: + // Source dimension fragment completely fits into the target one: // just copy it as is. - operand_dim_order.push_back(*out_dim); - // Update the size of the remaining part of the operand that is - // carried over to next output dimensions. - operand_remaining_size /= out_dim->size; + target_dim_order.push_back(*src_dim); + // Update the size of the remaining part of the target that is + // carried over to next source dimensions. + target_remaining_size /= src_dim->size; } else { - // Output is larger than input. Assign further operand dimensions. - // Size of the not yet assigned part of the output dimension. - int64_t out_remaining_size = out_dim->size; + // Source is larger than target. Assign further target dimensions. + // Size of the not yet assigned part of the source dimension. + int64_t src_remaining_size = src_dim->size; // Subdimension index tracking dimension splits. - int subdim_index = out_dim->subdim_number; - if (operand_remaining_size > 1) { - // If there is a remaining fragment of a previous operand dimension + int subdim_index = src_dim->subdim_number; + if (target_remaining_size > 1) { + // If there is a remaining fragment of a previous target dimension // assign it first. - if (out_remaining_size % operand_remaining_size) { + if (src_remaining_size % target_remaining_size) { return "Unsupported bitcast"; } - operand_dim_order.push_back( - {out_dim->target_dim_number, subdim_index, operand_remaining_size}); + target_dim_order.push_back( + {src_dim->target_dim_number, subdim_index, target_remaining_size}); ++subdim_index; // Update the size of the fragment remaining to assign. - out_remaining_size /= operand_remaining_size; - operand_remaining_size = 1; + src_remaining_size /= target_remaining_size; + target_remaining_size = 1; } - while (out_remaining_size > 1) { - // Assign operand dimensions until the output remainder is covered. - int64_t operand_dim_size = operand_shape.dimensions(*operand_dim_iter); - int64_t new_fragment_size = operand_dim_size; - if (operand_dim_size > out_remaining_size) { - // If adding the next operand dimension exceeds output fragment size - // assign the remainder of the output and carry over the remainder - // of the operand. - if (operand_dim_size % out_remaining_size) { + while (src_remaining_size > 1) { + // Assign target dimensions until the source remainder is covered. + int64_t target_dim_size = target_shape.dimensions(*target_dim_iter); + int64_t new_fragment_size = target_dim_size; + if (target_dim_size > src_remaining_size) { + // If adding the next target dimension exceeds source fragment size + // assign the remainder of the source and carry over the remainder + // of the target. + if (target_dim_size % src_remaining_size) { return "Unsupported bitcast"; } - operand_remaining_size = operand_dim_size / out_remaining_size; - new_fragment_size = out_remaining_size; + target_remaining_size = target_dim_size / src_remaining_size; + new_fragment_size = src_remaining_size; } - operand_dim_order.push_back( - {out_dim->target_dim_number, subdim_index, new_fragment_size}); - out_remaining_size /= new_fragment_size; - ++operand_dim_iter; + target_dim_order.push_back( + {src_dim->target_dim_number, subdim_index, new_fragment_size}); + src_remaining_size /= new_fragment_size; + ++target_dim_iter; ++subdim_index; } } } - CHECK_EQ(operand_remaining_size, 1); + CHECK_EQ(target_remaining_size, 1); - // Handle remaining major dimensions of the operand. Call all degenerate + // Handle remaining major dimensions of the target. Call all degenerate // ones subdimensions of the most-major non-degenerate one. Otherwise // give up. - int subdim_index = operand_dim_order.back().subdim_number + 1; - while (operand_dim_iter != operand_shape.layout().minor_to_major().cend()) { - if (operand_shape.dimensions(*operand_dim_iter) != 1) { + int subdim_index = target_dim_order.back().subdim_number + 1; + while (target_dim_iter != target_shape.layout().minor_to_major().cend()) { + if (target_shape.dimensions(*target_dim_iter) != 1) { return "Unsupported bitcast"; } - operand_dim_order.push_back( - {operand_dim_order.back().target_dim_number, subdim_index, 1}); + target_dim_order.push_back( + {target_dim_order.back().target_dim_number, subdim_index, 1}); ++subdim_index; - ++operand_dim_iter; + ++target_dim_iter; } - dim_order_ = operand_dim_order; + dim_order_ = target_dim_order; return FusionDecision{}; } FusionDecision DimensionOrder::HandleCopyOrTranspose( - const HloInstruction* hlo) { + const HloInstruction* hlo, TransformDirection direction) { // Every HLO dimension can correspond to a group of subdimensions in // dim_order_. For the easier handling of permutations: group dim_order_ by // dimension, apply permutations, then finally remove the grouping. // Group subdimensions by iterating over them in the same order as over // dimensions and matching by total size. - std::vector out_physical; - out_physical.reserve(hlo->shape().rank()); + const HloInstruction* src = + (direction == TransformDirection::kOutputToInput) ? hlo : hlo->operand(0); + const HloInstruction* dst = + (direction == TransformDirection::kOutputToInput) ? hlo->operand(0) : hlo; + std::vector src_physical; + src_physical.reserve(src->shape().rank()); auto dim_order_it = dim_order_.cbegin(); - for (int64_t dim_index : hlo->shape().layout().minor_to_major()) { - const int64_t dim_size = hlo->shape().dimensions(dim_index); + for (int64_t dim_index : src->shape().layout().minor_to_major()) { + const int64_t dim_size = src->shape().dimensions(dim_index); int64_t subdim_size_accumulator = 1; DimOrderVector subdim_group; do { @@ -433,33 +462,37 @@ FusionDecision DimensionOrder::HandleCopyOrTranspose( ++dim_order_it; } while (subdim_size_accumulator < dim_size); CHECK_EQ(subdim_size_accumulator, dim_size); - out_physical.push_back(subdim_group); + src_physical.push_back(subdim_group); } - // Out physical -> out logical. - std::vector out_logical; - out_logical.resize(out_physical.size()); - for (int i = 0; i < out_physical.size(); ++i) { - out_logical[hlo->shape().layout().minor_to_major(i)] = out_physical[i]; + // Source physical -> source logical. + std::vector src_logical; + src_logical.resize(src_physical.size()); + for (int i = 0; i < src_physical.size(); ++i) { + src_logical[src->shape().layout().minor_to_major(i)] = src_physical[i]; } - // Out logical -> operand logical. - std::vector operand_logical; + // Source logical -> destination logical. + std::vector dst_logical; if (hlo->opcode() == HloOpcode::kTranspose) { auto transpose = ::xla::Cast(hlo); - operand_logical.resize(out_logical.size()); - for (int i = 0; i < out_logical.size(); ++i) { - operand_logical[transpose->dimensions()[i]] = out_logical[i]; + std::vector permutation(transpose->dimensions().cbegin(), + transpose->dimensions().cend()); + if (direction == TransformDirection::kInputToOutput) { + permutation = InversePermutation(permutation); + } + dst_logical.resize(src_logical.size()); + for (int i = 0; i < src_logical.size(); ++i) { + dst_logical[permutation[i]] = src_logical[i]; } } else { // Copy preserves the logical shape, just permutes the layout. - const Shape& operand_shape = hlo->operand(0)->shape(); - CHECK(ShapeUtil::SameDimensions(hlo->shape(), operand_shape)); - operand_logical = out_logical; + CHECK(ShapeUtil::SameDimensions(src->shape(), dst->shape())); + dst_logical = src_logical; } - // Operand logical -> operand physical and ungroup subdimensions. - const Layout& operand_layout = hlo->operand(0)->shape().layout(); + // Destination logical -> destination physical and ungroup subdimensions. + const Layout& dst_layout = dst->shape().layout(); dim_order_.clear(); - for (int64_t dim_idx : operand_layout.minor_to_major()) { - for (const DimDescription& subdim : operand_logical[dim_idx]) { + for (int64_t dim_idx : dst_layout.minor_to_major()) { + for (const DimDescription& subdim : dst_logical[dim_idx]) { dim_order_.push_back(subdim); } } @@ -512,27 +545,44 @@ int64_t InputMinusOutputBytes(const HloInstruction& hlo) { return input_size - ShapeUtil::ByteSizeOf(hlo.shape()); } +// Tells if an instruction has no user into which it could be fused. +// More cases should be added here. +bool CanNotBeFusedIntoAUser(const HloInstruction& hlo) { + return hlo.IsRoot() || (hlo.user_count() == 1 && hlo.users()[0]->IsRoot() && + hlo.users()[0]->opcode() == HloOpcode::kTuple); +} + // Tells if an instruction has no input into which it could be fused. // More cases should be added here. bool CanNotBeFusedIntoAProducer(const HloInstruction& hlo) { return hlo_query::AllOperandsAreParametersOrConstants(hlo); } -// Tells that fusing an instruction is efficient. +// Let input and output data volumes of a fusion grow by small amounts. +constexpr int kIoToleranceBytes = 1024; + +// Tells that fusing an instruction as an input is efficient. bool IsInputWorthFusing(const HloInstruction& hlo) { if (hlo.user_count() > 1) { return false; } - // Let input and output data volumes of a fusion grow by small amounts. - constexpr int kIoToleranceBytes = 1024; return hlo_query::AllOperandsAreParametersOrConstants(hlo) || InputMinusOutputBytes(hlo) <= kIoToleranceBytes; } +// Tells that fusing an instruction as an output is efficient. +bool IsOutputWorthFusing(const HloInstruction& hlo) { + return CanNotBeFusedIntoAUser(hlo) || + InputMinusOutputBytes(hlo) >= -kIoToleranceBytes; +} + // Checks if the instruction is possible and profitable to fuse. -// If so tries to transform dim_order describing output of `hlo` into a -// description of its input if it is supported by the triton GEMM emitter. -FusionDecision CanFuse(const HloInstruction& hlo, DimensionOrder& dim_order, +// If so tries to transform dim_order describing one side of `hlo` into a +// description of its other side if it is supported by the triton GEMM emitter. +FusionDecision CanFuse(const HloInstruction& hlo, bool as_input, + DimensionOrder& dim_order, + absl::flat_hash_map& old_to_new_mapping, const GpuVersion gpu_version) { if (hlo.opcode() == HloOpcode::kTuple || hlo.opcode() == HloOpcode::kGetTupleElement) { @@ -552,26 +602,58 @@ FusionDecision CanFuse(const HloInstruction& hlo, DimensionOrder& dim_order, if (hlo.opcode() == HloOpcode::kBroadcast) { return "Not fusing a broadcast."; } - if (hlo.GetModule()->config().debug_options().xla_gpu_triton_fusion_level() < - 2) { - if (hlo.opcode() == HloOpcode::kConvert) { - if (FusionDecision decision = - RequireTritonFusibleConvert(&hlo, gpu_version); - !decision) { - return decision; + if (as_input) { + if (hlo.GetModule() + ->config() + .debug_options() + .xla_gpu_triton_fusion_level() < 2) { + if (hlo.opcode() == HloOpcode::kConvert) { + if (FusionDecision decision = + RequireTritonFusibleConvert(&hlo, gpu_version); + !decision) { + return decision; + } + } else if (hlo.IsElementwise() && hlo.opcode() != HloOpcode::kCopy) { + return "Ignored elementwise operation"; + } + } else { + if (!CanNotBeFusedIntoAProducer(hlo) && !IsInputWorthFusing(hlo)) { + return "Not obviously profitable to fuse as input."; } - } else if (hlo.IsElementwise() && hlo.opcode() != HloOpcode::kCopy) { - return "Ignored elementwise operation"; } } else { - if (!CanNotBeFusedIntoAProducer(hlo) && !IsInputWorthFusing(hlo)) { - return "Not obviously profitable to fuse as input."; + if (hlo.GetModule() + ->config() + .debug_options() + .xla_gpu_triton_fusion_level() < 2) { + return "Skipping fusing outputs at low fusion levels."; + } + for (const HloInstruction* operand : hlo.operands()) { + // Skip already fused operands. + if (old_to_new_mapping.contains(operand)) { + continue; + } + // Currently only broadcasts of scalar constants or parameters + // are accepted as other inputs of non-unary operations + // in the output fusion. + if (hlo_query::IsBroadcastOfScalarConstant(*operand) || + operand->opcode() == HloOpcode::kParameter) { + continue; + } + return "Has multiple inputs - not properly analyzed yet."; + } + if (!IsOutputWorthFusing(hlo)) { + return "Not obviously profitable to fuse as output."; } } - if (FusionDecision decision = dim_order.HandleInstruction(&hlo); !decision) { + if (FusionDecision decision = dim_order.HandleInstruction( + &hlo, as_input ? DimensionOrder::TransformDirection::kOutputToInput + : DimensionOrder::TransformDirection::kInputToOutput); + !decision) { return decision; } + return RequireTritonGemmSupportedDimOrder(dim_order); } @@ -645,7 +727,8 @@ void FuseWithInputsRecursively( // Let it change while the scope has one input; afterwards require all // of them to be physically compatible. const HloInstruction* reference_dim_order_hlo = nullptr; - if (CanFuse(*root, root_dim_order, gpu_version)) { + if (CanFuse(*root, /*as_input=*/true, root_dim_order, old_to_new_mapping, + gpu_version)) { to_fuse.push(root); inputs.insert(root->operands().begin(), root->operands().end()); // root_dim_order went through output -> input transformation here. @@ -669,7 +752,8 @@ void FuseWithInputsRecursively( DimensionOrder operand_dim_order(dim_orders.at(hlo)); // CanFuse() makes output -> input transformation of // operand_dim_order if succeeds. - if (CanFuse(*operand, operand_dim_order, gpu_version)) { + if (CanFuse(*operand, /*as_input=*/true, operand_dim_order, + old_to_new_mapping, gpu_version)) { if (reference_dim_order_hlo != nullptr && !operand_dim_order.IsPhysicallyEquivalent( dim_orders.at(reference_dim_order_hlo))) { @@ -721,8 +805,6 @@ class GemmRewriterTritonVisitor : public DfsHloRewriteVisitor { return OkStatus(); } - // TODO(b/266857789): also fuse convert(dot()) at output if present: - // seen on s8xf32->bf16 std::string suggested_name = absl::StrCat("triton_gemm_", dot->name()); HloComputation::Builder builder( absl::StrCat(suggested_name, "_computation")); @@ -748,6 +830,43 @@ class GemmRewriterTritonVisitor : public DfsHloRewriteVisitor { Fuse(*dot, old_to_new_mapping, call_operands, builder); + // Fusion at dot's output. + + // These describe _outputs_ of corresponding HLOs. + absl::flat_hash_map out_dim_orders; + out_dim_orders.insert({dot, DimensionOrder::FromDotOutput(*dot)}); + HloInstruction* fusion_output = dot; + bool output_changed = true; + while (output_changed) { + output_changed = false; + if (fusion_output->user_count() != 1) { + break; + } + HloInstruction* user = fusion_output->users()[0]; + if (!IsDistributiveOverAddition(*user)) { + break; + } + // Describes the output of `current_output` = input of `user`. + DimensionOrder dim_order(out_dim_orders.at(fusion_output)); + if (CanFuse(*user, /*as_input=*/false, dim_order, old_to_new_mapping, + gpu_version_)) { + // Now it describes the output of the user. + CHECK(out_dim_orders.insert({user, dim_order}).second); + for (HloInstruction* operand : user->operands()) { + if (!old_to_new_mapping.contains(operand)) { + // Here we need again a dim order describing inputs of the user. + FuseWithInputsRecursively( + operand, DimensionOrder(out_dim_orders.at(fusion_output)), + out_dim_orders, gpu_version_, old_to_new_mapping, call_operands, + builder); + } + } + Fuse(*user, old_to_new_mapping, call_operands, builder); + fusion_output = user; + output_changed = true; + } + } + HloComputation* computation = dot->GetModule()->AddComputationAndUnifyNamesAndIds(builder.Build(), /*is_entry=*/false); @@ -763,13 +882,14 @@ class GemmRewriterTritonVisitor : public DfsHloRewriteVisitor { backend_config.set_kind(std::string(kTritonGemmFusionKind)); TF_RETURN_IF_ERROR(dot_fusion->set_backend_config(backend_config)); - if (dot->IsRoot()) { - dot->parent()->set_root_instruction(dot_fusion); + if (fusion_output->IsRoot()) { + fusion_output->parent()->set_root_instruction(dot_fusion); TF_RETURN_IF_ERROR( - dot->parent()->RemoveInstructionAndUnusedOperands(dot)); + fusion_output->parent()->RemoveInstructionAndUnusedOperands( + fusion_output)); MarkAsChanged(); } else { - TF_RETURN_IF_ERROR(ReplaceInstruction(dot, dot_fusion)); + TF_RETURN_IF_ERROR(ReplaceInstruction(fusion_output, dot_fusion)); } XLA_VLOG_LINES(5, computation->ToString()); return OkStatus(); @@ -883,9 +1003,6 @@ Status MakeDotComputationSplitKBatch( DotDimensionNumbers new_dim_numbers; const int64_t lhs_contracting_idx = ContractingDimensionIndex(*dot, 0); - TF_ASSIGN_OR_RETURN( - HloInstruction * lhs, - MakeSplitKOperand(*dot, analysis, tiling, lhs_contracting_idx, 0)); CopyIncrementingAboveThreshold( old_dim_numbers.lhs_contracting_dimensions(), *new_dim_numbers.mutable_lhs_contracting_dimensions(), @@ -896,9 +1013,6 @@ Status MakeDotComputationSplitKBatch( *new_dim_numbers.mutable_lhs_batch_dimensions(), lhs_contracting_idx); const int64_t rhs_contracting_idx = ContractingDimensionIndex(*dot, 1); - TF_ASSIGN_OR_RETURN( - HloInstruction * rhs, - MakeSplitKOperand(*dot, analysis, tiling, rhs_contracting_idx, 1)); CopyIncrementingAboveThreshold( old_dim_numbers.rhs_contracting_dimensions(), *new_dim_numbers.mutable_rhs_contracting_dimensions(), @@ -908,16 +1022,67 @@ Status MakeDotComputationSplitKBatch( old_dim_numbers.rhs_batch_dimensions(), *new_dim_numbers.mutable_rhs_batch_dimensions(), rhs_contracting_idx); - HloInstruction* new_dot = - MakeDotHlo(lhs, rhs, new_dim_numbers, dot->precision_config(), - dot->shape().element_type()) - .value(); - // `new_dot` will have default output layout even if `dot` had a custom one. - // We will set the original output layout on the reduce operation. + // Collect HLOs to transform between dot output and root. These will + // get a new major most batch dimension sized as split K factor. Other inputs + // of these HLOs will get broadcasted. + std::stack to_process; + // Store the same HLOs also in a hash set for quick lookups. + absl::flat_hash_set to_process_set; + HloInstruction* current = dot; + do { + to_process.push(current); + CHECK(to_process_set.insert(current).second); + if (current->users().empty()) { + break; + } + CHECK_EQ(current->user_count(), 1); + current = current->users()[0]; + if (!IsDistributiveOverAddition(*current)) { + return Cancelled("Operation non-distributive over addition after dot."); + } + } while (true); + + // Process the collected HLOs from computation root to dot. + while (!to_process.empty()) { + HloInstruction* current = to_process.top(); + to_process.pop(); + // Add split-K dimension to `current`. + HloInstruction* expanded; + if (current == dot) { + TF_ASSIGN_OR_RETURN( + HloInstruction * lhs, + MakeSplitKOperand(*dot, analysis, tiling, lhs_contracting_idx, 0)); + TF_ASSIGN_OR_RETURN( + HloInstruction * rhs, + MakeSplitKOperand(*dot, analysis, tiling, rhs_contracting_idx, 1)); + expanded = MakeDotHlo(lhs, rhs, new_dim_numbers, dot->precision_config(), + dot->shape().element_type()) + .value(); + dot->SetupDerivedInstruction(expanded); + } else { + expanded = computation->AddInstruction( + current->CloneWithNewShape(ShapeUtil::PrependMajorDimension( + tiling.split_k(), current->shape()))); + } + TF_RETURN_IF_ERROR(current->ReplaceAllUsesWithDifferentShape(expanded)); + TF_RETURN_IF_ERROR(computation->RemoveInstruction(current)); + // Broadcast operands. + if (current == dot) { + continue; + } + for (int i = 0; i < expanded->operands().size(); ++i) { + HloInstruction* operand = expanded->mutable_operand(i); + if (!to_process_set.contains(operand)) { + std::vector broadcast_dimensions(operand->shape().rank()); + absl::c_iota(broadcast_dimensions, 1); + TF_RETURN_IF_ERROR(expanded->ReplaceOperandWithDifferentShape( + i, MakeBroadcastHlo(operand, broadcast_dimensions, + ShapeUtil::PrependMajorDimension( + tiling.split_k(), operand->shape())))); + } + } + } - dot->SetupDerivedInstruction(new_dot); - TF_RETURN_IF_ERROR(dot->ReplaceAllUsesWithDifferentShape(new_dot)); - TF_RETURN_IF_ERROR(dot->parent()->RemoveInstruction(dot)); if (disable_reduced_precision_reduction) { PrimitiveType output_type = computation->root_instruction()->shape().element_type(); @@ -1064,7 +1229,8 @@ DotFusionAnalysis::DotFusionAnalysis(const HloComputation* dot_computation, absl::flat_hash_map dim_orders; DimensionOrder dot_operand_dim_order = DimensionOrder::FromDotOperand(*dot, operand_number, split_k); - CHECK(dot_operand_dim_order.HandleInstruction(dot_operand)); + CHECK(dot_operand_dim_order.HandleInstruction( + dot_operand, DimensionOrder::TransformDirection::kOutputToInput)); CHECK(RequireTritonGemmSupportedDimOrder(dot_operand_dim_order)) << dot_computation->ToString(); dim_orders.insert({dot_operand, dot_operand_dim_order}); @@ -1086,7 +1252,8 @@ DotFusionAnalysis::DotFusionAnalysis(const HloComputation* dot_computation, {hlo_operand, DimensionOrder(dim_orders.at(hlo))}); CHECK(inserted); DimensionOrder& hlo_operand_dim_order = it->second; - CHECK(hlo_operand_dim_order.HandleInstruction(hlo_operand)); + CHECK(hlo_operand_dim_order.HandleInstruction( + hlo_operand, DimensionOrder::TransformDirection::kOutputToInput)); CHECK(RequireTritonGemmSupportedDimOrder(hlo_operand_dim_order)) << " " << dot_computation->ToString(); to_process.push(hlo_operand); @@ -1104,8 +1271,17 @@ DotFusionAnalysis::DotFusionAnalysis(const HloComputation* dot_computation, } DimensionOrder dim_order = DimensionOrder::FromDotOutput(*dot); + const HloInstruction* output = dot; + // Currently supported is one fusion output and one path from dot to it. + while (!output->IsRoot()) { + CHECK_EQ(output->user_count(), 1); + output = output->users()[0]; + CHECK(dim_order.HandleInstruction( + output, DimensionOrder::TransformDirection::kInputToOutput)); + CHECK(RequireTritonGemmSupportedDimOrder(dim_order)); + } CHECK(iter_specs_[Scope::OUTPUT] - .insert({dot, DimensionOrderToTensorIterationSpec(dim_order)}) + .insert({output, DimensionOrderToTensorIterationSpec(dim_order)}) .second); } @@ -1180,10 +1356,19 @@ bool ShouldTritonHandleGEMM(const HloInstruction& dot, return true; } + // Data-narrowing conversion after the dot is profitable to fuse. + if (dot.user_count() == 1 && + dot.users()[0]->opcode() == HloOpcode::kConvert && + InputMinusOutputBytes(*dot.users()[0]) > -kIoToleranceBytes) { + return true; + } + // Traverse HLO graph part checking that it both can be fused // and is worth fusing. auto has_triton_fusible_inputs = [&gpu_version](const HloInstruction& dot, const int operand_number) { + absl::flat_hash_map + old_to_new_mapping; DimensionOrder dim_order = DimensionOrder::FromDotOperand(dot, operand_number); std::queue queue; @@ -1191,9 +1376,12 @@ bool ShouldTritonHandleGEMM(const HloInstruction& dot, while (!queue.empty()) { const HloInstruction* current = queue.front(); queue.pop(); - if (!CanFuse(*current, dim_order, gpu_version)) { + if (!CanFuse(*current, /*as_input=*/true, dim_order, old_to_new_mapping, + gpu_version)) { continue; } + // The values in the map are not used by CanFuse(). + old_to_new_mapping.insert({current, nullptr}); // Stop as soon as a profitable operation is fused. if (current->opcode() == HloOpcode::kConvert || current->opcode() == HloOpcode::kTranspose) { @@ -1207,9 +1395,6 @@ bool ShouldTritonHandleGEMM(const HloInstruction& dot, }; return has_triton_fusible_inputs(dot, 0) || has_triton_fusible_inputs(dot, 1); - - // TODO(b/266857789): either check that no output fusion (axpy, relu etc) - // is expected or actually support it. } StatusOr GemmRewriterTriton::Run( diff --git a/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton_test.cc b/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton_test.cc index 95eaf51915d2e5..53f5e8c163e3d5 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton_test.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton_test.cc @@ -404,6 +404,41 @@ ENTRY e { /*subfragments=*/ElementsAre(3)))); } +TEST_F(TritonDotAnalysisTest, TransposeOutput) { + const std::string hlo_text = R"( +HloModule t + +triton_dot { + p0 = bf16[24,4]{1,0} parameter(0) + p1 = bf16[4,3]{1,0} parameter(1) + dot = bf16[24,3]{1,0} dot(p0, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} + bc = bf16[12,2,3]{2,1,0} bitcast(dot) + ROOT t = bf16[3,12,2]{2,1,0} transpose(bc), dimensions={2,0,1} +} + +ENTRY e { + p0 = bf16[24,4]{1,0} parameter(0) + p1 = bf16[4,3]{1,0} parameter(1) + ROOT r = bf16[3,12,2]{2,1,0} fusion(p0, p1), kind=kCustom, + calls=triton_dot +})"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_text)); + const HloComputation* dot_computation = + module->entry_computation()->root_instruction()->called_computations()[0]; + const HloInstruction* dot_output = dot_computation->root_instruction(); + const DotFusionAnalysis analysis(dot_computation); + EXPECT_THAT( + *analysis.IterSpec(DotFusionAnalysis::Scope::OUTPUT, dot_output, 0), + ElementsAre(FieldsAre(/*stride=*/1, /*count=*/24, + /*subfragments=*/ElementsAre(2, 12)))); + EXPECT_THAT( + *analysis.IterSpec(DotFusionAnalysis::Scope::OUTPUT, dot_output, 1), + ElementsAre(FieldsAre(/*stride=*/24, /*count=*/3, + /*subfragments=*/ElementsAre(3)))); +} + using SplitKTest = HloTestBase; class SplitKTestWithMorePreciseReduction @@ -454,6 +489,79 @@ ENTRY e { HloOpcode::kReduce); } +TEST_F(SplitKTest, MakeSplitKWithOutputFusion) { + const std::string hlo_text = R"( +HloModule t + +triton_gemm_dot { + p0 = f16[480,128]{1,0} parameter(0) + p1 = f16[16,128]{1,0} parameter(1) + d = f16[480,16]{1,0} dot(p0, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={1} + c = bf16[] constant(123) + n = bf16[] negate(c) + bc = bf16[480,16]{1,0} broadcast(n) + cv = bf16[480,16]{1,0} convert(d) + ROOT a = bf16[480,16]{1,0} multiply(bc, cv) +} + +ENTRY e { + p0 = f16[480,128]{1,0} parameter(0) + p1 = f16[16,128]{1,0} parameter(1) + ROOT fusion = bf16[480,16]{1,0} fusion(p0, p1), + kind=kCustom, calls=triton_gemm_dot, backend_config="__triton_gemm" +})"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_text)); + AutotuneResult::TritonGemmKey key; + key.set_block_m(16); + key.set_block_n(16); + key.set_block_k(16); + key.set_split_k(4); + key.set_num_stages(1); + key.set_num_warps(4); + TF_EXPECT_OK( + MakeDotSplitKBatch(module->entry_computation()->root_instruction(), key)); + EXPECT_EQ(module->entry_computation()->root_instruction()->opcode(), + HloOpcode::kReduce); +} + +TEST_F(SplitKTest, PreventSplitKWithNonDistributiveOperations) { + const std::string hlo_text = R"( +HloModule t + +triton_gemm_dot { + p0 = f16[480,128]{1,0} parameter(0) + p1 = f16[16,128]{1,0} parameter(1) + d = f16[480,16]{1,0} dot(p0, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={1} + c = f32[480,16]{1,0} convert(d) + ROOT s = f32[480,16]{1,0} tanh(c) +} + +ENTRY e { + p0 = f16[480,128]{1,0} parameter(0) + p1 = f16[16,128]{1,0} parameter(1) + ROOT fusion = f32[480,16]{1,0} fusion(p0, p1), + kind=kCustom, calls=triton_gemm_dot, backend_config="__triton_gemm" +})"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_text)); + AutotuneResult::TritonGemmKey key; + key.set_block_m(16); + key.set_block_n(16); + key.set_block_k(16); + key.set_split_k(4); + key.set_num_stages(1); + key.set_num_warps(4); + EXPECT_THAT( + MakeDotSplitKBatch(module->entry_computation()->root_instruction(), key), + tsl::testing::StatusIs( + tsl::error::CANCELLED, + absl::StrFormat( + "Operation non-distributive over addition after dot."))); +} + TEST_F(SplitKTest, MakeSplitKWithNonStandardOutputLayout) { const std::string kHloText = R"( HloModule t @@ -570,8 +678,6 @@ ENTRY e { } TEST_F(SplitKTestWithMorePreciseReduction, MakeSplitKWithOutputFusion) { - GTEST_SKIP() << "Output fusion support is temporarily rolled back."; - const std::string hlo_text = R"( HloModule t diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_triton_test.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_triton_test.cc index e87b973b4a60c8..35ed530b9932f0 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_triton_test.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_triton_test.cc @@ -876,6 +876,124 @@ ENTRY e { EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/2e-3, /*arel=*/2e-3})); } +TEST_F(TritonGemmTest, SineOutputIsNotFused) { + const std::string kHloText = R"( +HloModule m + +ENTRY e { + p0 = s8[7,101] parameter(0) + p1 = f32[101,16] parameter(1) + c = f32[7,101] convert(p0) + d = f32[7,16] dot(c, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} + ROOT r = f32[7,16] sine(d) +})"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + GetOptimizedModule(kHloText)); + EXPECT_THAT(module->entry_computation()->root_instruction(), + GmockMatch(m::Sin( + m::Fusion(m::Parameter(), m::Parameter()) + .WithFusionKind(HloInstruction::FusionKind::kCustom)))); + + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-1, /*arel=*/1e-2})); +} + +TEST_F(TritonGemmLevel2Test, NarrowingConvertOutputIsFused) { + const std::string kHloText = R"( +HloModule m + +ENTRY e { + p0 = s8[22,80] parameter(0) + p1 = f32[80,54] parameter(1) + c = f32[22,80] convert(p0) + d = f32[54,22] dot(p1, c), + lhs_contracting_dims={0}, rhs_contracting_dims={1} + ROOT r = f16[54,22] convert(d) +})"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + GetOptimizedModule(kHloText)); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch(m::Fusion(m::Parameter(), m::Parameter()) + .WithFusionKind(HloInstruction::FusionKind::kCustom))); + + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/3e-2, /*arel=*/3e-2})); +} + +TEST_F(TritonGemmLevel2Test, ParameterAfterDotIsFused) { + if (!GetCudaComputeCapability().IsAtLeast( + se::CudaComputeCapability::AMPERE)) { + GTEST_SKIP() << "No BF16 before Ampere."; + } + + const std::string kHloText = R"( +HloModule m + +ENTRY e { + p0 = bf16[350,1280]{1,0} parameter(0) + p1 = s16[1280,690]{0,1} parameter(1) + p1c = bf16[1280,690]{0,1} convert(p1) + dot.21 = bf16[350,690]{1,0} dot(p0, p1c), + lhs_contracting_dims={1}, rhs_contracting_dims={0} + p2 = bf16[350,690]{1,0} parameter(2) + ROOT r = bf16[350,690]{1,0} multiply(p2, dot.21) +})"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + GetOptimizedModule(kHloText)); + const HloInstruction* instr = module->entry_computation()->root_instruction(); + if (!instr->IsCustomFusion()) { + instr = instr->operand(0); + ASSERT_TRUE(instr->IsCustomFusion()); + } + EXPECT_THAT( + instr, + GmockMatch(m::Fusion(m::Parameter(), m::Parameter(), m::Parameter()) + .WithFusionKind(HloInstruction::FusionKind::kCustom))); + + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/2e-2, /*arel=*/2e-2})); +} + +TEST_F(TritonGemmLevel2Test, OutputFusionExecutesCorrectly) { + if (!GetCudaComputeCapability().IsAtLeast( + se::CudaComputeCapability::AMPERE)) { + GTEST_SKIP() << "No BF16 before Ampere."; + } + + const std::string kHloText = R"( +HloModule m + +ENTRY e { + p0 = f16[350,1280]{1,0} parameter(0) + p0c = bf16[350,1280]{1,0} convert(p0) + p1 = bf16[1280,690]{0,1} parameter(1) + d = bf16[350,690]{1,0} dot(p0c, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} + p3 = bf16[350,690]{1,0} parameter(3) + multiply.8811 = bf16[350,690]{1,0} multiply(d, p3) + neg.484 = bf16[350,690]{1,0} negate(multiply.8811) + p2 = bf16[350,690]{1,0} parameter(2) + ROOT multiply.8808 = bf16[350,690]{1,0} multiply(neg.484, p2) +})"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + GetOptimizedModule(kHloText)); + const HloInstruction* instr = module->entry_computation()->root_instruction(); + if (!instr->IsCustomFusion()) { + instr = instr->operand(0); + ASSERT_TRUE(instr->IsCustomFusion()); + } + EXPECT_THAT( + instr, + GmockMatch(m::Fusion(m::Parameter(), m::Parameter(), m::Parameter(), + m::Parameter()) + .WithFusionKind(HloInstruction::FusionKind::kCustom))); + + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/2e-2, /*arel=*/2e-2})); +} + TEST_F(TritonGemmTest, Naming) { const char* hlo_text = R"( HloModule t diff --git a/tensorflow/compiler/xla/service/gpu/tests/gemm_rewrite_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gemm_rewrite_test.cc index 2da870489e4d70..24be7fdbccd2b0 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gemm_rewrite_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gemm_rewrite_test.cc @@ -1267,6 +1267,7 @@ class LegacyCublasGemmRewriteTest : public GemmRewriteTest { public: DebugOptions GetDebugOptionsForTest() override { DebugOptions debug_options = GemmRewriteTest::GetDebugOptionsForTest(); + debug_options.set_xla_gpu_enable_triton_gemm(false); debug_options.set_xla_gpu_enable_cublaslt(false); return debug_options; } diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index 250c5d7bd32e5a..66b3b497139905 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -515,6 +515,22 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( TF_DCHECK_OK(ValidateShape(*shape)); } +// Prepend new major-most dimension sized `bound` to the shape. +Shape ShapeUtil::PrependMajorDimension(int64_t bound, Shape shape) { + Shape new_shape(shape.element_type(), {}, {}, {}); + new_shape.add_dimensions(bound); + for (const int64_t dim : shape.dimensions()) { + new_shape.add_dimensions(dim); + } + if (shape.has_layout()) { + for (const int64_t dim : shape.layout().minor_to_major()) { + new_shape.mutable_layout()->add_minor_to_major(dim + 1); + } + new_shape.mutable_layout()->add_minor_to_major(0); + } + return new_shape; +} + /* static */ void ShapeUtil::AppendMinorDimension(int bound, Shape* shape) { CHECK(LayoutUtil::IsDenseArray(*shape)); shape->add_dimensions(bound); diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index adc93dc5408b1c..82345566e79f8e 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -321,6 +321,9 @@ class ShapeUtil { // Appends a major dimension to the shape with the given bound. static void AppendMajorDimension(int bound, Shape* shape); + // Prepends a major dimension sized `bound` to the shape. + static Shape PrependMajorDimension(int64_t bound, Shape shape); + // Appends a minor dimension to the shape with the given bound. static void AppendMinorDimension(int bound, Shape* shape); diff --git a/tensorflow/compiler/xla/shape_util_test.cc b/tensorflow/compiler/xla/shape_util_test.cc index 84e4cee1c13a7c..a079c93af35771 100644 --- a/tensorflow/compiler/xla/shape_util_test.cc +++ b/tensorflow/compiler/xla/shape_util_test.cc @@ -973,6 +973,22 @@ TEST(ShapeUtilTest, PermuteDynamicDimensions) { } while (std::next_permutation(permutation.begin(), permutation.end())); } +TEST(ShapeUtilTest, PrependMajorDimension) { + Shape shape = ShapeUtil::MakeShape(F32, {10, 20, 30}); + EXPECT_EQ(ShapeUtil::PrependMajorDimension(40, shape), + ShapeUtil::MakeShape(F32, {40, 10, 20, 30})); + + shape = ShapeUtil::MakeShapeWithDenseLayout(F32, {10, 20, 30}, {0, 2, 1}); + EXPECT_EQ( + ShapeUtil::PrependMajorDimension(40, shape), + ShapeUtil::MakeShapeWithDenseLayout(F32, {40, 10, 20, 30}, {1, 3, 2, 0})); + + shape = ShapeUtil::MakeShapeWithDenseLayout(F32, {10, 20, 30}, {2, 1, 0}); + EXPECT_EQ( + ShapeUtil::PrependMajorDimension(40, shape), + ShapeUtil::MakeShapeWithDenseLayout(F32, {40, 10, 20, 30}, {3, 2, 1, 0})); +} + TEST(ShapeUtilTest, AppendMinorDimension) { Shape shape = ShapeUtil::MakeShape(F32, {10, 20, 30}); ShapeUtil::AppendMinorDimension(40, &shape); From 4e5c09c944b43ff6297d0385943285d11eed773c Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Thu, 13 Jul 2023 06:15:23 -0700 Subject: [PATCH 251/376] Avoid using invalidated iterator after erase. We need to be careful while erasing an element from a flat_hash_map during iteration. Erase invalidates the existing iterator, so erase(it++) is invalid. PiperOrigin-RevId: 547776041 --- tensorflow/compiler/xla/service/xla_debug_info_manager.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/compiler/xla/service/xla_debug_info_manager.cc b/tensorflow/compiler/xla/service/xla_debug_info_manager.cc index e59922a9a569f1..33411199568f60 100644 --- a/tensorflow/compiler/xla/service/xla_debug_info_manager.cc +++ b/tensorflow/compiler/xla/service/xla_debug_info_manager.cc @@ -69,12 +69,12 @@ void XlaDebugInfoManager::StopTracing( modules_to_serialize.reserve(modules_.size()); for (auto it = modules_.begin(); it != modules_.end();) { auto& m = it->second; + auto cur_it = it++; if (!m.active) { modules_to_serialize.emplace_back(std::move(m)); - modules_.erase(it++); + modules_.erase(cur_it); } else { modules_to_serialize.emplace_back(m); - ++it; } } } From 83bbee775f381a481cbfad4e910ce7fbafa93df1 Mon Sep 17 00:00:00 2001 From: Justin Lebar Date: Thu, 13 Jul 2023 06:37:22 -0700 Subject: [PATCH 252/376] [XLA:GPU] Disable cudnn runtime fusion. We have now found cases where cudnn runtime fusion fails for both kRelu6 and kLeakyRelu. We haven't found cases for kElu, but I suspect this is just because we don't test kElu much at Google. The failing testcases are described in debug_options_flags.cc. PiperOrigin-RevId: 547781039 --- .../compiler/xla/debug_options_flags.cc | 25 ++++++++++++++++- .../service/gpu/cudnn_fused_conv_rewriter.cc | 12 +++------ .../gpu/cudnn_fused_conv_rewriter_test.cc | 27 ++++++++++++++----- 3 files changed, 48 insertions(+), 16 deletions(-) diff --git a/tensorflow/compiler/xla/debug_options_flags.cc b/tensorflow/compiler/xla/debug_options_flags.cc index 5b8c7ce5d7ce04..eef4cc1ac0a5f0 100644 --- a/tensorflow/compiler/xla/debug_options_flags.cc +++ b/tensorflow/compiler/xla/debug_options_flags.cc @@ -42,7 +42,30 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_cpu_multi_thread_eigen(true); opts.set_xla_gpu_cuda_data_dir("./cuda_sdk_lib"); opts.set_xla_gpu_asm_extra_flags(""); - opts.set_xla_gpu_use_runtime_fusion(true); + + // As of cudnn 8.9.0, enabling cudnn runtime fusion sometimes causes a + // situation where cudnn returns 0 algorithms for an otherwise-valid conv, + // causing compilation to fail. Examples of failing convs: + // + // // failing kLeakyRelu, b/290967578 + // (f16[2,256,768,16]{3,2,1,0}, u8[0]{0}) + // custom-call(f16[2,256,768,3]{3,2,1,0} %a, f16[16,3,3,3]{3,2,1,0} %b, + // f16[16]{0} %c), window={size=3x3 pad=1_1x1_1}, + // dim_labels=b01f_o01i->b01f, operand_precision={highest,highest}, + // custom_call_target="__cudnn$convBiasActivationForward", + // backend_config={"activation_mode":"kLeakyRelu","conv_result_scale":1, + // "side_input_scale":0,"leakyrelu_alpha":0.199951171875} + // + // // failing kRelu6, b/291011396 + // (f16[1,384,1024,32]{3,2,1,0}, u8[0]{0}) + // custom-call(f16[1,769,2049,3]{3,2,1,0} %a, f16[32,3,3,3]{3,2,1,0} %b, + // f16[32]{0} %c), window={size=3x3 stride=2x2}, dim_labels=b01f_o01i->b01f, + // operand_precision={highest,highest}, + // custom_call_target="__cudnn$convBiasActivationForward", + // backend_config={"activation_mode":"kRelu6","conv_result_scale":1, + // "side_input_scale":0,"leakyrelu_alpha":0} + opts.set_xla_gpu_use_runtime_fusion(false); + opts.set_xla_eliminate_hlo_implicit_broadcast(true); opts.set_xla_dump_hlo_as_html(false); opts.set_xla_dump_fusion_visualization(false); diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.cc b/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.cc index 220b53c46011a0..8bf17cca183d0f 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.cc @@ -71,6 +71,9 @@ bool IsNonDepthwiseConvCustomCall(const HloInstruction* instr) { // // nvidia currently recommends that we enable this only on Ampere+, but we've // tested on Turing (sm75) and it seems to work fine. +// +// Note that as of writing, xla_gpu_use_runtime_fusion is disabled by default +// due to apparent bugs in cudnn 8.9.0. See debug_options_flags.cc for details. bool ShouldUseCudnnRuntimeFusion(const DebugOptions& debug_opts, se::CudaComputeCapability cc) { return debug_opts.xla_gpu_use_runtime_fusion() && cc.IsAtLeast(7, 5); @@ -674,15 +677,6 @@ StatusOr FuseRelu6(HloComputation* comp, se::CudaComputeCapability cc) { StatusOr FuseLeakyRelu(HloComputation* comp, se::CudaComputeCapability cc) { - // TODO(jlebar): Disabled due to bugs in cudnn 8.9.0. In particular, the - // following convolution gets 0 algorithms available, so it fails to run. - // - // (f16[2,256,768,16]{3,2,1,0}, u8[0]{0}) - // custom-call(f16[2,256,768,3]{3,2,1,0} %a, f16[16,3,3,3]{3,2,1,0} %b, - // f16[16]{0} %c), window={size=3x3 pad=1_1x1_1}, - // dim_labels=b01f_o01i->b01f, operand_precision={highest,highest} - return false; - if (!ShouldUseCudnnRuntimeFusion(comp->parent()->config().debug_options(), cc)) { return false; diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter_test.cc b/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter_test.cc index 0701cc5e19e0a1..0ac90d1ca480aa 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter_test.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter_test.cc @@ -80,6 +80,7 @@ class CudnnFusedConvRewriterTest : public GpuCodegenTest { HloModuleConfig config = GetModuleConfigForTest(); DebugOptions debug_opts = config.debug_options(); debug_opts.add_xla_disable_hlo_passes("cudnn_vectorize_convolutions"); + debug_opts.set_xla_gpu_use_runtime_fusion(true); config.set_debug_options(debug_opts); auto result = backend().compiler()->RunHloPasses( @@ -285,9 +286,7 @@ TEST_F(CudnnFusedConvRewriterTest, TestRelu6) { })"); } -// TODO(jlebar): leaky-relu fusion is disabled because some convolutions have 0 -// algorithm choices. See the cc file. -TEST_F(CudnnFusedConvRewriterTest, DISABLED_TestLeakyRelu) { +TEST_F(CudnnFusedConvRewriterTest, TestLeakyRelu) { if (!GetCudaComputeCapability().IsAtLeast( se::CudaComputeCapability::AMPERE)) { GTEST_SKIP() @@ -994,6 +993,9 @@ TEST_F(CudnnFusedConvRewriterHloTest, FuseElu) { ROOT elu = select(cmp, sum, expm1) })"; TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); + DebugOptions debug_opts = m->config().debug_options(); + debug_opts.set_xla_gpu_use_runtime_fusion(true); + m->config().set_debug_options(debug_opts); GpuConvRewriter rewriter; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); @@ -1038,6 +1040,9 @@ TEST_F(CudnnFusedConvRewriterHloTest, DontFuseEluIfMultipleUses) { ROOT root = tuple(elu, not_elu) })"; TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); + DebugOptions debug_opts = m->config().debug_options(); + debug_opts.set_xla_gpu_use_runtime_fusion(true); + m->config().set_debug_options(debug_opts); GpuConvRewriter rewriter; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); @@ -1086,6 +1091,9 @@ TEST_F(CudnnFusedConvRewriterHloTest, FuseRelu6) { ROOT relu = clamp(zeros, sum, sixes) })"; TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); + DebugOptions debug_opts = m->config().debug_options(); + debug_opts.set_xla_gpu_use_runtime_fusion(true); + m->config().set_debug_options(debug_opts); GpuConvRewriter rewriter; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); @@ -1125,6 +1133,9 @@ TEST_F(CudnnFusedConvRewriterHloTest, DontFuseRelu6IfMultipleUses) { ROOT root = tuple(relu, not_relu) })"; TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); + DebugOptions debug_opts = m->config().debug_options(); + debug_opts.set_xla_gpu_use_runtime_fusion(true); + m->config().set_debug_options(debug_opts); GpuConvRewriter rewriter; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); @@ -1150,9 +1161,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, DontFuseRelu6IfMultipleUses) { EXPECT_EQ(config.activation_mode(), se::dnn::kNone); } -// TODO(jlebar): leaky-relu fusion is disabled because some convolutions have 0 -// algorithm choices. See the cc file. -TEST_F(CudnnFusedConvRewriterHloTest, DISABLED_FuseLeakyRelu) { +TEST_F(CudnnFusedConvRewriterHloTest, FuseLeakyRelu) { const std::string module_str = R"( HloModule Test ENTRY Test { @@ -1170,6 +1179,9 @@ TEST_F(CudnnFusedConvRewriterHloTest, DISABLED_FuseLeakyRelu) { ROOT leaky_relu = select(cmp, sum, mul) })"; TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); + DebugOptions debug_opts = m->config().debug_options(); + debug_opts.set_xla_gpu_use_runtime_fusion(true); + m->config().set_debug_options(debug_opts); GpuConvRewriter rewriter; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); @@ -1212,6 +1224,9 @@ TEST_F(CudnnFusedConvRewriterHloTest, DontFuseLeakyReluIfMultipleUses) { ROOT root = tuple(leaky_relu, not_leaky_relu) })"; TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); + DebugOptions debug_opts = m->config().debug_options(); + debug_opts.set_xla_gpu_use_runtime_fusion(true); + m->config().set_debug_options(debug_opts); GpuConvRewriter rewriter; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); From 25c87dbd523d10210accb33015ead7fd48b4406d Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 13 Jul 2023 07:23:44 -0700 Subject: [PATCH 253/376] Fix typo: `s/the the/the /`. PiperOrigin-RevId: 547792636 --- tensorflow/python/framework/tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/framework/tensor.py b/tensorflow/python/framework/tensor.py index 1a5cb7df766fc0..61dbae3417036c 100644 --- a/tensorflow/python/framework/tensor.py +++ b/tensorflow/python/framework/tensor.py @@ -925,7 +925,7 @@ class TensorSpec(DenseSpec, type_spec.BatchableTypeSpec, >>> tf.TensorSpec.from_tensor(t) TensorSpec(shape=(2, 3), dtype=tf.int32, name=None) - Contains metadata for describing the the nature of `tf.Tensor` objects + Contains metadata for describing the nature of `tf.Tensor` objects accepted or returned by some TensorFlow APIs. For example, it can be used to constrain the type of inputs accepted by From cf4fb63fcebf4536583480706b8d675b8c19ddc0 Mon Sep 17 00:00:00 2001 From: Ashish Shenoy Date: Thu, 13 Jul 2023 07:24:10 -0700 Subject: [PATCH 254/376] Remove unused autograph converter in `tensorflow/python/autograph/converters/control_flow_deprecated_py2.py`. PiperOrigin-RevId: 547792722 --- tensorflow/python/autograph/converters/BUILD | 22 - .../converters/control_flow_deprecated_py2.py | 635 ------------------ 2 files changed, 657 deletions(-) delete mode 100644 tensorflow/python/autograph/converters/control_flow_deprecated_py2.py diff --git a/tensorflow/python/autograph/converters/BUILD b/tensorflow/python/autograph/converters/BUILD index 0b1dcdd0ca3900..d45a9a330ea7a1 100644 --- a/tensorflow/python/autograph/converters/BUILD +++ b/tensorflow/python/autograph/converters/BUILD @@ -84,28 +84,6 @@ py_strict_library( ], ) -py_strict_library( - name = "control_flow_deprecated_py2", - srcs = ["control_flow_deprecated_py2.py"], - visibility = ["//tensorflow:__subpackages__"], - deps = [ - "//tensorflow/python/autograph/core:converter", - "//tensorflow/python/autograph/lang:directives", - "//tensorflow/python/autograph/pyct:anno", - "//tensorflow/python/autograph/pyct:ast_util", - "//tensorflow/python/autograph/pyct:cfg", - "//tensorflow/python/autograph/pyct:parser", - "//tensorflow/python/autograph/pyct:qual_names", - "//tensorflow/python/autograph/pyct:templates", - "//tensorflow/python/autograph/pyct/static_analysis:activity", - "//tensorflow/python/autograph/pyct/static_analysis:annos", - "//tensorflow/python/autograph/pyct/static_analysis:liveness", - "//tensorflow/python/autograph/pyct/static_analysis:reaching_definitions", - "//tensorflow/python/autograph/pyct/static_analysis:reaching_fndefs", - "@gast_archive//:gast", - ], -) - py_strict_library( name = "directives", srcs = ["directives.py"], diff --git a/tensorflow/python/autograph/converters/control_flow_deprecated_py2.py b/tensorflow/python/autograph/converters/control_flow_deprecated_py2.py deleted file mode 100644 index 6fa1deee76b61f..00000000000000 --- a/tensorflow/python/autograph/converters/control_flow_deprecated_py2.py +++ /dev/null @@ -1,635 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Handles control flow statements: while, for, if. - -Python 2 compatibility version. Not maintained. -""" - -import gast - -from tensorflow.python.autograph.core import converter -from tensorflow.python.autograph.lang import directives -from tensorflow.python.autograph.pyct import anno -from tensorflow.python.autograph.pyct import ast_util -from tensorflow.python.autograph.pyct import cfg -from tensorflow.python.autograph.pyct import parser -from tensorflow.python.autograph.pyct import qual_names -from tensorflow.python.autograph.pyct import templates -from tensorflow.python.autograph.pyct.static_analysis import activity -from tensorflow.python.autograph.pyct.static_analysis import annos -from tensorflow.python.autograph.pyct.static_analysis import liveness -from tensorflow.python.autograph.pyct.static_analysis import reaching_definitions -from tensorflow.python.autograph.pyct.static_analysis import reaching_fndefs - - -# TODO(mdan): Refactor functions to make them smaller. - - -class ControlFlowTransformer(converter.Base): - """Transforms control flow structures like loops an conditionals.""" - - def _create_cond_branch(self, body_name, aliased_orig_names, - aliased_new_names, body, returns): - if len(returns) == 1: - template = """ - return retval - """ - return_stmt = templates.replace(template, retval=returns[0]) - else: - template = """ - return (retvals,) - """ - return_stmt = templates.replace(template, retvals=returns) - - if aliased_orig_names: - alias_declarations = [] - for new_name, old_name in zip(aliased_new_names, aliased_orig_names): - template = """ - try: - aliased_new_name = aliased_orig_name - except NameError: - aliased_new_name = ag__.Undefined(symbol_name) - """ - - alias_declarations.extend( - templates.replace( - template, - aliased_new_name=new_name, - aliased_orig_name=old_name, - symbol_name=gast.Constant(str(old_name), kind=None))) - - template = """ - def body_name(): - alias_declarations - body - return_stmt - """ - return templates.replace( - template, - alias_declarations=alias_declarations, - body_name=body_name, - body=body, - return_stmt=return_stmt) - else: - template = """ - def body_name(): - body - return_stmt - """ - return templates.replace( - template, body_name=body_name, body=body, return_stmt=return_stmt) - - def _create_cond_expr(self, results, test, body_name, orelse_name, - state_getter_name, state_setter_name, - basic_symbol_names, composite_symbol_names): - if results is not None: - template = """ - results = ag__.if_stmt(test, body_name, orelse_name, - state_getter_name, state_setter_name, - (basic_symbol_names,), - (composite_symbol_names,)) - """ - return templates.replace( - template, - test=test, - results=results, - body_name=body_name, - orelse_name=orelse_name, - state_getter_name=state_getter_name, - state_setter_name=state_setter_name, - basic_symbol_names=basic_symbol_names, - composite_symbol_names=composite_symbol_names) - else: - template = """ - ag__.if_stmt(test, body_name, orelse_name, getter_name, setter_name, - (basic_symbol_names,), (composite_symbol_names,)) - """ - return templates.replace( - template, - test=test, - body_name=body_name, - orelse_name=orelse_name, - getter_name=state_getter_name, - setter_name=state_setter_name, - basic_symbol_names=basic_symbol_names, - composite_symbol_names=composite_symbol_names) - - def _fmt_symbols(self, symbol_set): - if not symbol_set: - return 'no variables' - return ', '.join(map(str, symbol_set)) - - def _determine_aliased_symbols(self, scope, node_defined_in): - modified_live = scope.modified & node_defined_in - # Composite symbols are handled elsewhere see _create_state_functions - return {s for s in modified_live if not s.is_composite()} - - def _create_state_functions(self, composites, state_getter_name, - state_setter_name): - - if composites: - composite_tuple = tuple(composites) - - template = """ - def state_getter_name(): - return composite_tuple, - def state_setter_name(vals): - composite_tuple, = vals - """ - node = templates.replace( - template, - state_getter_name=state_getter_name, - state_setter_name=state_setter_name, - composite_tuple=composite_tuple) - else: - template = """ - def state_getter_name(): - return () - def state_setter_name(_): - pass - """ - node = templates.replace( - template, - state_getter_name=state_getter_name, - state_setter_name=state_setter_name) - - return node - - def _create_loop_options(self, node): - if not anno.hasanno(node, anno.Basic.DIRECTIVES): - return gast.Dict([], []) - - loop_directives = anno.getanno(node, anno.Basic.DIRECTIVES) - if directives.set_loop_options not in loop_directives: - return gast.Dict([], []) - - opts_dict = loop_directives[directives.set_loop_options] - str_keys, values = zip(*opts_dict.items()) - keys = [gast.Constant(s, kind=None) for s in str_keys] - values = list(values) # ast and gast don't play well with tuples. - return gast.Dict(keys, values) - - def _create_undefined_assigns(self, undefined_symbols): - assignments = [] - for s in undefined_symbols: - template = ''' - var = ag__.Undefined(symbol_name) - ''' - assignments += templates.replace( - template, - var=s, - symbol_name=gast.Constant(s.ssf(), kind=None)) - return assignments - - def visit_If(self, node): - body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) - orelse_scope = anno.getanno(node, annos.NodeAnno.ORELSE_SCOPE) - defined_in = anno.getanno(node, anno.Static.DEFINED_VARS_IN) - live_out = anno.getanno(node, anno.Static.LIVE_VARS_OUT) - - # Note: this information needs to be extracted before the body conversion - # that happens in the call to generic_visit below, because the conversion - # generates nodes that lack static analysis annotations. - need_alias_in_body = self._determine_aliased_symbols( - body_scope, defined_in) - need_alias_in_orelse = self._determine_aliased_symbols( - orelse_scope, defined_in) - - node = self.generic_visit(node) - - modified_in_cond = body_scope.modified | orelse_scope.modified - returned_from_cond = set() - composites = set() - for s in modified_in_cond: - if s in live_out and not s.is_composite(): - returned_from_cond.add(s) - if s.is_composite(): - # Special treatment for compound objects, always return them. - # This allows special handling within the if_stmt itself. - # For example, in TensorFlow we need to restore the state of composite - # symbols to ensure that only effects from the executed branch are seen. - composites.add(s) - - created_in_body = body_scope.modified & returned_from_cond - defined_in - created_in_orelse = orelse_scope.modified & returned_from_cond - defined_in - - basic_created_in_body = tuple( - s for s in created_in_body if not s.is_composite()) - basic_created_in_orelse = tuple( - s for s in created_in_orelse if not s.is_composite()) - - # These variables are defined only in a single branch. This is fine in - # Python so we pass them through. Another backend, e.g. Tensorflow, may need - # to handle these cases specially or throw an Error. - possibly_undefined = (set(basic_created_in_body) ^ - set(basic_created_in_orelse)) - - # Alias the closure variables inside the conditional functions, to allow - # the functions access to the respective variables. - # We will alias variables independently for body and orelse scope, - # because different branches might write different variables. - aliased_body_orig_names = tuple(need_alias_in_body) - aliased_orelse_orig_names = tuple(need_alias_in_orelse) - aliased_body_new_names = tuple( - self.ctx.namer.new_symbol(s.ssf(), body_scope.referenced) - for s in aliased_body_orig_names) - aliased_orelse_new_names = tuple( - self.ctx.namer.new_symbol(s.ssf(), orelse_scope.referenced) - for s in aliased_orelse_orig_names) - - alias_body_map = dict(zip(aliased_body_orig_names, aliased_body_new_names)) - alias_orelse_map = dict( - zip(aliased_orelse_orig_names, aliased_orelse_new_names)) - - node_body = ast_util.rename_symbols(node.body, alias_body_map) - node_orelse = ast_util.rename_symbols(node.orelse, alias_orelse_map) - - cond_var_name = self.ctx.namer.new_symbol('cond', body_scope.referenced) - body_name = self.ctx.namer.new_symbol('if_true', body_scope.referenced) - orelse_name = self.ctx.namer.new_symbol('if_false', orelse_scope.referenced) - all_referenced = body_scope.referenced | orelse_scope.referenced - state_getter_name = self.ctx.namer.new_symbol('get_state', all_referenced) - state_setter_name = self.ctx.namer.new_symbol('set_state', all_referenced) - - returned_from_cond = tuple(returned_from_cond) - composites = tuple(composites) - - if returned_from_cond: - if len(returned_from_cond) == 1: - cond_results = returned_from_cond[0] - else: - cond_results = gast.Tuple([s.ast() for s in returned_from_cond], None) - - returned_from_body = tuple( - alias_body_map[s] if s in need_alias_in_body else s - for s in returned_from_cond) - returned_from_orelse = tuple( - alias_orelse_map[s] if s in need_alias_in_orelse else s - for s in returned_from_cond) - - else: - # When the cond would return no value, we leave the cond called without - # results. That in turn should trigger the side effect guards. The - # branch functions will return a dummy value that ensures cond - # actually has some return value as well. - cond_results = None - # TODO(mdan): Replace with None once side_effect_guards is retired. - returned_from_body = (templates.replace_as_expression( - 'ag__.match_staging_level(1, cond_var_name)', - cond_var_name=cond_var_name),) - returned_from_orelse = (templates.replace_as_expression( - 'ag__.match_staging_level(1, cond_var_name)', - cond_var_name=cond_var_name),) - - cond_assign = self.create_assignment(cond_var_name, node.test) - body_def = self._create_cond_branch( - body_name, - aliased_orig_names=aliased_body_orig_names, - aliased_new_names=aliased_body_new_names, - body=node_body, - returns=returned_from_body) - orelse_def = self._create_cond_branch( - orelse_name, - aliased_orig_names=aliased_orelse_orig_names, - aliased_new_names=aliased_orelse_new_names, - body=node_orelse, - returns=returned_from_orelse) - undefined_assigns = self._create_undefined_assigns(possibly_undefined) - composite_defs = self._create_state_functions( - composites, state_getter_name, state_setter_name) - - basic_symbol_names = tuple( - gast.Constant(str(symbol), kind=None) for symbol in returned_from_cond) - composite_symbol_names = tuple( - gast.Constant(str(symbol), kind=None) for symbol in composites) - - cond_expr = self._create_cond_expr(cond_results, cond_var_name, body_name, - orelse_name, state_getter_name, - state_setter_name, basic_symbol_names, - composite_symbol_names) - - if_ast = ( - undefined_assigns + composite_defs + body_def + orelse_def + - cond_assign + cond_expr) - return if_ast - - def _get_basic_loop_vars(self, modified_symbols, live_in, live_out): - # The loop variables corresponding to simple symbols (e.g. `x`). - basic_loop_vars = [] - for s in modified_symbols: - if s.is_composite(): - # TODO(mdan): Raise an error when this happens for a TF loop. - continue - # Variables not live into or out of the loop are considered local to the - # loop. - if s not in live_in and s not in live_out: - continue - basic_loop_vars.append(s) - return frozenset(basic_loop_vars) - - def _get_composite_loop_vars(self, modified_symbols, live_in): - # The loop variables corresponding to composite symbols (e.g. `self.x`). - composite_loop_vars = [] - for s in modified_symbols: - if not s.is_composite(): - continue - # Mutations made to objects created inside the loop will appear as writes - # to composite symbols. Because these mutations appear as modifications - # made to composite symbols, we check whether the composite's parent is - # actually live into the loop. - # Example: - # while cond: - # x = Foo() - # x.foo = 2 * x.foo # x.foo is live into the loop, but x is not. - # - # Note that some parents might not be symbols - for example, in x['foo'], - # 'foo' is a parent, but it's a literal, not a symbol. We don't check the - # liveness of literals. - support_set_symbols = tuple( - sss for sss in s.support_set if sss.is_symbol()) - if not all(sss in live_in for sss in support_set_symbols): - continue - composite_loop_vars.append(s) - return frozenset(composite_loop_vars) - - def _get_loop_vars(self, node, modified_symbols): - body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) - defined_in = anno.getanno(node, anno.Static.DEFINED_VARS_IN) - live_in = anno.getanno(node, anno.Static.LIVE_VARS_IN) - live_out = anno.getanno(node, anno.Static.LIVE_VARS_OUT) - reserved_symbols = body_scope.referenced - - basic_loop_vars = self._get_basic_loop_vars( - modified_symbols, live_in, live_out) - composite_loop_vars = self._get_composite_loop_vars( - modified_symbols, live_in) - - # Variable that are used or defined inside the loop, but not defined - # before entering the loop. Only simple variables must be defined. The - # composite ones will be implicitly checked at runtime. - undefined_lives = basic_loop_vars - defined_in - - return (basic_loop_vars, composite_loop_vars, reserved_symbols, - undefined_lives) - - def _loop_var_constructs(self, basic_loop_vars): - loop_vars = tuple(basic_loop_vars) - loop_vars_ast_tuple = gast.Tuple([n.ast() for n in loop_vars], None) - - if len(loop_vars) == 1: - loop_vars = loop_vars[0] - - return loop_vars, loop_vars_ast_tuple - - def visit_While(self, node): - node = self.generic_visit(node) - - (basic_loop_vars, composite_loop_vars, reserved_symbols, - possibly_undefs) = self._get_loop_vars( - node, - anno.getanno(node, annos.NodeAnno.BODY_SCOPE).modified) - loop_vars, loop_vars_ast_tuple = self._loop_var_constructs( - basic_loop_vars) - - state_getter_name = self.ctx.namer.new_symbol('get_state', reserved_symbols) - state_setter_name = self.ctx.namer.new_symbol('set_state', reserved_symbols) - state_functions = self._create_state_functions( - composite_loop_vars, state_getter_name, state_setter_name) - - basic_symbol_names = tuple( - gast.Constant(str(symbol), kind=None) for symbol in basic_loop_vars) - composite_symbol_names = tuple( - gast.Constant(str(symbol), kind=None) for symbol in composite_loop_vars) - - opts = self._create_loop_options(node) - - # TODO(mdan): Use a single template. - # If the body and test functions took a single tuple for loop_vars, instead - # of *loop_vars, then a single template could be used. - if loop_vars: - template = """ - state_functions - def body_name(loop_vars): - body - return loop_vars, - def test_name(loop_vars): - return test - loop_vars_ast_tuple = ag__.while_stmt( - test_name, - body_name, - state_getter_name, - state_setter_name, - (loop_vars,), - (basic_symbol_names,), - (composite_symbol_names,), - opts) - """ - node = templates.replace( - template, - loop_vars=loop_vars, - loop_vars_ast_tuple=loop_vars_ast_tuple, - test_name=self.ctx.namer.new_symbol('loop_test', reserved_symbols), - test=node.test, - body_name=self.ctx.namer.new_symbol('loop_body', reserved_symbols), - body=node.body, - state_functions=state_functions, - state_getter_name=state_getter_name, - state_setter_name=state_setter_name, - basic_symbol_names=basic_symbol_names, - composite_symbol_names=composite_symbol_names, - opts=opts) - else: - template = """ - state_functions - def body_name(): - body - return () - def test_name(): - return test - ag__.while_stmt( - test_name, - body_name, - state_getter_name, - state_setter_name, - (), - (), - (composite_symbol_names,), - opts) - """ - node = templates.replace( - template, - test_name=self.ctx.namer.new_symbol('loop_test', reserved_symbols), - test=node.test, - body_name=self.ctx.namer.new_symbol('loop_body', reserved_symbols), - body=node.body, - state_functions=state_functions, - state_getter_name=state_getter_name, - state_setter_name=state_setter_name, - composite_symbol_names=composite_symbol_names, - opts=opts) - - undefined_assigns = self._create_undefined_assigns(possibly_undefs) - return undefined_assigns + node - - def visit_For(self, node): - node = self.generic_visit(node) - - (basic_loop_vars, composite_loop_vars, - reserved_symbols, possibly_undefs) = self._get_loop_vars( - node, (anno.getanno(node, annos.NodeAnno.BODY_SCOPE).modified - | anno.getanno(node, annos.NodeAnno.ITERATE_SCOPE).modified)) - loop_vars, loop_vars_ast_tuple = self._loop_var_constructs( - basic_loop_vars) - body_name = self.ctx.namer.new_symbol('loop_body', reserved_symbols) - - state_getter_name = self.ctx.namer.new_symbol('get_state', reserved_symbols) - state_setter_name = self.ctx.namer.new_symbol('set_state', reserved_symbols) - state_functions = self._create_state_functions( - composite_loop_vars, state_getter_name, state_setter_name) - - if anno.hasanno(node, anno.Basic.EXTRA_LOOP_TEST): - extra_test = anno.getanno(node, anno.Basic.EXTRA_LOOP_TEST) - extra_test_name = self.ctx.namer.new_symbol( - 'extra_test', reserved_symbols) - template = """ - def extra_test_name(loop_vars): - return extra_test_expr - """ - extra_test_function = templates.replace( - template, - extra_test_name=extra_test_name, - loop_vars=loop_vars, - extra_test_expr=extra_test) - else: - extra_test_name = parser.parse_expression('None') - extra_test_function = [] - - # Workaround for PEP-3113 - # iterates_var holds a single variable with the iterates, which may be a - # tuple. - iterates_var_name = self.ctx.namer.new_symbol( - 'iterates', reserved_symbols) - template = """ - iterates = iterates_var_name - """ - iterate_expansion = templates.replace( - template, - iterates=node.target, - iterates_var_name=iterates_var_name) - - undefined_assigns = self._create_undefined_assigns(possibly_undefs) - - basic_symbol_names = tuple( - gast.Constant(str(symbol), kind=None) for symbol in basic_loop_vars) - composite_symbol_names = tuple( - gast.Constant(str(symbol), kind=None) for symbol in composite_loop_vars) - - opts = self._create_loop_options(node) - - # TODO(mdan): Use a single template. - # If the body and test functions took a single tuple for loop_vars, instead - # of *loop_vars, then a single template could be used. - if loop_vars: - template = """ - undefined_assigns - state_functions - def body_name(iterates_var_name, loop_vars): - iterate_expansion - body - return loop_vars, - extra_test_function - loop_vars_ast_tuple = ag__.for_stmt( - iter_, - extra_test_name, - body_name, - state_getter_name, - state_setter_name, - (loop_vars,), - (basic_symbol_names,), - (composite_symbol_names,), - opts) - """ - return templates.replace( - template, - undefined_assigns=undefined_assigns, - loop_vars=loop_vars, - loop_vars_ast_tuple=loop_vars_ast_tuple, - iter_=node.iter, - iterate_expansion=iterate_expansion, - iterates_var_name=iterates_var_name, - extra_test_name=extra_test_name, - extra_test_function=extra_test_function, - body_name=body_name, - body=node.body, - state_functions=state_functions, - state_getter_name=state_getter_name, - state_setter_name=state_setter_name, - basic_symbol_names=basic_symbol_names, - composite_symbol_names=composite_symbol_names, - opts=opts) - else: - template = """ - undefined_assigns - state_functions - def body_name(iterates_var_name): - iterate_expansion - body - return () - extra_test_function - ag__.for_stmt( - iter_, - extra_test_name, - body_name, - state_getter_name, - state_setter_name, - (), - (), - (composite_symbol_names,), - opts) - """ - return templates.replace( - template, - undefined_assigns=undefined_assigns, - iter_=node.iter, - iterate_expansion=iterate_expansion, - iterates_var_name=iterates_var_name, - extra_test_name=extra_test_name, - extra_test_function=extra_test_function, - body_name=body_name, - body=node.body, - state_functions=state_functions, - state_getter_name=state_getter_name, - state_setter_name=state_setter_name, - composite_symbol_names=composite_symbol_names, - opts=opts) - - -class AnnotatedDef(reaching_definitions.Definition): - - def __init__(self): - super(AnnotatedDef, self).__init__() - self.directives = {} - - -def transform(node, ctx): - graphs = cfg.build(node) - node = qual_names.resolve(node) - node = activity.resolve(node, ctx, None) - node = reaching_definitions.resolve(node, ctx, graphs) - node = reaching_fndefs.resolve(node, ctx, graphs) - node = liveness.resolve(node, ctx, graphs) - - node = ControlFlowTransformer(ctx).visit(node) - return node From 6635250da33fda610bd557df7afbc88958ed3389 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 13 Jul 2023 07:24:13 -0700 Subject: [PATCH 255/376] Fix typo: `s/the the/the /`. PiperOrigin-RevId: 547792729 --- tensorflow/python/ops/embedding_ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/ops/embedding_ops.py b/tensorflow/python/ops/embedding_ops.py index fc46a849d86f5c..2fbd2643da45d1 100644 --- a/tensorflow/python/ops/embedding_ops.py +++ b/tensorflow/python/ops/embedding_ops.py @@ -576,7 +576,7 @@ def embedding_lookup_sparse_v2( Since row 1 and 2 of `sp_ids` only have one value each, they simply select the corresponding row from `params` as the output row. Row 1 has value `3` so it selects the `params` elements `[7, 8]` and row 2 has the value 2 so it - selects the the `params` elements `[5, 6]`. + selects the `params` elements `[5, 6]`. If `sparse_weights` is specified, it must have the same shape as `sp_ids`. `sparse_weights` is used to assign a weight to each slice of `params`. For From 651876f3356c88c14b9237c5c2fbad9486e11f2e Mon Sep 17 00:00:00 2001 From: Jean-Baptiste Lespiau Date: Thu, 13 Jul 2023 08:10:23 -0700 Subject: [PATCH 256/376] Remove a potential duplicate of `RETURN_IF_ERROR`. PiperOrigin-RevId: 547805235 --- tensorflow/cc/experimental/libexport/load.cc | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/tensorflow/cc/experimental/libexport/load.cc b/tensorflow/cc/experimental/libexport/load.cc index c045dbd4e78058..be9319b066d74d 100644 --- a/tensorflow/cc/experimental/libexport/load.cc +++ b/tensorflow/cc/experimental/libexport/load.cc @@ -23,12 +23,6 @@ limitations under the License. #include "tensorflow/core/protobuf/meta_graph.pb.h" #include "tensorflow/core/util/tensor_bundle/tensor_bundle.h" -#define RETURN_IF_ERROR(s) \ - { \ - auto c = (s); \ - if (!c.ok()) return c; \ - } - namespace tensorflow { namespace libexport { @@ -41,11 +35,11 @@ tensorflow::StatusOr TFPackage::Load(const std::string& path) { const string saved_model_pbtxt_path = io::JoinPath(path, kSavedModelFilenamePbTxt); if (Env::Default()->FileExists(saved_model_pb_path).ok()) { - RETURN_IF_ERROR(ReadBinaryProto(Env::Default(), saved_model_pb_path, - &tf_package.saved_model_proto_)); + TF_RETURN_IF_ERROR(ReadBinaryProto(Env::Default(), saved_model_pb_path, + &tf_package.saved_model_proto_)); } else if (Env::Default()->FileExists(saved_model_pbtxt_path).ok()) { - RETURN_IF_ERROR(ReadTextProto(Env::Default(), saved_model_pbtxt_path, - &tf_package.saved_model_proto_)); + TF_RETURN_IF_ERROR(ReadTextProto(Env::Default(), saved_model_pbtxt_path, + &tf_package.saved_model_proto_)); } else { return Status(absl::StatusCode::kNotFound, "Could not find SavedModel .pb or .pbtxt at supplied export " @@ -65,7 +59,7 @@ tensorflow::StatusOr TFPackage::Load(const std::string& path) { tf_package.variable_reader_ = std::make_unique( tensorflow::Env::Default(), tf_package.variables_filepath_); tensorflow::Tensor object_graph_tensor; - RETURN_IF_ERROR(tf_package.variable_reader_->Lookup( + TF_RETURN_IF_ERROR(tf_package.variable_reader_->Lookup( tensorflow::kObjectGraphProtoKey, &object_graph_tensor)); const auto* object_graph_string = reinterpret_cast( From ade1ff06383ad39aa71a7d2320dfb4b47781bbf0 Mon Sep 17 00:00:00 2001 From: Ilia Sergachev Date: Thu, 13 Jul 2023 09:03:48 -0700 Subject: [PATCH 257/376] [XLA:GPU] Support specific kind of output transposes in Triton GEMM. This is about transposes that split the output dimension originating from the non-contracting dimension of the LHS operand at the same ratio. PiperOrigin-RevId: 547817630 --- .../xla/service/gpu/gemm_rewriter_triton.cc | 98 ++++++++++++++++--- .../xla/service/gpu/ir_emitter_triton.cc | 27 +++-- .../xla/service/gpu/ir_emitter_triton_test.cc | 60 ++++++++++++ 3 files changed, 163 insertions(+), 22 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton.cc b/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton.cc index 1c26b87c8ab9b2..9e0bfe08f4d3de 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton.cc @@ -199,9 +199,12 @@ class DimensionOrder { // `hlo` is currently supposed to be an operand of dot(); // dimension indices describing the operand // are stored along with the dimension order for later analysis. - explicit DimensionOrder(const HloInstruction* hlo, - const int64_t splittable_dimension_index = -1) - : splittable_dimension_index_(splittable_dimension_index) { + explicit DimensionOrder( + const HloInstruction* hlo, const int64_t splittable_dimension_index = -1, + const int64_t splittable_dimension_supported_major_size = 0) + : splittable_dimension_index_(splittable_dimension_index), + splittable_dimension_supported_major_part_size_( + splittable_dimension_supported_major_size) { dim_order_.reserve(hlo->shape().rank()); for (const int64_t i : hlo->shape().layout().minor_to_major()) { dim_order_.push_back({i, 0, hlo->shape().dimensions(i)}); @@ -214,7 +217,9 @@ class DimensionOrder { int operand_number, int64_t split_k = 1); // Create dimension order describing dot's output. - static DimensionOrder FromDotOutput(const HloInstruction& dot); + static DimensionOrder FromDotOutput( + const HloInstruction& dot, int64_t split_k = 1, + int64_t splittable_dimension_supported_major_part_size = 0); enum class TransformDirection { kInputToOutput, kOutputToInput }; @@ -258,6 +263,13 @@ class DimensionOrder { return splittable_dimension_index_; } + // Tells whether `size` major part of a dimension can be physically split. + bool IsSupportedSplittableDimensionMajorPartSize(int64_t size) const { + // 0 means no specific size requirement. + return splittable_dimension_supported_major_part_size_ == 0 || + splittable_dimension_supported_major_part_size_ == size; + } + // Tells that two dimension orders describe the same tensor physical layout. bool IsPhysicallyEquivalent(const DimensionOrder& other) const; @@ -276,6 +288,7 @@ class DimensionOrder { DimOrderVector dim_order_; const int64_t splittable_dimension_index_; + const int64_t splittable_dimension_supported_major_part_size_; }; using DimIterationSpec = TensorIterationSpec::DimIterationSpec; @@ -348,8 +361,20 @@ DimensionOrder DimensionOrder::FromDotOperand(const HloInstruction& dot, return DimensionOrder(operand); } -DimensionOrder DimensionOrder::FromDotOutput(const HloInstruction& dot) { - return DimensionOrder(&dot); +DimensionOrder DimensionOrder::FromDotOutput( + const HloInstruction& dot, const int64_t split_k, + const int64_t splittable_dimension_supported_major_part_size) { + // Allow non-contracting dimension originating from LHS to split if + // this dimension is split at the output at the same ratio as + // at the input. + int64_t splittable_dimension_index = -1; + if (splittable_dimension_supported_major_part_size > 1) { + // Split-K dimension is the first one in the output if present; + // LHS non-contracting follows (batch is absent in this case). + splittable_dimension_index = (split_k > 1) ? 1 : 0; + } + return DimensionOrder(&dot, splittable_dimension_index, + splittable_dimension_supported_major_part_size); } FusionDecision DimensionOrder::HandleBitcast(const HloInstruction* hlo, @@ -522,7 +547,8 @@ FusionDecision RequireTritonGemmSupportedDimOrder(const DimensionOrder& order) { } if (i == 0 || dim_order_vector[i - 1].target_dim_number != dim_number) { ++split_counters[dim_number]; - if (dim_number == order.SplittableDimensionIndex()) { + if (dim_number == order.SplittableDimensionIndex() && + order.IsSupportedSplittableDimensionMajorPartSize(size)) { if (split_counters[dim_number] > 1) { return "2nd split of a splittable dimension."; } @@ -813,7 +839,13 @@ class GemmRewriterTritonVisitor : public DfsHloRewriteVisitor { absl::flat_hash_map old_to_new_mapping; - auto fuse_inputs = [&](int operand_number) { + // Separate traversal from LHS and RHS inputs of the dot: they use + // differently shaped tiles but may go through same HLO graph nodes. + // Direct dot inputs have well defined dimension orders. + + auto fuse_inputs = [&](int operand_number) + -> StatusOr< + absl::flat_hash_map> { absl::flat_hash_map dim_orders; int operand_count_before = call_operands.size(); // Direct dot inputs have well defined dimension orders. @@ -821,12 +853,36 @@ class GemmRewriterTritonVisitor : public DfsHloRewriteVisitor { dot->mutable_operand(operand_number), DimensionOrder::FromDotOperand(*dot, operand_number), dim_orders, gpu_version_, old_to_new_mapping, call_operands, builder); - return call_operands.size() - operand_count_before; + TF_RET_CHECK(call_operands.size() - operand_count_before <= + DotFusionAnalysis::kMaxParameterPerScope); + return dim_orders; }; - // Separate traversal from LHS and RHS inputs of the dot: they use - // differently shaped tiles but may go through same HLO graph nodes. - TF_RET_CHECK(fuse_inputs(0) <= DotFusionAnalysis::kMaxParameterPerScope); - TF_RET_CHECK(fuse_inputs(1) <= DotFusionAnalysis::kMaxParameterPerScope); + // Check if non-contracting dimension originating from LHS operand in the + // output can be split. This currently requires this dimension being split + // in the operand the same way. + int64_t lhs_nc_split_major_part = -1; + { + TF_ASSIGN_OR_RETURN(const auto lhs_dim_orders, fuse_inputs(0)); + // Looking at first LHS parameter to find split non-contracting dimension + // is sufficient because currently all parameters of one scope have to use + // the same tiling. + auto first_lhs_parameter_it = lhs_dim_orders.cbegin(); + while (first_lhs_parameter_it != lhs_dim_orders.cend()) { + if (first_lhs_parameter_it->first->opcode() == HloOpcode::kParameter) { + break; + } + ++first_lhs_parameter_it; + } + if (first_lhs_parameter_it != lhs_dim_orders.cend()) { + const auto lhs_nc_iter_spec = DimensionOrderToTensorIterationSpec( + first_lhs_parameter_it + ->second)[NonContractingDimensionIndex(*dot, 0)]; + if (lhs_nc_iter_spec.size() > 1) { + lhs_nc_split_major_part = lhs_nc_iter_spec.at(1).count; + } + } + } + TF_RET_CHECK(fuse_inputs(1).ok()); Fuse(*dot, old_to_new_mapping, call_operands, builder); @@ -834,7 +890,9 @@ class GemmRewriterTritonVisitor : public DfsHloRewriteVisitor { // These describe _outputs_ of corresponding HLOs. absl::flat_hash_map out_dim_orders; - out_dim_orders.insert({dot, DimensionOrder::FromDotOutput(*dot)}); + out_dim_orders.insert( + {dot, DimensionOrder::FromDotOutput(*dot, /*split_k=*/1, + lhs_nc_split_major_part)}); HloInstruction* fusion_output = dot; bool output_changed = true; while (output_changed) { @@ -1270,7 +1328,17 @@ DotFusionAnalysis::DotFusionAnalysis(const HloComputation* dot_computation, } } - DimensionOrder dim_order = DimensionOrder::FromDotOutput(*dot); + int64_t lhs_nc_split_major_part_size = -1; + if (!ScopeParameters(Scope::LHS).empty()) { + const TensorIterationSpec::DimIterationSpec* lhs_nc_iter_spec = + IterSpec(Scope::LHS, *ScopeParameters(Scope::LHS).cbegin(), + NonContractingDimensionIndex(*dot, 0)); + if (lhs_nc_iter_spec->size() > 1) { + lhs_nc_split_major_part_size = lhs_nc_iter_spec->at(1).count; + } + } + DimensionOrder dim_order = DimensionOrder::FromDotOutput( + *dot, split_k, lhs_nc_split_major_part_size); const HloInstruction* output = dot; // Currently supported is one fusion output and one path from dot to it. while (!output->IsRoot()) { diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_triton.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_triton.cc index 73a3a378f6fa73..6065cf3236dcb2 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_triton.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_triton.cc @@ -905,13 +905,26 @@ StatusOr MatMulImpl( ->at(0) .stride; CHECK_GE(stride_out_batch, 1); - } else if (lhs_nc_split) { - // Dimension of the output produced by the non-contracting LHS one - // is physically contiguous even if the producing LHS one is split. - // Because the major part of the split is implemented using the batch - // logic stride_out_batch is populated here as the stride of the minor - // part times its size. - stride_out_batch = stride_out_m * m; + } + { + const TensorIterationSpec::DimIterationSpec* spec = analysis.IterSpec( + DotFusionAnalysis::Scope::OUTPUT, root, lhs_nc_out_idx); + if (spec->size() > 1) { + CHECK_EQ(spec->size(), 2); + // Support one specific kind of output transpose that splits the dimension + // originating from the split LHS non-contracting one. + CHECK(!have_batch); + CHECK(lhs_nc_split); + CHECK_EQ(spec->at(1).count, batch_size); + stride_out_batch = spec->at(1).stride; + } else if (lhs_nc_split) { + // Dimension of the output produced by the non-contracting LHS one + // is physically contiguous though the producing LHS one is split. + // Because the major part of the split is implemented using the batch + // logic stride_out_batch is populated here as the stride of the minor + // part times its size. + stride_out_batch = stride_out_m * m; + } } const int block_m = config.block_m(); diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_triton_test.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_triton_test.cc index 35ed530b9932f0..05f8619c08fce3 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_triton_test.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_triton_test.cc @@ -994,6 +994,66 @@ ENTRY e { EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/2e-2, /*arel=*/2e-2})); } +TEST_F(TritonGemmLevel2Test, SplitLHSOutputTransposeAloneIsNotFused) { + if (!GetCudaComputeCapability().IsAtLeast( + se::CudaComputeCapability::AMPERE)) { + GTEST_SKIP() << "No BF16 before Ampere."; + } + + const std::string kHloText = R"( +HloModule m + +ENTRY e { + p0 = s8[18,15000] parameter(0) + p0c = bf16[18,15000] convert(p0) + p1 = bf16[42,18] parameter(1) + d = bf16[15000,42] dot(p0c, p1), + lhs_contracting_dims={0}, rhs_contracting_dims={1} + r1 = bf16[5,200,15,42] reshape(d) + ROOT t1 = bf16[5,42,200,15] transpose(r1), dimensions={0,3,1,2} +})"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + GetOptimizedModule(kHloText)); + EXPECT_THAT(module->entry_computation()->root_instruction(), + GmockMatch(m::Transpose( + m::Fusion(m::Parameter(), m::Parameter()) + .WithFusionKind(HloInstruction::FusionKind::kCustom)))); + + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); +} + +TEST_F(TritonGemmLevel2Test, SplitLHSInputOutputIsFused) { + if (!GetCudaComputeCapability().IsAtLeast( + se::CudaComputeCapability::AMPERE)) { + GTEST_SKIP() << "No BF16 before Ampere."; + } + + const std::string kHloText = R"( +HloModule m + +ENTRY e { + p0 = s8[5,18,20,150] parameter(0) + p0c = bf16[5,18,20,150] convert(p0) + t0 = bf16[18,5,20,150] transpose(p0c), dimensions={1,0,2,3} + r0 = bf16[18,15000] reshape(t0) + p1 = bf16[42,18] parameter(1) + d = bf16[15000,42] dot(r0, p1), + lhs_contracting_dims={0}, rhs_contracting_dims={1} + r1 = bf16[5,20,150,42] reshape(d) + ROOT t1 = bf16[5,42,20,150] transpose(r1), dimensions={0,3,1,2} +})"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + GetOptimizedModule(kHloText)); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch(m::Fusion(m::Parameter(), m::Parameter()) + .WithFusionKind(HloInstruction::FusionKind::kCustom))); + + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); +} + TEST_F(TritonGemmTest, Naming) { const char* hlo_text = R"( HloModule t From 9fa5e774d4f86c2e41b04a6ac326578bdd516691 Mon Sep 17 00:00:00 2001 From: George Karpenkov Date: Thu, 13 Jul 2023 09:12:25 -0700 Subject: [PATCH 258/376] [XLA:GPU] [NFC] In Triton Gemm autotuning, do not copy the returned buffer: compare with the returned value directly PiperOrigin-RevId: 547819846 --- .../xla/service/gpu/autotuner_compile_util.cc | 25 ++------ .../xla/service/gpu/autotuner_compile_util.h | 12 +++- .../xla/service/gpu/triton_autotuner.cc | 62 ++++++++----------- 3 files changed, 39 insertions(+), 60 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/autotuner_compile_util.cc b/tensorflow/compiler/xla/service/gpu/autotuner_compile_util.cc index 7ee4c0b2fba5fd..85859f88d64cb2 100644 --- a/tensorflow/compiler/xla/service/gpu/autotuner_compile_util.cc +++ b/tensorflow/compiler/xla/service/gpu/autotuner_compile_util.cc @@ -126,11 +126,11 @@ AutotunerCompileUtil::AutotunerCompileUtil(const AutotuneConfig& config, opts_.set_xla_gpu_cuda_graph_level(0); } -StatusOr> +StatusOr> AutotunerCompileUtil::GenerateAndProfileExecutable( const AutotuneResult& config, const AutotuneCacheKey& cache_key, se::Stream* stream, absl::Span input_buffers, - ShapedBuffer output_buffer, GenerateModuleFn extractor) { + GenerateModuleFn extractor) { TF_ASSIGN_OR_RETURN(Executable * executable, Compile(config, cache_key, std::move(extractor))); @@ -154,25 +154,8 @@ AutotunerCompileUtil::GenerateAndProfileExecutable( Execute(*executable, std::move(execution_inputs))); TF_ASSIGN_OR_RETURN(absl::Duration timer_duration, timer.GetElapsedDuration()); - ScopedShapedBuffer result = execution_output.ConsumeResult(); - - // TODO(cheshire): Copying should not be required. Instead, we can add a new - // aliased parameter. - Shape shape = result.on_device_shape(); - TF_RET_CHECK(shape == output_buffer.on_device_shape()); - if (shape.IsTuple()) { - for (int64_t i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { - TF_RET_CHECK(!shape.tuple_shapes(i).IsTuple()); - stream->ThenMemcpy(output_buffer.buffers().mutable_element(ShapeIndex{i}), - result.buffer(ShapeIndex{i}), - ShapeUtil::ByteSizeOf(shape.tuple_shapes(i))); - } - } else { - stream->ThenMemcpy(output_buffer.buffers().mutable_element(ShapeIndex{}), - result.buffer(ShapeIndex{}), - ShapeUtil::ByteSizeOf(shape)); - } - return std::make_optional(timer_duration); + return std::make_optional( + timer_duration, execution_output.Commit().ConsumeResult()); } StatusOr AutotunerCompileUtil::Compile( diff --git a/tensorflow/compiler/xla/service/gpu/autotuner_compile_util.h b/tensorflow/compiler/xla/service/gpu/autotuner_compile_util.h index 7de054f6841221..41330dcc1ba4a2 100644 --- a/tensorflow/compiler/xla/service/gpu/autotuner_compile_util.h +++ b/tensorflow/compiler/xla/service/gpu/autotuner_compile_util.h @@ -61,16 +61,24 @@ class AutotunerCompileUtil { static StatusOr> Create( const AutotuneConfig& config, const DebugOptions& opts); + struct ProfilingOutput { + ProfilingOutput(absl::Duration duration, ScopedShapedBuffer&& buffer) + : duration(duration), output(std::move(buffer)) {} + + absl::Duration duration; + ScopedShapedBuffer output; + }; + // Generates an executable first, given the module generator function in // `extractor`. // // Runs the resulting executable with the given extractor, cached with // `(cache_key, config)`. Returns `std::nullopt` on expected failure, bad // `Status` otherwise. - StatusOr> GenerateAndProfileExecutable( + StatusOr> GenerateAndProfileExecutable( const AutotuneResult& config, const AutotuneCacheKey& cache_key, se::Stream* stream, absl::Span input_buffers, - ShapedBuffer output_buffer, GenerateModuleFn extractor); + GenerateModuleFn extractor); // Generic method to compile a generated module from `extractor` in isolation. // diff --git a/tensorflow/compiler/xla/service/gpu/triton_autotuner.cc b/tensorflow/compiler/xla/service/gpu/triton_autotuner.cc index ccaf09e6816c54..467bbf41718dd7 100644 --- a/tensorflow/compiler/xla/service/gpu/triton_autotuner.cc +++ b/tensorflow/compiler/xla/service/gpu/triton_autotuner.cc @@ -70,6 +70,8 @@ limitations under the License. namespace xla { namespace gpu { +using ProfilingOutput = AutotunerCompileUtil::ProfilingOutput; + namespace { // Constructs an autotuning key for a gemm performed in Triton. @@ -152,13 +154,7 @@ class TritonAutotunerVisitor : public DfsHloRewriteVisitor { se::RedzoneAllocator rz_allocator, AutotunerUtil::CreateRedzoneAllocator(config_, debug_opts)); - se::DeviceMemoryBase reference_buffer; - if (config_.should_check_correctness()) { - TF_ASSIGN_OR_RETURN( - reference_buffer, - rz_allocator.AllocateBytes(ShapeUtil::ByteSizeOf(root->shape()))); - } - + std::optional reference_buffer; BufferComparator comparator(root->shape(), fusion.parent()->config()); const std::vector configurations = @@ -201,14 +197,11 @@ class TritonAutotunerVisitor : public DfsHloRewriteVisitor { } if (config_.should_check_correctness()) { - TF_RETURN_IF_ERROR(RunMatmulWithCublas(fusion, stream, allocator, inputs, - reference_buffer, cache_key)); + TF_ASSIGN_OR_RETURN( + reference_buffer, + RunMatmulWithCublas(fusion, stream, allocator, inputs, cache_key)); } - TF_ASSIGN_OR_RETURN( - se::DeviceMemoryBase output_buffer, - rz_allocator.AllocateBytes(ShapeUtil::ByteSizeOf(root->shape()))); - std::vector results; for (const AutotuneResult::TritonGemmKey& conf : configurations) { VLOG(1) << "Trying triton tiling: " << conf.ShortDebugString(); @@ -216,17 +209,18 @@ class TritonAutotunerVisitor : public DfsHloRewriteVisitor { AutotuneResult res; *res.mutable_triton() = conf; - TF_ASSIGN_OR_RETURN(std::optional duration, - RunMatmulWithConfig(fusion, conf, stream, inputs, - output_buffer, cache_key)); + TF_ASSIGN_OR_RETURN( + std::optional profiling_output, + RunMatmulWithConfig(fusion, conf, stream, inputs, cache_key)); - if (!duration) { + if (!profiling_output) { VLOG(1) << "Skipping this tiling."; continue; } - VLOG(1) << "Running the kernel took: " << *duration; - *res.mutable_run_time() = tsl::proto_utils::ToDurationProto(*duration); + VLOG(1) << "Running the kernel took: " << profiling_output->duration; + *res.mutable_run_time() = + tsl::proto_utils::ToDurationProto(profiling_output->duration); if (config_.should_check_correctness()) { TF_ASSIGN_OR_RETURN( @@ -243,8 +237,9 @@ class TritonAutotunerVisitor : public DfsHloRewriteVisitor { TF_ASSIGN_OR_RETURN( bool outputs_match, - comparator.CompareEqual(stream, /*current=*/output_buffer, - /*expected=*/reference_buffer)); + comparator.CompareEqual( + stream, /*current=*/profiling_output->output.root_buffer(), + /*expected=*/reference_buffer->root_buffer())); if (!outputs_match) { LOG(ERROR) << "Results do not match the reference. " << "This is likely a bug/unexpected loss of precision."; @@ -269,19 +264,16 @@ class TritonAutotunerVisitor : public DfsHloRewriteVisitor { // // `cache_key`: The cache key corresponding to the code of the fusion and the // device type. Passing it to avoid recalculating it everywhere it's needed. - StatusOr> RunMatmulWithConfig( + StatusOr> RunMatmulWithConfig( const HloComputation& hlo_computation, const AutotuneResult::TritonGemmKey& autotune_config, se::Stream* stream, absl::Span input_buffers, - se::DeviceMemoryBase output_buffer, const AutotuneCacheKey& cache_key) { + const AutotuneCacheKey& cache_key) { AutotuneResult config; *config.mutable_triton() = autotune_config; - ShapedBuffer output(hlo_computation.root_instruction()->shape(), 0); - output.set_buffer(output_buffer, ShapeIndex{}); - return autotuner_compile_util_->GenerateAndProfileExecutable( - config, cache_key, stream, input_buffers, std::move(output), [&] { + config, cache_key, stream, input_buffers, [&] { return TritonGemmAutotuneExtractor( autotune_config, GetGpuDeviceInfo(config_.GetExecutor()), hlo_computation.FusionInstruction()); @@ -333,11 +325,11 @@ class TritonAutotunerVisitor : public DfsHloRewriteVisitor { // // `cache_key`: The cache key corresponding to the code of the fusion and the // device type. Passing it to avoid recalculating it everywhere it's needed. - Status RunMatmulWithCublas( + StatusOr RunMatmulWithCublas( const HloComputation& original_computation, se::Stream* stream, se::DeviceMemoryAllocator* allocator, absl::Span input_buffers, - se::DeviceMemoryBase output_buffer, const AutotuneCacheKey& cache_key) { + const AutotuneCacheKey& cache_key) { AutotuneResult res; // We need some value to cache compilation. We associate the compiled module @@ -346,19 +338,15 @@ class TritonAutotunerVisitor : public DfsHloRewriteVisitor { gemm.set_algorithm(0); *res.mutable_gemm() = gemm; - ShapedBuffer output(original_computation.root_instruction()->shape(), 0); - output.set_buffer(output_buffer, ShapeIndex{}); - - TF_ASSIGN_OR_RETURN(std::optional duration, + TF_ASSIGN_OR_RETURN(std::optional output, autotuner_compile_util_->GenerateAndProfileExecutable( - res, cache_key, stream, input_buffers, - std::move(output), [&] { + res, cache_key, stream, input_buffers, [&] { return CublasGemmAutotuneExtractor( GetGpuDeviceInfo(config_.GetExecutor()), &original_computation); })); - TF_RET_CHECK(duration.has_value()); - return OkStatus(); + TF_RET_CHECK(output.has_value()); + return std::move(output->output); } StatusOr> CublasGemmAutotuneExtractor( From 3e0503c2e23d62436e5f78abf5ff88dd703a07b2 Mon Sep 17 00:00:00 2001 From: Swachhand Lokhande Date: Thu, 13 Jul 2023 10:25:24 -0700 Subject: [PATCH 259/376] Delete DeviceCompiler when a new PjRtClient is created for DEVICE_GPU. Also change a python e2e test to use GPU device instead of XLA_GPU device. PiperOrigin-RevId: 547841258 --- tensorflow/compiler/jit/BUILD | 6 ++++ tensorflow/core/common_runtime/gpu/BUILD | 3 ++ .../core/common_runtime/gpu/gpu_device.cc | 32 +++++++++++++++++-- .../python/compiler/xla/pjrt_compile_test.py | 8 +---- 4 files changed, 39 insertions(+), 10 deletions(-) diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index ab84540ec8c683..5241b4a3c5e08b 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -612,6 +612,7 @@ cc_library( hdrs = ["xla_compile_util.h"], visibility = [ ":internal", + "//tensorflow/core/common_runtime/gpu:__pkg__", "//tensorflow/core/common_runtime/next_pluggable_device:__pkg__", ], deps = [ @@ -654,6 +655,7 @@ cc_library( copts = tf_copts(), visibility = [ ":internal", + "//tensorflow/core/common_runtime/gpu:__pkg__", "//tensorflow/core/common_runtime/next_pluggable_device:__pkg__", ], deps = [ @@ -1417,6 +1419,10 @@ cc_library( name = "device_compilation_profiler", srcs = ["device_compilation_profiler.cc"], hdrs = ["device_compilation_profiler.h"], + visibility = [ + ":internal", + "//tensorflow/core/common_runtime/gpu:__pkg__", + ], deps = [ ":xla_activity_listener", ":xla_activity_proto_cc", diff --git a/tensorflow/core/common_runtime/gpu/BUILD b/tensorflow/core/common_runtime/gpu/BUILD index d343d9a9ad9eb9..c7d8d4bc121a6c 100644 --- a/tensorflow/core/common_runtime/gpu/BUILD +++ b/tensorflow/core/common_runtime/gpu/BUILD @@ -201,6 +201,9 @@ tf_cuda_library( "//tensorflow/compiler/tf2xla:layout_util", "//tensorflow/compiler/jit:flags", "//tensorflow/compiler/jit:pjrt_device_context", + "//tensorflow/compiler/jit:device_compilation_profiler", + "//tensorflow/compiler/jit:device_compiler", + "//tensorflow/compiler/jit:xla_compile_util", "//tensorflow/compiler/xla/pjrt/gpu:gpu_helpers", "//tensorflow/compiler/xla/pjrt/gpu:se_gpu_pjrt_client", "//tensorflow/compiler/xla/stream_executor:tf_allocator_adapter", diff --git a/tensorflow/core/common_runtime/gpu/gpu_device.cc b/tensorflow/core/common_runtime/gpu/gpu_device.cc index fe1dbf05f75e72..d13a73102ff260 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_device.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_device.cc @@ -78,12 +78,16 @@ limitations under the License. #include "tensorflow/core/platform/rocm.h" #endif #ifdef TF_GPU_USE_PJRT +#include "tensorflow/compiler/jit/device_compilation_profiler.h" +#include "tensorflow/compiler/jit/device_compiler.h" #include "tensorflow/compiler/jit/flags.h" +#include "tensorflow/compiler/jit/xla_compile_util.h" #include "tensorflow/compiler/xla/pjrt/gpu/gpu_helpers.h" #include "tensorflow/compiler/xla/pjrt/gpu/se_gpu_pjrt_client.h" #include "tensorflow/compiler/xla/pjrt/pjrt_client.h" #include "tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h" #include "tensorflow/compiler/xla/stream_executor/device_host_allocator.h" +#include "tensorflow/core/tfrt/common/global_state.h" #endif // TF_GPU_USE_PJRT #include "tensorflow/compiler/xla/stream_executor/gpu/gpu_stream.h" #include "tensorflow/compiler/xla/stream_executor/platform/dso_loader.h" @@ -112,6 +116,21 @@ limitations under the License. namespace tensorflow { namespace { +#ifdef TF_GPU_USE_PJRT +using PjRtDeviceCompiler = + DeviceCompiler; + +void DeleteDeviceCompiler(const DeviceType& device_type) { + ResourceMgr* rm = tfrt_global::GetTFGlobalResourceMgr(); + rm->Delete(rm->default_container(), + GetPjRtDeviceCompilerResourceName(device_type)) + .IgnoreError(); + rm->Delete( + rm->default_container(), + GetPjRtDeviceCompilationProfilerResourceName(device_type)) + .IgnoreError(); +} +#endif // TF_GPU_USE_PJRT // Returns priority for the given virtual GPU id from the session options. // Returns 0 if no virtual devices are specified. @@ -1755,15 +1774,22 @@ Status BaseGPUDeviceFactory::CreateDevices( /*should_stage_host_to_device_transfers=*/true, /*gpu_run_options=*/std::move(gpu_run_options)); - return SetPjRtClientInTFGlobalResourceManager(DeviceType(DEVICE_GPU), - std::move(pjrt_client)); + TF_RETURN_IF_ERROR(SetPjRtClientInTFGlobalResourceManager( + DeviceType(DEVICE_GPU), std::move(pjrt_client))); + // We don't forsee a realistic scenario where the PjRtClient is deleted and + // replaced by a new one, except in unit tests. However, if this does happen, + // the DeviceCompiler that stores the PjRtLoadedExecutables built by the old + // PjRtClient needs to be deleted. A new DeviceCompiler using the current + // PjRtClient will be created on-demand when compilation is requested (if one + // doesn't exist already). + DeleteDeviceCompiler(DeviceType(DEVICE_GPU)); #else TF_RETURN_IF_ERROR(CreateGPUDevice(options, name_prefix, tf_device_id, /*dev_locality=*/it->second, gpu_allocator, devices)); } - return OkStatus(); #endif // TF_GPU_USE_PJRT + return OkStatus(); } static string GetShortDeviceDescription( diff --git a/tensorflow/python/compiler/xla/pjrt_compile_test.py b/tensorflow/python/compiler/xla/pjrt_compile_test.py index 31ef70b0fd2166..ddfa81e1b9408b 100644 --- a/tensorflow/python/compiler/xla/pjrt_compile_test.py +++ b/tensorflow/python/compiler/xla/pjrt_compile_test.py @@ -61,13 +61,7 @@ def bar(x, y): x.assign(y) y.assign_add([1.0, 1.0]) - # Currently PjRt only supports compilation and execution for the XLA_GPU - # device to unblock development. Support for non-XLA devices (CPU/GPU/single - # core TPU) is going to be added soon, after which support for XLA_* devices - # will be dropped. - # TODO(b/255826209): Modify the test as we progress towards supporting - # non-XLA devices. - with ops.device("/device:XLA_GPU:0"): + with ops.device("/device:GPU:0"): # Function call with scalars self.assertEqual(self.evaluate(foo(1, 2)), 4) From fff678b3925c5a0fccf9d0f3f64d7c357b0cb17e Mon Sep 17 00:00:00 2001 From: Hyeontaek Lim Date: Thu, 13 Jul 2023 10:51:52 -0700 Subject: [PATCH 260/376] [IFRT] Rollback for serialization/deserialization for shardings Breaks some platform's build PiperOrigin-RevId: 547850348 --- tensorflow/compiler/xla/python/ifrt/BUILD | 28 -- tensorflow/compiler/xla/python/ifrt/device.cc | 24 -- tensorflow/compiler/xla/python/ifrt/device.h | 11 - tensorflow/compiler/xla/python/ifrt/shape.cc | 25 -- tensorflow/compiler/xla/python/ifrt/shape.h | 8 - .../compiler/xla/python/ifrt/sharding.cc | 26 +- .../compiler/xla/python/ifrt/sharding.h | 15 +- .../compiler/xla/python/ifrt/sharding.proto | 46 ---- .../xla/python/ifrt/sharding_serdes.cc | 240 ------------------ .../xla/python/ifrt/sharding_serdes.h | 48 ---- .../xla/python/ifrt/sharding_serdes_test.cc | 157 ------------ .../compiler/xla/python/ifrt/types.proto | 31 --- .../compiler/xla/python/pjrt_ifrt/BUILD | 24 -- .../xla/python/pjrt_ifrt/xla_sharding.proto | 27 -- .../python/pjrt_ifrt/xla_sharding_serdes.cc | 79 ------ .../pjrt_ifrt/xla_sharding_serdes_test.cc | 95 ------- 16 files changed, 19 insertions(+), 865 deletions(-) delete mode 100644 tensorflow/compiler/xla/python/ifrt/sharding.proto delete mode 100644 tensorflow/compiler/xla/python/ifrt/sharding_serdes.cc delete mode 100644 tensorflow/compiler/xla/python/ifrt/sharding_serdes.h delete mode 100644 tensorflow/compiler/xla/python/ifrt/sharding_serdes_test.cc delete mode 100644 tensorflow/compiler/xla/python/ifrt/types.proto delete mode 100644 tensorflow/compiler/xla/python/pjrt_ifrt/xla_sharding.proto delete mode 100644 tensorflow/compiler/xla/python/pjrt_ifrt/xla_sharding_serdes.cc delete mode 100644 tensorflow/compiler/xla/python/pjrt_ifrt/xla_sharding_serdes_test.cc diff --git a/tensorflow/compiler/xla/python/ifrt/BUILD b/tensorflow/compiler/xla/python/ifrt/BUILD index ba797b6b6b00eb..ec8774cd5ef187 100644 --- a/tensorflow/compiler/xla/python/ifrt/BUILD +++ b/tensorflow/compiler/xla/python/ifrt/BUILD @@ -45,7 +45,6 @@ cc_library( "index_domain.cc", "shape.cc", "sharding.cc", - "sharding_serdes.cc", "tuple.cc", "value.cc", ], @@ -62,21 +61,17 @@ cc_library( "index_domain.h", "shape.h", "sharding.h", - "sharding_serdes.h", "tuple.h", "value.h", ], deps = [ ":serdes", - ":sharding_proto_cc", - ":types_proto_cc", "//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/pjrt:pjrt_client", "//tensorflow/compiler/xla/python/ifrt/ir", "//tensorflow/tsl/platform:logging", - "//tensorflow/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:inlined_vector", @@ -314,26 +309,3 @@ tf_proto_library( name = "serdes_proto", srcs = ["serdes.proto"], ) - -xla_cc_test( - name = "sharding_serdes_test", - srcs = ["sharding_serdes_test.cc"], - deps = [ - ":ifrt", - ":mock", - ":serdes", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_googletest//:gtest_main", - ], -) - -tf_proto_library( - name = "types_proto", - srcs = ["types.proto"], -) - -tf_proto_library( - name = "sharding_proto", - srcs = ["sharding.proto"], - protodeps = [":types_proto"], -) diff --git a/tensorflow/compiler/xla/python/ifrt/device.cc b/tensorflow/compiler/xla/python/ifrt/device.cc index a549a811de6de6..0f02149ae48a64 100644 --- a/tensorflow/compiler/xla/python/ifrt/device.cc +++ b/tensorflow/compiler/xla/python/ifrt/device.cc @@ -15,35 +15,11 @@ limitations under the License. #include "tensorflow/compiler/xla/python/ifrt/device.h" -#include #include -#include "tensorflow/compiler/xla/python/ifrt/client.h" -#include "tensorflow/compiler/xla/python/ifrt/types.pb.h" - namespace xla { namespace ifrt { -StatusOr DeviceList::FromProto(Client* client, - const DeviceListProto& proto) { - DeviceList::Devices devices; - devices.reserve(proto.device_ids_size()); - for (int device_id : proto.device_ids()) { - TF_ASSIGN_OR_RETURN(Device * device, client->LookupDevice(device_id)); - devices.push_back(device); - } - return DeviceList(std::move(devices)); -} - -DeviceListProto DeviceList::ToProto() const { - DeviceListProto proto; - proto.mutable_device_ids()->Reserve(devices().size()); - for (Device* device : devices()) { - proto.mutable_device_ids()->AddAlreadyReserved(device->id()); - } - return proto; -} - std::vector GetDeviceIds(DeviceList device_list) { std::vector ids; ids.reserve(device_list.devices().size()); diff --git a/tensorflow/compiler/xla/python/ifrt/device.h b/tensorflow/compiler/xla/python/ifrt/device.h index d54afa190deaa9..a2d5f61dd35c2a 100644 --- a/tensorflow/compiler/xla/python/ifrt/device.h +++ b/tensorflow/compiler/xla/python/ifrt/device.h @@ -21,13 +21,10 @@ limitations under the License. #include "absl/container/inlined_vector.h" #include "tensorflow/compiler/xla/pjrt/pjrt_client.h" -#include "tensorflow/compiler/xla/python/ifrt/types.pb.h" namespace xla { namespace ifrt { -class Client; - // Short-term alias to reuse `xla::PjRtDevice` without a separate abstract type. using Device = ::xla::PjRtDevice; @@ -45,14 +42,6 @@ class DeviceList { explicit DeviceList(Devices devices) : devices_(std::move(devices)) {} - // Constructs `DeviceList` from `DeviceListProto`. Device ids in the proto - // must be consistent with the devices owned by `client'. - static StatusOr FromProto(Client* client, - const DeviceListProto& proto); - - // Returns a `DeviceListProto` representation. - DeviceListProto ToProto() const; - absl::Span devices() const { return devices_; } int size() const { return devices_.size(); } diff --git a/tensorflow/compiler/xla/python/ifrt/shape.cc b/tensorflow/compiler/xla/python/ifrt/shape.cc index 07e8e2b81494a5..bd3ff1fc8e08b6 100644 --- a/tensorflow/compiler/xla/python/ifrt/shape.cc +++ b/tensorflow/compiler/xla/python/ifrt/shape.cc @@ -17,37 +17,12 @@ limitations under the License. #include #include -#include #include "absl/strings/str_join.h" -#include "tensorflow/compiler/xla/python/ifrt/types.pb.h" -#include "tensorflow/compiler/xla/util.h" namespace xla { namespace ifrt { -StatusOr Shape::FromProto(const ShapeProto& proto) { - Shape::Dimensions dims; - dims.reserve(proto.dims_size()); - for (int64_t dim : proto.dims()) { - if (dim < 0) { - return InvalidArgument( - "Shape expects non-negative dimension sizes, but got %d", dim); - } - dims.push_back(dim); - } - return Shape(std::move(dims)); -} - -ShapeProto Shape::ToProto() const { - ShapeProto proto; - proto.mutable_dims()->Reserve(dims().size()); - for (int64_t dim : dims()) { - proto.mutable_dims()->AddAlreadyReserved(dim); - } - return proto; -} - int64_t Shape::num_elements() const { int64_t count = 1; for (int64_t d : dims_) { diff --git a/tensorflow/compiler/xla/python/ifrt/shape.h b/tensorflow/compiler/xla/python/ifrt/shape.h index f3ce028789d5ef..3558e3518ed84d 100644 --- a/tensorflow/compiler/xla/python/ifrt/shape.h +++ b/tensorflow/compiler/xla/python/ifrt/shape.h @@ -22,8 +22,6 @@ limitations under the License. #include "absl/container/inlined_vector.h" #include "absl/types/span.h" -#include "tensorflow/compiler/xla/python/ifrt/types.pb.h" -#include "tensorflow/compiler/xla/statusor.h" namespace xla { namespace ifrt { @@ -44,12 +42,6 @@ class Shape { Shape& operator=(const Shape&) = default; Shape& operator=(Shape&&) = default; - // Constructs `Shape` from `ShapeProto`. - static StatusOr FromProto(const ShapeProto& proto); - - // Returns a `ShapeProto` representation. - ShapeProto ToProto() const; - absl::Span dims() const { return dims_; } bool operator==(const Shape& other) const { return dims_ == other.dims_; } diff --git a/tensorflow/compiler/xla/python/ifrt/sharding.cc b/tensorflow/compiler/xla/python/ifrt/sharding.cc index 8caaf9f12a83e4..f057ad53fccf83 100644 --- a/tensorflow/compiler/xla/python/ifrt/sharding.cc +++ b/tensorflow/compiler/xla/python/ifrt/sharding.cc @@ -159,10 +159,8 @@ std::ostream& operator<<(std::ostream& os, const Sharding& sharding) { return os << sharding.DebugString(); } -std::unique_ptr SingleDeviceSharding::Create( - Device* device) { - return std::unique_ptr( - new SingleDeviceSharding(device)); +std::unique_ptr SingleDeviceSharding::Create(Device* device) { + return std::unique_ptr(new SingleDeviceSharding(device)); } StatusOr>>> @@ -189,9 +187,8 @@ std::string SingleDeviceSharding::DebugString() const { devices_.front()->ToString()); } -std::unique_ptr OpaqueSharding::Create(DeviceList devices) { - return std::unique_ptr( - new OpaqueSharding(std::move(devices))); +std::unique_ptr OpaqueSharding::Create(DeviceList devices) { + return std::unique_ptr(new OpaqueSharding(std::move(devices))); } OpaqueSharding::OpaqueSharding(DeviceList devices) @@ -220,10 +217,10 @@ std::string OpaqueSharding::DebugString() const { })); } -std::unique_ptr ConcreteSharding::Create( +std::unique_ptr ConcreteSharding::Create( DeviceList devices, Shape shape, std::vector shard_shapes) { CHECK_EQ(devices.size(), shard_shapes.size()); - return std::unique_ptr(new ConcreteSharding( + return std::unique_ptr(new ConcreteSharding( std::move(devices), std::move(shape), std::move(shard_shapes))); } @@ -273,9 +270,10 @@ std::string ConcreteSharding::DebugString() const { })); } -std::unique_ptr ConcreteEvenSharding::Create( - DeviceList devices, Shape shape, Shape shard_shape) { - return std::unique_ptr(new ConcreteEvenSharding( +std::unique_ptr ConcreteEvenSharding::Create(DeviceList devices, + Shape shape, + Shape shard_shape) { + return std::unique_ptr(new ConcreteEvenSharding( std::move(devices), std::move(shape), std::move(shard_shape))); } @@ -320,7 +318,7 @@ std::string ConcreteEvenSharding::DebugString() const { shape_.DebugString(), shard_shape_.DebugString()); } -StatusOr> ShardingParamSharding::Create( +StatusOr> ShardingParamSharding::Create( ShardingParam sharding_param, DeviceList devices) { int64_t device_count = absl::c_accumulate(sharding_param.minor_to_major().axis_sizes, 1, @@ -331,7 +329,7 @@ StatusOr> ShardingParamSharding::Create( "%d", device_count, devices.size()); } - return std::unique_ptr( + return std::unique_ptr( new ShardingParamSharding(std::move(sharding_param), std::move(devices))); } diff --git a/tensorflow/compiler/xla/python/ifrt/sharding.h b/tensorflow/compiler/xla/python/ifrt/sharding.h index 6e3d30e99d2584..375cedc16a0a68 100644 --- a/tensorflow/compiler/xla/python/ifrt/sharding.h +++ b/tensorflow/compiler/xla/python/ifrt/sharding.h @@ -83,7 +83,7 @@ class SingleDeviceSharding final : public llvm::RTTIExtends { public: // Creates a single-device sharding. - static std::unique_ptr Create(Device* device); + static std::unique_ptr Create(Device* device); // Sharding implementation. @@ -110,7 +110,7 @@ class SingleDeviceSharding final class OpaqueSharding : public llvm::RTTIExtends { public: // Creates an opaque sharding. `Disassemble()` will fail. - static std::unique_ptr Create(DeviceList devices); + static std::unique_ptr Create(DeviceList devices); // Sharding implementation. @@ -138,8 +138,8 @@ class ConcreteSharding : public llvm::RTTIExtends { public: // Creates a concrete sharding that may contain non-identical shard shapes. // REQUIRES: devices.size() == shard_shapes.size() - static std::unique_ptr Create( - DeviceList devices, Shape shape, std::vector shard_shapes); + static std::unique_ptr Create(DeviceList devices, Shape shape, + std::vector shard_shapes); Shape shape() const { DCHECK(this); @@ -179,9 +179,8 @@ class ConcreteEvenSharding : public llvm::RTTIExtends { public: // Creates a concrete even sharding. - static std::unique_ptr Create(DeviceList devices, - Shape shape, - Shape shard_shape); + static std::unique_ptr Create(DeviceList devices, Shape shape, + Shape shard_shape); Shape shape() const { DCHECK(this); @@ -217,7 +216,7 @@ class ConcreteEvenSharding class ShardingParamSharding : public llvm::RTTIExtends { public: - static StatusOr> Create( + static StatusOr> Create( ShardingParam sharding_param, DeviceList devices); StatusOr>>> diff --git a/tensorflow/compiler/xla/python/ifrt/sharding.proto b/tensorflow/compiler/xla/python/ifrt/sharding.proto deleted file mode 100644 index 066bce11413998..00000000000000 --- a/tensorflow/compiler/xla/python/ifrt/sharding.proto +++ /dev/null @@ -1,46 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -syntax = "proto3"; - -package xla.ifrt; - -import "tensorflow/compiler/xla/python/ifrt/types.proto"; - -// Wire format for `SingleDeviceSharding`. -message SingleDeviceShardingProto { - // Serialization and deserialization are expected to ensure that device ids - // are stable across proto construction and consumption. - int32 device_id = 1; -} - -// Wire format for `OpaqueSharding`. -message OpaqueShardingProto { - DeviceListProto devices = 1; -} - -// Wire format for `ConcreteSharding`. -message ConcreteShardingProto { - DeviceListProto devices = 1; - ShapeProto shape = 2; - repeated ShapeProto shard_shapes = 3; -} - -// Wire format for `ConcreteEvenSharding`. -message ConcreteEvenShardingProto { - DeviceListProto devices = 1; - ShapeProto shape = 2; - ShapeProto shard_shape = 3; -} diff --git a/tensorflow/compiler/xla/python/ifrt/sharding_serdes.cc b/tensorflow/compiler/xla/python/ifrt/sharding_serdes.cc deleted file mode 100644 index d9ade8d6a62b96..00000000000000 --- a/tensorflow/compiler/xla/python/ifrt/sharding_serdes.cc +++ /dev/null @@ -1,240 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/python/ifrt/sharding_serdes.h" - -#include -#include -#include -#include - -#include "tensorflow/compiler/xla/python/ifrt/client.h" -#include "tensorflow/compiler/xla/python/ifrt/device.h" -#include "tensorflow/compiler/xla/python/ifrt/serdes.h" -#include "tensorflow/compiler/xla/python/ifrt/shape.h" -#include "tensorflow/compiler/xla/python/ifrt/sharding.h" -#include "tensorflow/compiler/xla/python/ifrt/sharding.pb.h" -#include "tensorflow/tsl/platform/statusor.h" - -namespace xla { -namespace ifrt { - -char DeserializeShardingOptions::ID = 0; - -namespace { - -// Serialization/deserialization for `SingleDeviceSharding`. -class SingleDeviceShardingSerDes - : public llvm::RTTIExtends { - public: - absl::string_view type_name() const override { - return "xla::ifrt::SingleDeviceSharding"; - } - - absl::StatusOr Serialize(Serializable& serializable) override { - const SingleDeviceSharding& sharding = - llvm::cast(serializable); - SingleDeviceShardingProto proto; - proto.set_device_id(sharding.devices().front()->id()); - return proto.SerializeAsString(); - } - - absl::StatusOr> Deserialize( - const std::string& serialized, - std::unique_ptr options) override { - TF_ASSIGN_OR_RETURN(auto deserialize_sharding_options, - GetDeserializeShardingOptions(std::move(options))); - SingleDeviceShardingProto proto; - if (!proto.ParseFromString(serialized)) { - return absl::InvalidArgumentError( - "Failed to parse serialized SimpleDeviceSharding"); - } - TF_ASSIGN_OR_RETURN( - Device * device, - deserialize_sharding_options->client->LookupDevice(proto.device_id())); - return SingleDeviceSharding::Create(device); - } - - static char ID; // NOLINT -}; - -// Serialization/deserialization for `OpaqueSharding`. -class OpaqueShardingSerDes - : public llvm::RTTIExtends { - public: - absl::string_view type_name() const override { - return "xla::ifrt::OpaqueSharding"; - } - - absl::StatusOr Serialize(Serializable& serializable) override { - const OpaqueSharding& sharding = llvm::cast(serializable); - OpaqueShardingProto proto; - *proto.mutable_devices() = sharding.devices().ToProto(); - return proto.SerializeAsString(); - } - - absl::StatusOr> Deserialize( - const std::string& serialized, - std::unique_ptr options) override { - TF_ASSIGN_OR_RETURN(auto deserialize_sharding_options, - GetDeserializeShardingOptions(std::move(options))); - - OpaqueShardingProto proto; - if (!proto.ParseFromString(serialized)) { - return absl::InvalidArgumentError( - "Failed to parse serialized OpaqueSharding"); - } - TF_ASSIGN_OR_RETURN(auto devices, DeviceList::FromProto( - deserialize_sharding_options->client, - proto.devices())); - return OpaqueSharding::Create(std::move(devices)); - } - - static char ID; // NOLINT -}; - -// Serialization/deserialization for `ConcreteSharding`. -class ConcreteShardingSerDes - : public llvm::RTTIExtends { - public: - absl::string_view type_name() const override { - return "xla::ifrt::ConcreteSharding"; - } - - absl::StatusOr Serialize(Serializable& serializable) override { - const ConcreteSharding& sharding = - llvm::cast(serializable); - ConcreteShardingProto proto; - *proto.mutable_devices() = sharding.devices().ToProto(); - *proto.mutable_shape() = sharding.shape().ToProto(); - for (const Shape& shape : sharding.shard_shapes()) { - *proto.add_shard_shapes() = shape.ToProto(); - } - return proto.SerializeAsString(); - } - - absl::StatusOr> Deserialize( - const std::string& serialized, - std::unique_ptr options) override { - TF_ASSIGN_OR_RETURN(auto deserialize_sharding_options, - GetDeserializeShardingOptions(std::move(options))); - - ConcreteShardingProto proto; - if (!proto.ParseFromString(serialized)) { - return absl::InvalidArgumentError( - "Failed to parse serialized ConcreteSharding"); - } - TF_ASSIGN_OR_RETURN(auto devices, DeviceList::FromProto( - deserialize_sharding_options->client, - proto.devices())); - TF_ASSIGN_OR_RETURN(auto shape, Shape::FromProto(proto.shape())); - std::vector shard_shapes; - shard_shapes.reserve(proto.shard_shapes_size()); - for (const auto& shard_shape_proto : proto.shard_shapes()) { - TF_ASSIGN_OR_RETURN(auto shard_shape, - Shape::FromProto(shard_shape_proto)); - shard_shapes.push_back(std::move(shard_shape)); - } - return ConcreteSharding::Create(std::move(devices), std::move(shape), - std::move(shard_shapes)); - } - - static char ID; // NOLINT -}; - -// Serialization/deserialization for `ConcreteEvenSharding`. -class ConcreteEvenShardingSerDes - : public llvm::RTTIExtends { - public: - absl::string_view type_name() const override { - return "xla::ifrt::ConcreteEvenSharding"; - } - - absl::StatusOr Serialize(Serializable& serializable) override { - const ConcreteEvenSharding& sharding = - llvm::cast(serializable); - ConcreteEvenShardingProto proto; - *proto.mutable_devices() = sharding.devices().ToProto(); - *proto.mutable_shape() = sharding.shape().ToProto(); - *proto.mutable_shard_shape() = sharding.shard_shape().ToProto(); - return proto.SerializeAsString(); - } - - absl::StatusOr> Deserialize( - const std::string& serialized, - std::unique_ptr options) override { - TF_ASSIGN_OR_RETURN(auto deserialize_sharding_options, - GetDeserializeShardingOptions(std::move(options))); - - ConcreteEvenShardingProto proto; - if (!proto.ParseFromString(serialized)) { - return absl::InvalidArgumentError( - "Failed to parse serialized ConcreteEvenSharding"); - } - TF_ASSIGN_OR_RETURN(auto devices, DeviceList::FromProto( - deserialize_sharding_options->client, - proto.devices())); - TF_ASSIGN_OR_RETURN(auto shape, Shape::FromProto(proto.shape())); - TF_ASSIGN_OR_RETURN(auto shard_shape, - Shape::FromProto(proto.shard_shape())); - return ConcreteEvenSharding::Create(std::move(devices), std::move(shape), - std::move(shard_shape)); - } - - static char ID; // NOLINT -}; - -// TODO(hyeontaek): Implement `ShardingParamShardingSerDes`. - -[[maybe_unused]] char SingleDeviceShardingSerDes::ID = 0; // NOLINT -[[maybe_unused]] char OpaqueShardingSerDes::ID = 0; // NOLINT -[[maybe_unused]] char ConcreteShardingSerDes::ID = 0; // NOLINT -[[maybe_unused]] char ConcreteEvenShardingSerDes::ID = 0; // NOLINT - -// clang-format off -bool register_single_device_sharding_serdes = ([]{ - RegisterSerDes( - std::make_unique()); -}(), true); - -bool register_opaque_sharding_serdes = ([]{ - RegisterSerDes( - std::make_unique()); -}(), true); - -bool register_concrete_sharding_serdes = ([]{ - RegisterSerDes( - std::make_unique()); -}(), true); - -bool register_concrete_even_sharding_serdes = ([]{ - RegisterSerDes( - std::make_unique()); -}(), true); -// clang-format on - -} // namespace - -StatusOr> -GetDeserializeShardingOptions(std::unique_ptr options) { - if (!llvm::isa(options.get())) { - return xla::InvalidArgument("options must be DeserializeShardingOptions"); - } - return std::unique_ptr( - static_cast(options.release())); -} - -} // namespace ifrt -} // namespace xla diff --git a/tensorflow/compiler/xla/python/ifrt/sharding_serdes.h b/tensorflow/compiler/xla/python/ifrt/sharding_serdes.h deleted file mode 100644 index 965670bcbc3401..00000000000000 --- a/tensorflow/compiler/xla/python/ifrt/sharding_serdes.h +++ /dev/null @@ -1,48 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_IFRT_SHARDING_SERDES_H_ -#define TENSORFLOW_COMPILER_XLA_PYTHON_IFRT_SHARDING_SERDES_H_ - -#include - -#include "llvm/Support/ExtensibleRTTI.h" -#include "tensorflow/compiler/xla/python/ifrt/serdes.h" -#include "tensorflow/compiler/xla/statusor.h" - -namespace xla { -namespace ifrt { - -class Client; - -// Options for deserializing shardings. -struct DeserializeShardingOptions - : llvm::RTTIExtends { - explicit DeserializeShardingOptions(Client* client) : client(client) {} - - static char ID; // NOLINT - - // The client whose devices will be used by deserialized shardings. - Client* client; -}; - -// Casts `DeserializeOptions` into `DeserializeShardingOptions`. -StatusOr> -GetDeserializeShardingOptions(std::unique_ptr options); - -} // namespace ifrt -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_PYTHON_IFRT_SHARDING_SERDES_H_ diff --git a/tensorflow/compiler/xla/python/ifrt/sharding_serdes_test.cc b/tensorflow/compiler/xla/python/ifrt/sharding_serdes_test.cc deleted file mode 100644 index 90efc6d9667167..00000000000000 --- a/tensorflow/compiler/xla/python/ifrt/sharding_serdes_test.cc +++ /dev/null @@ -1,157 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/python/ifrt/sharding_serdes.h" - -#include -#include -#include - -#include -#include -#include "absl/container/flat_hash_map.h" -#include "tensorflow/compiler/xla/python/ifrt/mock.h" -#include "tensorflow/compiler/xla/python/ifrt/serdes.h" -#include "tensorflow/compiler/xla/python/ifrt/sharding.h" - -namespace xla { -namespace ifrt { -namespace { - -using ::testing::ElementsAreArray; - -// Test fixture for sharding serialization and deserialization. It makes a mock -// client with a number of fake devices. Client implements `devices()` and -// `LookupDevice()`, and Device implements `id()`, with an arbitrary device ids -// assigned. -class ShardingSerDesTest : public ::testing::TestWithParam { - public: - void SetUp() override { - const int num_devices = GetParam(); - device_map_.reserve(num_devices); - devices_.reserve(num_devices); - for (int i = 0; i < num_devices; ++i) { - auto device = std::make_unique(); - ON_CALL(*device, id).WillByDefault([i]() { return i + 10; }); - devices_.push_back(device.get()); - device_map_.insert({i + 10, std::move(device)}); - } - client_ = std::make_unique(); - ON_CALL(*client_, devices) - .WillByDefault( - [this]() -> absl::Span { return devices_; }); - ON_CALL(*client_, LookupDevice) - .WillByDefault([this](int device_id) -> StatusOr { - auto it = device_map_.find(device_id); - if (it == device_map_.end()) { - return InvalidArgument("Unexpected device id: %d", device_id); - } - return it->second.get(); - }); - } - Client* client() { return client_.get(); } - - private: - std::unique_ptr client_; - absl::flat_hash_map> device_map_; - std::vector devices_; -}; - -TEST_P(ShardingSerDesTest, SingleDeviceShardingRoundTrip) { - auto sharding = SingleDeviceSharding::Create(client()->devices().front()); - - TF_ASSERT_OK_AND_ASSIGN(auto serialized, Serialize(*sharding)); - - auto deserialized_options = - std::make_unique(client()); - TF_ASSERT_OK_AND_ASSIGN( - auto deserialized, - Deserialize(serialized, std::move(deserialized_options))); - - const auto* out_sharding = - llvm::dyn_cast(deserialized.get()); - ASSERT_NE(out_sharding, nullptr); - EXPECT_THAT(out_sharding->devices(), ElementsAreArray(sharding->devices())); -} - -TEST_P(ShardingSerDesTest, OpaqueShardingRoundTrip) { - auto sharding = OpaqueSharding::Create(DeviceList(DeviceList::Devices( - client()->devices().begin(), client()->devices().end()))); - - TF_ASSERT_OK_AND_ASSIGN(auto serialized, Serialize(*sharding)); - - auto deserialized_options = - std::make_unique(client()); - TF_ASSERT_OK_AND_ASSIGN( - auto deserialized, - Deserialize(serialized, std::move(deserialized_options))); - - const auto* out_sharding = llvm::dyn_cast(deserialized.get()); - ASSERT_NE(out_sharding, nullptr); - EXPECT_THAT(out_sharding->devices(), ElementsAreArray(sharding->devices())); -} - -TEST_P(ShardingSerDesTest, ConcreteShardingRoundTrip) { - auto sharding = ConcreteSharding::Create( - DeviceList(DeviceList::Devices(client()->devices().begin(), - client()->devices().end())), - /*shape=*/Shape({10, 20}), - /*shard_shapes=*/{Shape({3, 20}), Shape({7, 20})}); - - TF_ASSERT_OK_AND_ASSIGN(auto serialized, Serialize(*sharding)); - - auto deserialized_options = - std::make_unique(client()); - TF_ASSERT_OK_AND_ASSIGN( - auto deserialized, - Deserialize(serialized, std::move(deserialized_options))); - - const auto* out_sharding = - llvm::dyn_cast(deserialized.get()); - ASSERT_NE(out_sharding, nullptr); - EXPECT_THAT(out_sharding->devices(), ElementsAreArray(sharding->devices())); - EXPECT_THAT(out_sharding->shape(), sharding->shape()); - EXPECT_THAT(out_sharding->shard_shapes(), - ElementsAreArray(sharding->shard_shapes())); -} - -TEST_P(ShardingSerDesTest, ConcreteEvenShardingRoundTrip) { - auto sharding = ConcreteEvenSharding::Create( - DeviceList(DeviceList::Devices(client()->devices().begin(), - client()->devices().end())), - /*shape=*/Shape({10, 20}), - /*shard_shape=*/Shape({5, 20})); - - TF_ASSERT_OK_AND_ASSIGN(auto serialized, Serialize(*sharding)); - - auto deserialized_options = - std::make_unique(client()); - TF_ASSERT_OK_AND_ASSIGN( - auto deserialized, - Deserialize(serialized, std::move(deserialized_options))); - - const auto* out_sharding = - llvm::dyn_cast(deserialized.get()); - ASSERT_NE(out_sharding, nullptr); - EXPECT_THAT(out_sharding->devices(), ElementsAreArray(sharding->devices())); - EXPECT_THAT(out_sharding->shape(), sharding->shape()); - EXPECT_THAT(out_sharding->shard_shape(), sharding->shard_shape()); -} - -INSTANTIATE_TEST_SUITE_P(NumDevices, ShardingSerDesTest, testing::Values(2)); - -} // namespace -} // namespace ifrt -} // namespace xla diff --git a/tensorflow/compiler/xla/python/ifrt/types.proto b/tensorflow/compiler/xla/python/ifrt/types.proto deleted file mode 100644 index e9c799bcc1ed6c..00000000000000 --- a/tensorflow/compiler/xla/python/ifrt/types.proto +++ /dev/null @@ -1,31 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -syntax = "proto3"; - -package xla.ifrt; - -// Wire format for `DeviceList`. -message DeviceListProto { - // Serialization and deserialization are expected to ensure that device ids - // are stable across proto construction and consumption. - repeated int32 device_ids = 1; -} - -// Wire format for `Shape`. Currently support static shapes with all dimension -// sizes greater than or equal to 0. -message ShapeProto { - repeated int64 dims = 1; -} diff --git a/tensorflow/compiler/xla/python/pjrt_ifrt/BUILD b/tensorflow/compiler/xla/python/pjrt_ifrt/BUILD index ba7f217ec28e3f..e1c1a36bb4ef62 100644 --- a/tensorflow/compiler/xla/python/pjrt_ifrt/BUILD +++ b/tensorflow/compiler/xla/python/pjrt_ifrt/BUILD @@ -36,7 +36,6 @@ cc_library( srcs = [ "xla_compiler.cc", "xla_sharding.cc", - "xla_sharding_serdes.cc", ], hdrs = [ "xla_compiler.h", @@ -44,9 +43,7 @@ cc_library( ], deps = [ ":xla_compiler_proto_cc", - ":xla_sharding_proto_cc", "//tensorflow/compiler/xla:xla_data_proto_cc", - "//tensorflow/compiler/xla/hlo/ir:hlo", "//tensorflow/compiler/xla/pjrt:pjrt_executable", "//tensorflow/compiler/xla/python/ifrt", "//tensorflow/compiler/xla/python/ifrt:serdes", @@ -113,27 +110,6 @@ xla_cc_test( ], ) -tf_proto_library( - name = "xla_sharding_proto", - srcs = ["xla_sharding.proto"], - protodeps = [ - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/python/ifrt:types_proto", - ], -) - -xla_cc_test( - name = "xla_sharding_serdes_test", - srcs = ["xla_sharding_serdes_test.cc"], - deps = [ - ":xla_ifrt", - "//tensorflow/compiler/xla/hlo/ir:hlo", - "//tensorflow/compiler/xla/python/ifrt", - "//tensorflow/compiler/xla/python/ifrt:mock", - "@com_google_googletest//:gtest_main", - ], -) - # TODO(hyeontaek): Move this target out of pjrt_ifrt. cc_library( name = "xla_executable_impl_test_lib", diff --git a/tensorflow/compiler/xla/python/pjrt_ifrt/xla_sharding.proto b/tensorflow/compiler/xla/python/pjrt_ifrt/xla_sharding.proto deleted file mode 100644 index 0ff8040b66233e..00000000000000 --- a/tensorflow/compiler/xla/python/pjrt_ifrt/xla_sharding.proto +++ /dev/null @@ -1,27 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -syntax = "proto3"; - -package xla.ifrt; - -import "tensorflow/compiler/xla/python/ifrt/types.proto"; -import "tensorflow/compiler/xla/xla_data.proto"; - -// Wire format for `HloSharding`. -message HloShardingProto { - DeviceListProto devices = 1; - xla.OpSharding xla_op_sharding = 2; -} diff --git a/tensorflow/compiler/xla/python/pjrt_ifrt/xla_sharding_serdes.cc b/tensorflow/compiler/xla/python/pjrt_ifrt/xla_sharding_serdes.cc deleted file mode 100644 index c3d8d2470600b9..00000000000000 --- a/tensorflow/compiler/xla/python/pjrt_ifrt/xla_sharding_serdes.cc +++ /dev/null @@ -1,79 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include - -#include "tensorflow/compiler/xla/hlo/ir/hlo_sharding.h" -#include "tensorflow/compiler/xla/python/ifrt/serdes.h" -#include "tensorflow/compiler/xla/python/ifrt/sharding_serdes.h" -#include "tensorflow/compiler/xla/python/pjrt_ifrt/xla_sharding.h" -#include "tensorflow/compiler/xla/python/pjrt_ifrt/xla_sharding.pb.h" - -namespace xla { -namespace ifrt { - -namespace { - -// Serialization/deserialization for `HloSharding`. -class HloShardingSerDes : public llvm::RTTIExtends { - public: - absl::string_view type_name() const override { - return "xla::ifrt::HloSharding"; - } - - absl::StatusOr Serialize(Serializable& serializable) override { - const HloSharding& sharding = llvm::cast(serializable); - HloShardingProto proto; - *proto.mutable_devices() = sharding.devices().ToProto(); - *proto.mutable_xla_op_sharding() = sharding.xla_hlo_sharding().ToProto(); - return proto.SerializeAsString(); - } - - absl::StatusOr> Deserialize( - const std::string& serialized, - std::unique_ptr options) override { - TF_ASSIGN_OR_RETURN(auto deserialize_sharding_options, - GetDeserializeShardingOptions(std::move(options))); - - HloShardingProto proto; - if (!proto.ParseFromString(serialized)) { - return absl::InvalidArgumentError( - "Failed to parse serialized HloSharding"); - } - TF_ASSIGN_OR_RETURN(auto devices, DeviceList::FromProto( - deserialize_sharding_options->client, - proto.devices())); - TF_ASSIGN_OR_RETURN(auto xla_hlo_sharding, - xla::HloSharding::FromProto(proto.xla_op_sharding())); - return HloSharding::Create(std::move(devices), std::move(xla_hlo_sharding)); - } - - static char ID; // NOLINT -}; - -[[maybe_unused]] char HloShardingSerDes::ID = 0; // NOLINT - -// clang-format off -bool register_hlo_sharding_serdes = ([] { - RegisterSerDes( - std::make_unique()); -}(), true); -// clang-format on - -} // namespace -} // namespace ifrt -} // namespace xla diff --git a/tensorflow/compiler/xla/python/pjrt_ifrt/xla_sharding_serdes_test.cc b/tensorflow/compiler/xla/python/pjrt_ifrt/xla_sharding_serdes_test.cc deleted file mode 100644 index e043fb7e575f2b..00000000000000 --- a/tensorflow/compiler/xla/python/pjrt_ifrt/xla_sharding_serdes_test.cc +++ /dev/null @@ -1,95 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include - -#include -#include -#include "tensorflow/compiler/xla/hlo/ir/hlo_sharding.h" -#include "tensorflow/compiler/xla/python/ifrt/mock.h" -#include "tensorflow/compiler/xla/python/ifrt/sharding_serdes.h" -#include "tensorflow/compiler/xla/python/pjrt_ifrt/xla_sharding.h" - -namespace xla { -namespace ifrt { -namespace { - -using ::testing::ElementsAreArray; - -// Test fixture for sharding serialization and deserialization. It makes a mock -// client with a number of fake devices. Client implements `devices()` and -// `LookupDevice()`, and Device implements `id()`, with an arbitrary device ids -// assigned. -class XlaShardingSerDesTest : public ::testing::TestWithParam { - public: - void SetUp() override { - const int num_devices = GetParam(); - device_map_.reserve(num_devices); - devices_.reserve(num_devices); - for (int i = 0; i < num_devices; ++i) { - auto device = std::make_unique(); - ON_CALL(*device, id).WillByDefault([i]() { return i + 10; }); - devices_.push_back(device.get()); - device_map_.insert({i + 10, std::move(device)}); - } - client_ = std::make_unique(); - ON_CALL(*client_, devices) - .WillByDefault( - [this]() -> absl::Span { return devices_; }); - ON_CALL(*client_, LookupDevice) - .WillByDefault([this](int device_id) -> StatusOr { - auto it = device_map_.find(device_id); - if (it == device_map_.end()) { - return InvalidArgument("Unexpected device id: %d", device_id); - } - return it->second.get(); - }); - } - Client* client() { return client_.get(); } - - private: - std::unique_ptr client_; - absl::flat_hash_map> device_map_; - std::vector devices_; -}; - -TEST_P(XlaShardingSerDesTest, HloShardingRoundTrip) { - auto xla_hlo_sharding = xla::HloSharding::Tile(xla::TileAssignment({2, 1})); - auto sharding = HloSharding::Create( - DeviceList(DeviceList::Devices(client()->devices().begin(), - client()->devices().end())), - /*xla_hlo_sharding=*/xla_hlo_sharding); - - TF_ASSERT_OK_AND_ASSIGN(auto serialized, Serialize(*sharding)); - - auto deserialized_options = - std::make_unique(client()); - TF_ASSERT_OK_AND_ASSIGN( - auto deserialized, - Deserialize(serialized, std::move(deserialized_options))); - - const auto* out_sharding = llvm::dyn_cast(deserialized.get()); - ASSERT_NE(out_sharding, nullptr); - EXPECT_THAT(out_sharding->devices(), ElementsAreArray(sharding->devices())); - EXPECT_EQ(out_sharding->xla_hlo_sharding(), sharding->xla_hlo_sharding()); -} - -INSTANTIATE_TEST_SUITE_P(NumDevices, XlaShardingSerDesTest, testing::Values(2)); - -} // namespace -} // namespace ifrt -} // namespace xla From e72477ce2bcb4736948f6346959bdf214b004442 Mon Sep 17 00:00:00 2001 From: Armando Ugalde Velasco Date: Thu, 13 Jul 2023 11:36:22 -0700 Subject: [PATCH 261/376] Collect target processing time in ClientHeartbeat Collect target processing times and send them in ClientHeartbeat. Also make copies of Parameters when creating a model snapshot. PiperOrigin-RevId: 547867518 --- tensorflow/core/data/service/client/data_service_client.cc | 5 +++++ tensorflow/core/framework/model.cc | 7 ++++++- tensorflow/core/framework/model.h | 7 +++++++ 3 files changed, 18 insertions(+), 1 deletion(-) diff --git a/tensorflow/core/data/service/client/data_service_client.cc b/tensorflow/core/data/service/client/data_service_client.cc index 13eebf9a32b27b..c9fc81902b5362 100644 --- a/tensorflow/core/data/service/client/data_service_client.cc +++ b/tensorflow/core/data/service/client/data_service_client.cc @@ -427,6 +427,11 @@ void DataServiceClient::Heartbeat() TF_LOCKS_EXCLUDED(mu_) { req.set_blocked_round(round_robin_round_limit_.value()); } } + { + mutex_lock l(mu_); + double target_processing_time_nsec = ctx_->GetTargetProcessingTimeNsec(); + req.set_target_processing_time_nsec(target_processing_time_nsec); + } ClientHeartbeatResponse resp; Status s = dispatcher_->ClientHeartbeat(req, resp); if (!s.ok()) { diff --git a/tensorflow/core/framework/model.cc b/tensorflow/core/framework/model.cc index f613e96f7bd8b7..f6341ed75a2162 100644 --- a/tensorflow/core/framework/model.cc +++ b/tensorflow/core/framework/model.cc @@ -2007,7 +2007,12 @@ std::shared_ptr Node::SnapshotHelper( cloned_current->processing_time_.store(processing_time_); { mutex_lock l2(cloned_current->mu_); - cloned_current->parameters_ = parameters_; + cloned_current->parameters_ = + absl::flat_hash_map>(); + for (const auto& [parameter_name, parameter_ptr] : parameters_) { + cloned_current->parameters_[parameter_name] = + std::make_shared(parameter_ptr); + } cloned_current->previous_processing_time_ = previous_processing_time_; cloned_current->processing_time_ema_ = processing_time_ema_; } diff --git a/tensorflow/core/framework/model.h b/tensorflow/core/framework/model.h index 58569f9773bef2..8e815eafb1abe9 100644 --- a/tensorflow/core/framework/model.h +++ b/tensorflow/core/framework/model.h @@ -107,6 +107,13 @@ struct Parameter { max(max), state(std::move(state)) {} + explicit Parameter(const std::shared_ptr parameter) + : name(parameter->name), + value(parameter->value), + min(parameter->min), + max(parameter->max), + state(parameter->state) {} + // Human-readable name of the parameter. const string name; From 8d4e21985a6cbf69e9a8d0a07dd293bdba8da12e Mon Sep 17 00:00:00 2001 From: James Mullenbach Date: Thu, 13 Jul 2023 11:56:17 -0700 Subject: [PATCH 262/376] Catch another type of error that indicates coordination service is down, which was causing some flakiness. PiperOrigin-RevId: 547874218 --- .../python/distribute/coordinator/cluster_coordinator.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/distribute/coordinator/cluster_coordinator.py b/tensorflow/python/distribute/coordinator/cluster_coordinator.py index 879c9e24921e4c..ca4bbbc2d8c2c1 100644 --- a/tensorflow/python/distribute/coordinator/cluster_coordinator.py +++ b/tensorflow/python/distribute/coordinator/cluster_coordinator.py @@ -766,11 +766,16 @@ def _log_ps_failure_and_raise(self, e, ps_index): raise PSUnavailableError(e) def _get_task_states(self): + """Get task states and reset to None if coordination service is down.""" try: self._task_states = context.context().get_task_states( [("worker", self._num_workers), ("ps", self._num_ps)] ) - except errors.UnavailableError: + except (errors.UnavailableError, errors.InternalError) as e: + if isinstance( + e, errors.InternalError + ) and "coordination service is not enabled" not in str(e).lower(): + raise # Coordination service is down self._task_states = None with self._next_task_state_cond: From 59e4c693a10741d2c95fc7bbbe8141297083f304 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 13 Jul 2023 12:00:58 -0700 Subject: [PATCH 263/376] Add logging information to Ph1 call sites PiperOrigin-RevId: 547875568 --- tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc b/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc index 06877ad8e42a2b..6eaa7a4b04980f 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc @@ -31,6 +31,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" #include "tensorflow/core/framework/metrics.h" #include "tensorflow/core/platform/error_payloads.h" +#include "tensorflow/core/platform/stacktrace.h" #include "tensorflow/core/protobuf/core_platform_payloads.pb.h" #include "tensorflow/core/util/debug_data_dumper.h" @@ -290,6 +291,8 @@ void CreateTPUBridgePipelineV1(OpPassManager &pm) { tensorflow::Status TPUBridge(ModuleOp module, bool fallback_enabled, llvm::StringRef module_name) { + VLOG(1) << "TPU Bridge called stack trace is :" + << tensorflow::CurrentStackTrace(); Status status = RunTFXLABridge( module, [module_name](OpPassManager &pm) { @@ -313,6 +316,8 @@ tensorflow::Status TPUBridge(ModuleOp module, bool fallback_enabled, return status; } tensorflow::Status TPUBridgeV1Compat(ModuleOp module, bool fallback_enabled) { + VLOG(1) << "TPU V1 Compat Bridge called stack trace is :" + << tensorflow::CurrentStackTrace(); Status status = RunTFXLABridge(module, [](OpPassManager &pm) { CreateTPUBridgePipelineV1(pm); // Add set of passes to lower back to graph (from tf_executor). @@ -487,6 +492,8 @@ void CreateTFXLABridgePipeline(OpPassManager &pm) { tensorflow::Status RunTFXLABridge(ModuleOp module, llvm::StringRef module_name) { + VLOG(1) << "CPU/GPU Bridge called stack trace is :" + << tensorflow::CurrentStackTrace(); Status status = mlir::TFTPU::RunTFXLABridge( module, [](OpPassManager &pm) { From 0ed8e904f4cb58826e2bd861ae508e38f49d3ea8 Mon Sep 17 00:00:00 2001 From: Fiona Lang Date: Thu, 13 Jul 2023 12:17:30 -0700 Subject: [PATCH 264/376] Delete unused tensor_priority_test.py. PiperOrigin-RevId: 547880537 --- .../kernel_tests/tensor_priority_test.py | 84 ------------------- 1 file changed, 84 deletions(-) delete mode 100644 tensorflow/python/kernel_tests/tensor_priority_test.py diff --git a/tensorflow/python/kernel_tests/tensor_priority_test.py b/tensorflow/python/kernel_tests/tensor_priority_test.py deleted file mode 100644 index 4111fa080f592a..00000000000000 --- a/tensorflow/python/kernel_tests/tensor_priority_test.py +++ /dev/null @@ -1,84 +0,0 @@ -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for the binary ops priority mechanism.""" -import numpy as np - -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor as tensor_lib -from tensorflow.python.framework import tensor_conversion_registry -from tensorflow.python.platform import test as test_lib - - -class TensorPriorityTest(test_lib.TestCase): - - def testSupportedRhsWithoutDelegation(self): - - class NumpyArraySubclass(np.ndarray): - pass - - supported_rhs_without_delegation = (3, 3.0, [1.0, 2.0], np.array( - [1.0, 2.0]), NumpyArraySubclass( - shape=(1, 2), buffer=np.array([1.0, 2.0])), - ops.convert_to_tensor([[1.0, 2.0]])) - for rhs in supported_rhs_without_delegation: - tensor = ops.convert_to_tensor([[10.0, 20.0]]) - res = tensor + rhs - self.assertIsInstance(res, tensor_lib.Tensor) - - def testUnsupportedRhsWithoutDelegation(self): - - class WithoutReverseAdd(object): - pass - - tensor = ops.convert_to_tensor([[10.0, 20.0]]) - rhs = WithoutReverseAdd() - with self.assertRaisesWithPredicateMatch( - TypeError, lambda e: "Expected float" in str(e)): - # pylint: disable=pointless-statement - tensor + rhs - - def testUnsupportedRhsWithDelegation(self): - - class WithReverseAdd(object): - - def __radd__(self, lhs): - return "Works!" - - tensor = ops.convert_to_tensor([[10.0, 20.0]]) - rhs = WithReverseAdd() - res = tensor + rhs - self.assertEqual(res, "Works!") - - def testFullDelegationControlUsingRegistry(self): - - class NumpyArraySubclass(np.ndarray): - - def __radd__(self, lhs): - return "Works!" - - def raise_to_delegate(value, dtype=None, name=None, as_ref=False): - del value, dtype, name, as_ref # Unused. - raise TypeError - - tensor_conversion_registry.register_tensor_conversion_function( - NumpyArraySubclass, raise_to_delegate, priority=0) - tensor = ops.convert_to_tensor([[10.0, 20.0]]) - rhs = NumpyArraySubclass(shape=(1, 2), buffer=np.array([1.0, 2.0])) - res = tensor + rhs - self.assertEqual(res, "Works!") - - -if __name__ == "__main__": - test_lib.main() From f8613978c410f46a0edf69d99589bdad9fcf6e00 Mon Sep 17 00:00:00 2001 From: Jean-Baptiste Lespiau Date: Thu, 13 Jul 2023 12:22:07 -0700 Subject: [PATCH 265/376] Fix the propagation of the errors through TF_ASSIGN_OR_RETURN. PiperOrigin-RevId: 547881677 --- tensorflow/tsl/platform/BUILD | 4 ++- tensorflow/tsl/platform/default/BUILD | 17 ++++++++++ .../tsl/platform/default/build_config.bzl | 1 + tensorflow/tsl/platform/default/statusor.h | 33 +++++++++++++++++++ tensorflow/tsl/platform/statusor.h | 23 ++++++------- tensorflow/tsl/platform/statusor_test.cc | 29 ++++++++++++++++ 6 files changed, 93 insertions(+), 14 deletions(-) create mode 100644 tensorflow/tsl/platform/default/statusor.h diff --git a/tensorflow/tsl/platform/BUILD b/tensorflow/tsl/platform/BUILD index dd5d7b3ace120c..fd61677798e865 100644 --- a/tensorflow/tsl/platform/BUILD +++ b/tensorflow/tsl/platform/BUILD @@ -328,13 +328,14 @@ cc_library( ":errors", ":logging", ":macros", + ":platform", ":status", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", - ], + ] + tf_platform_deps("statusor"), ) cc_library( @@ -732,6 +733,7 @@ filegroup( "setround.h", "snappy.h", "status.h", + "statusor.h", "tracing.h", "unbounded_work_queue.h", ], diff --git a/tensorflow/tsl/platform/default/BUILD b/tensorflow/tsl/platform/default/BUILD index d9cbafe7710aa7..cc59928cb393ba 100644 --- a/tensorflow/tsl/platform/default/BUILD +++ b/tensorflow/tsl/platform/default/BUILD @@ -564,6 +564,22 @@ cc_library( visibility = set_external_visibility(["//tensorflow:__subpackages__"]), ) +cc_library( + name = "statusor", + tags = [ + "manual", + "no_oss", + "nobuilder", + ], + textual_hdrs = ["statusor.h"], + visibility = set_external_visibility(["//tensorflow:__subpackages__"]), + deps = [ + "//tensorflow/tsl/platform:macros", + "//tensorflow/tsl/platform:status", + "@com_google_absl//absl/status:statusor", + ], +) + bzl_library( name = "cuda_build_defs_bzl", srcs = ["cuda_build_defs.bzl"], @@ -595,6 +611,7 @@ filegroup( "posix_file_system.h", "stacktrace.h", "status.h", + "statusor.h", "tracing_impl.h", "//tensorflow/tsl/platform/profile_utils:cpu_utils.h", "//tensorflow/tsl/platform/profile_utils:i_cpu_utils_helper.h", diff --git a/tensorflow/tsl/platform/default/build_config.bzl b/tensorflow/tsl/platform/default/build_config.bzl index 7863928644e66e..d8d4ee476b3614 100644 --- a/tensorflow/tsl/platform/default/build_config.bzl +++ b/tensorflow/tsl/platform/default/build_config.bzl @@ -663,6 +663,7 @@ def tf_additional_lib_hdrs(): clean_dep("//tensorflow/tsl/platform/default:notification.h"), clean_dep("//tensorflow/tsl/platform/default:stacktrace.h"), clean_dep("//tensorflow/tsl/platform/default:status.h"), + clean_dep("//tensorflow/tsl/platform/default:statusor.h"), clean_dep("//tensorflow/tsl/platform/default:tracing_impl.h"), clean_dep("//tensorflow/tsl/platform/default:unbounded_work_queue.h"), ] + select({ diff --git a/tensorflow/tsl/platform/default/statusor.h b/tensorflow/tsl/platform/default/statusor.h new file mode 100644 index 00000000000000..300b4906f0f8db --- /dev/null +++ b/tensorflow/tsl/platform/default/statusor.h @@ -0,0 +1,33 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_TSL_PLATFORM_DEFAULT_STATUSOR_H_ +#define TENSORFLOW_TSL_PLATFORM_DEFAULT_STATUSOR_H_ + +#include "absl/status/statusor.h" +#include "tensorflow/tsl/platform/macros.h" +#include "tensorflow/tsl/platform/status.h" + +#define TF_ASSIGN_OR_RETURN(lhs, rexpr) \ + TF_ASSIGN_OR_RETURN_IMPL( \ + TF_STATUS_MACROS_CONCAT_NAME(_status_or_value, __COUNTER__), lhs, rexpr) + +#define TF_ASSIGN_OR_RETURN_IMPL(statusor, lhs, rexpr) \ + auto statusor = (rexpr); \ + if (TF_PREDICT_FALSE(!statusor.ok())) { \ + return statusor.status(); \ + } \ + lhs = std::move(statusor).value() + +#endif // TENSORFLOW_TSL_PLATFORM_DEFAULT_STATUSOR_H_ diff --git a/tensorflow/tsl/platform/statusor.h b/tensorflow/tsl/platform/statusor.h index 34bf3e38d20e6e..cf7a95d45a7ec8 100644 --- a/tensorflow/tsl/platform/statusor.h +++ b/tensorflow/tsl/platform/statusor.h @@ -72,12 +72,22 @@ limitations under the License. #include "absl/status/statusor.h" #include "tensorflow/tsl/platform/errors.h" #include "tensorflow/tsl/platform/macros.h" +#include "tensorflow/tsl/platform/platform.h" #include "tensorflow/tsl/platform/status.h" +// Include appropriate platform-dependent `TF_ASSIGN_OR_RETURN`. +#if defined(PLATFORM_GOOGLE) +#include "tensorflow/tsl/platform/google/statusor.h" // IWYU pragma: export +#else +#include "tensorflow/tsl/platform/default/statusor.h" // IWYU pragma: export +#endif + namespace tsl { using absl::StatusOr; +} // namespace tsl + #define TF_ASSERT_OK_AND_ASSIGN(lhs, rexpr) \ TF_ASSERT_OK_AND_ASSIGN_IMPL( \ TF_STATUS_MACROS_CONCAT_NAME(_status_or_value, __COUNTER__), lhs, \ @@ -91,17 +101,4 @@ using absl::StatusOr; #define TF_STATUS_MACROS_CONCAT_NAME(x, y) TF_STATUS_MACROS_CONCAT_IMPL(x, y) #define TF_STATUS_MACROS_CONCAT_IMPL(x, y) x##y -#define TF_ASSIGN_OR_RETURN(lhs, rexpr) \ - TF_ASSIGN_OR_RETURN_IMPL( \ - TF_STATUS_MACROS_CONCAT_NAME(_status_or_value, __COUNTER__), lhs, rexpr) - -#define TF_ASSIGN_OR_RETURN_IMPL(statusor, lhs, rexpr) \ - auto statusor = (rexpr); \ - if (TF_PREDICT_FALSE(!statusor.ok())) { \ - return statusor.status(); \ - } \ - lhs = std::move(statusor).value() - -} // namespace tsl - #endif // TENSORFLOW_TSL_PLATFORM_STATUSOR_H_ diff --git a/tensorflow/tsl/platform/statusor_test.cc b/tensorflow/tsl/platform/statusor_test.cc index 9dca858fa2f854..c6eb12c782a5e0 100644 --- a/tensorflow/tsl/platform/statusor_test.cc +++ b/tensorflow/tsl/platform/statusor_test.cc @@ -697,5 +697,34 @@ void BM_StatusOrFactoryFailLongMsg(::testing::benchmark::State& state) { } BENCHMARK(BM_StatusOrFactoryFailLongMsg); +#if defined(PLATFORM_GOOGLE) + +StatusOr GetError() { + return absl::InvalidArgumentError("An invalid argument error"); +} + +StatusOr PropagateError() { + TF_ASSIGN_OR_RETURN(int a, GetError()); + return a; +} + +StatusOr PropagateError2() { + TF_ASSIGN_OR_RETURN(int a, PropagateError()); + return a; +} + +TEST(Status, StackTracePropagation) { + StatusOr s = PropagateError2(); + auto sources = s.status().GetSourceLocations(); + ASSERT_EQ(sources.size(), 3); + + for (int i = 0; i < 3; ++i) { + ASSERT_EQ(sources[i].file_name(), + "third_party/tensorflow/tsl/platform/statusor_test.cc"); + } +} + +#endif + } // namespace } // namespace tsl From 24c79b6f902bf6554c356c32744f540cffa57aa6 Mon Sep 17 00:00:00 2001 From: James Mullenbach Date: Thu, 13 Jul 2023 12:27:35 -0700 Subject: [PATCH 266/376] Log ParameterServerStrategy variable placements at a higher verbosity when an env var is passed PiperOrigin-RevId: 547883350 --- .../python/distribute/parameter_server_strategy_v2.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/distribute/parameter_server_strategy_v2.py b/tensorflow/python/distribute/parameter_server_strategy_v2.py index 2c39613505a27c..8bf08b2dd8efdc 100644 --- a/tensorflow/python/distribute/parameter_server_strategy_v2.py +++ b/tensorflow/python/distribute/parameter_server_strategy_v2.py @@ -824,7 +824,11 @@ def _create_variable_round_robin(self, next_creator, **kwargs): with ops.device("/job:ps/task:%d/device:CPU:0" % (self._variable_count % self._num_ps)): var = next_creator(**kwargs) - logging.debug( + log_method = ( + logging.info if os.getenv("TF_PSS_VERBOSE_VARIABLE_PLACEMENT") + else logging.debug + ) + log_method( "Creating variable (name:%s, shape:%r) on " "/job:ps/task:%d/device:CPU:0", var.name, var.shape, (self._variable_count % self._num_ps)) From f49705df15031903ea40136ea3ea000e0db660e2 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 13 Jul 2023 12:32:14 -0700 Subject: [PATCH 267/376] Fix typo that causes CPU kernel to be registered twice. PiperOrigin-RevId: 547884686 --- tensorflow/core/kernels/stochastic_cast_op.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tensorflow/core/kernels/stochastic_cast_op.cc b/tensorflow/core/kernels/stochastic_cast_op.cc index 626a00894da311..ba2954760610a6 100644 --- a/tensorflow/core/kernels/stochastic_cast_op.cc +++ b/tensorflow/core/kernels/stochastic_cast_op.cc @@ -91,9 +91,9 @@ REGISTER_CAST_TO_INT_CPU_KERNEL(double, int16); REGISTER_CAST_TO_INT_CPU_KERNEL(double, int32); #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM -REGISTER_CAST_TO_INT_CPU_KERNEL(half, int8); -REGISTER_CAST_TO_INT_CPU_KERNEL(half, int16); -REGISTER_CAST_TO_INT_CPU_KERNEL(half, int32); +REGISTER_CAST_TO_INT_GPU_KERNEL(half, int8); +REGISTER_CAST_TO_INT_GPU_KERNEL(half, int16); +REGISTER_CAST_TO_INT_GPU_KERNEL(half, int32); REGISTER_CAST_TO_INT_GPU_KERNEL(bfloat16, int8); REGISTER_CAST_TO_INT_GPU_KERNEL(bfloat16, int16); From 253dab848e233368604de84c1b51521dcf4b557d Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 13 Jul 2023 12:36:34 -0700 Subject: [PATCH 268/376] Build changes for AArch64. PiperOrigin-RevId: 547886019 --- tensorflow/compiler/tf2tensorrt/BUILD | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorflow/compiler/tf2tensorrt/BUILD b/tensorflow/compiler/tf2tensorrt/BUILD index 24a6dc43bd1096..96e07f337dee90 100644 --- a/tensorflow/compiler/tf2tensorrt/BUILD +++ b/tensorflow/compiler/tf2tensorrt/BUILD @@ -85,6 +85,7 @@ alias( "@local_config_tensorrt//:use_static_tensorrt": "@local_config_tensorrt//:tensorrt", "//conditions:default": ":tensorrt_stub", }), + visibility = ["//visibility:private"], ) tf_cuda_cc_test( From 182bec854f12b6ee3e9f54eff9df54ba399dc013 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 13 Jul 2023 12:54:27 -0700 Subject: [PATCH 269/376] Pattern to fuse/fold the reshape ops around TFL_BatchMatMulOp Python code to test the numerical equivalence- ``` import tensorflow as tf import numpy as np a = tf.random.uniform(shape=[2,3,4], maxval=100, dtype=tf.int32) b = tf.random.uniform(shape=[4,5], maxval=100, dtype=tf.int32) r1 = tf.reshape(a, [6,4]) bmm1 = tf.matmul(r1, b) r2 = tf.reshape(bmm1, [2,3,5]) bmm2 = tf.matmul(a, b) print(bmm2==r2) ``` PiperOrigin-RevId: 547891798 --- .../compiler/mlir/lite/tests/optimize.mlir | 62 +++++++++++++++++++ .../compiler/mlir/lite/transforms/optimize.cc | 16 +++++ .../mlir/lite/transforms/optimize_patterns.td | 55 ++++++++++++++++ 3 files changed, 133 insertions(+) diff --git a/tensorflow/compiler/mlir/lite/tests/optimize.mlir b/tensorflow/compiler/mlir/lite/tests/optimize.mlir index 4993a8babfbb79..2f1561855e9d4e 100644 --- a/tensorflow/compiler/mlir/lite/tests/optimize.mlir +++ b/tensorflow/compiler/mlir/lite/tests/optimize.mlir @@ -602,6 +602,68 @@ func.func @FuseFullyConnectedAddWithUnfusableRhs(%arg0: tensor<4x37xf32>, %arg1: // CHECK: return %[[add_result]] } +// CHECK-LABEL: @FuseReshapeAroundBMMLHS +func.func @FuseReshapeAroundBMMLHS(%arg0: tensor<6x5x1024xf32>) -> tensor<6x5x8192xf32> { + %cst = arith.constant dense_resource<__elided__> : tensor<1024x8192xf32> + %cst_0 = arith.constant dense_resource<__elided__> : tensor<3xi32> + %cst_1 = arith.constant dense_resource<__elided__> : tensor<2xi32> + %0 = "tfl.reshape"(%arg0, %cst_1) : (tensor<6x5x1024xf32>, tensor<2xi32>) -> tensor<30x1024xf32> + %1 = "tfl.batch_matmul"(%0, %cst) {adj_x = false, adj_y = false} : (tensor<30x1024xf32>, tensor<1024x8192xf32>) -> tensor<30x8192xf32> + %2 = "tfl.reshape"(%1, %cst_0) : (tensor<30x8192xf32>, tensor<3xi32>) -> tensor<6x5x8192xf32> + return %2 : tensor<6x5x8192xf32> + // CHECK: %cst = arith.constant dense_resource<__elided__> : tensor<1024x8192xf32> + // CHECK: %0 = "tfl.batch_matmul"(%arg0, %cst) {adj_x = false, adj_y = false} : (tensor<6x5x1024xf32>, tensor<1024x8192xf32>) -> tensor<6x5x8192xf32> + // CHECK: return %0 : tensor<6x5x8192xf32> +} + +// CHECK-LABEL: @FuseReshapeAroundBMMNagativeTest +func.func @FuseReshapeAroundBMMNagativeTest(%arg0: tensor<5x4x1x1024xf32>, %arg1: tensor<5x1024x8192xf32>) -> tensor<5x4x1x8192xf32> { + %cst = arith.constant dense_resource<__elided__> : tensor<3xi32> + %cst_0 = arith.constant dense_resource<__elided__> : tensor<4xi32> + %0 = "tfl.reshape"(%arg0, %cst) : (tensor<5x4x1x1024xf32>, tensor<3xi32>) -> tensor<5x4x1024xf32> + %1 = "tfl.batch_matmul"(%0, %arg1) {adj_x = false, adj_y = false} : (tensor<5x4x1024xf32>, tensor<5x1024x8192xf32>) -> tensor<5x4x8192xf32> + %2 = "tfl.reshape"(%1, %cst_0) : (tensor<5x4x8192xf32>, tensor<4xi32>) -> tensor<5x4x1x8192xf32> + return %2 : tensor<5x4x1x8192xf32> + // CHECK: %cst = arith.constant dense_resource<__elided__> : tensor<3xi32> + // CHECK: %cst_0 = arith.constant dense_resource<__elided__> : tensor<4xi32> + // CHECK: %0 = "tfl.reshape"(%arg0, %cst) : (tensor<5x4x1x1024xf32>, tensor<3xi32>) -> tensor<5x4x1024xf32> + // CHECK: %1 = "tfl.batch_matmul"(%0, %arg1) {adj_x = false, adj_y = false} : (tensor<5x4x1024xf32>, tensor<5x1024x8192xf32>) -> tensor<5x4x8192xf32> + // CHECK: %2 = "tfl.reshape"(%1, %cst_0) : (tensor<5x4x8192xf32>, tensor<4xi32>) -> tensor<5x4x1x8192xf32> + // CHECK: return %2 : tensor<5x4x1x8192xf32> +} + +// CHECK-LABEL: @FuseReshapeAroundBMMNagativeTest2 +func.func @FuseReshapeAroundBMMNagativeTest2(%arg0: tensor<2x1536xf32>) -> tensor<2x768xf32> { + %cst = arith.constant dense_resource<__elided__> : tensor<3xi32> + %cst_0 = arith.constant dense_resource<__elided__> : tensor<2xi32> + %402 = "tfl.reshape"(%arg0, %cst) : (tensor<2x1536xf32>, tensor<3xi32>) -> tensor<2x12x128xf32> + %403 = "tfl.pseudo_qconst"() {qtype = tensor<128x64x!quant.uniform>, value = dense<9> : tensor<128x64xi8>} : () -> tensor<128x64x!quant.uniform> + %404 = "tfl.batch_matmul"(%402, %403) {adj_x = false, adj_y = false, asymmetric_quantize_inputs = true} : (tensor<2x12x128xf32>, tensor<128x64x!quant.uniform>) -> tensor<2x12x64xf32> + %405 = "tfl.reshape"(%404, %cst_0) : (tensor<2x12x64xf32>, tensor<2xi32>) -> tensor<2x768xf32> + return %405 : tensor<2x768xf32> + // CHECK: %cst = arith.constant dense_resource<__elided__> : tensor<3xi32> + // CHECK: %cst_0 = arith.constant dense_resource<__elided__> : tensor<2xi32> + // CHECK: %0 = "tfl.reshape"(%arg0, %cst) : (tensor<2x1536xf32>, tensor<3xi32>) -> tensor<2x12x128xf32> + // CHECK: %1 = "tfl.pseudo_qconst"() {qtype = tensor<128x64x!quant.uniform>, value = dense<9> : tensor<128x64xi8>} : () -> tensor<128x64x!quant.uniform> + // CHECK: %2 = "tfl.batch_matmul"(%0, %1) {adj_x = false, adj_y = false, asymmetric_quantize_inputs = true} : (tensor<2x12x128xf32>, tensor<128x64x!quant.uniform>) -> tensor<2x12x64xf32> + // CHECK: %3 = "tfl.reshape"(%2, %cst_0) : (tensor<2x12x64xf32>, tensor<2xi32>) -> tensor<2x768xf32> + // CHECK: return %3 : tensor<2x768xf32> +} + +// CHECK-LABEL: @FuseReshapeAroundBMMRHS +func.func @FuseReshapeAroundBMMRHS(%arg0: tensor<1x3x6x5x1024xf32>) -> tensor<1x3x6x5x8192xf32> attributes {tf.entry_function = {control_outputs = "", inputs = "inputs", outputs = "Identity_1"}} { + %cst = arith.constant dense_resource<__elided__> : tensor<1x1024x8192xf32> + %cst_0 = arith.constant dense_resource<__elided__> : tensor<5xi32> + %cst_1 = arith.constant dense_resource<__elided__> : tensor<3xi32> + %0 = "tfl.reshape"(%arg0, %cst_1) : (tensor<1x3x6x5x1024xf32>, tensor<3xi32>) -> tensor<1x90x1024xf32> + %1 = "tfl.batch_matmul"(%0, %cst) {adj_x = false, adj_y = false} : (tensor<1x90x1024xf32>, tensor<1x1024x8192xf32>) -> tensor<1x90x8192xf32> + %2 = "tfl.reshape"(%1, %cst_0) : (tensor<1x90x8192xf32>, tensor<5xi32>) -> tensor<1x3x6x5x8192xf32> + return %2 : tensor<1x3x6x5x8192xf32> + // CHECK: %cst = arith.constant dense_resource<__elided__> : tensor<1x1024x8192xf32> + // CHECK: %0 = "tfl.batch_matmul"(%arg0, %cst) {adj_x = false, adj_y = false} : (tensor<1x3x6x5x1024xf32>, tensor<1x1024x8192xf32>) -> tensor<1x3x6x5x8192xf32> + // CHECK: return %0 : tensor<1x3x6x5x8192xf32> +} + // CHECK-LABEL: @FuseFullyConnectedReshapeAddConst // FOLD-LABEL: @FuseFullyConnectedReshapeAddConst func.func @FuseFullyConnectedReshapeAddConst(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> tensor<40x40xf32> { diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize.cc b/tensorflow/compiler/mlir/lite/transforms/optimize.cc index 1c9d7bbb002d2c..03a162f98af533 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/optimize.cc @@ -123,6 +123,22 @@ class OptimizePass : public impl::OptimizePassBase { void runOnOperation() override; }; +// Return true if the product of dimension values of a subsection of the tensor +// is equal to the non-contracting dimension after a reshape +bool BroadcastDimsProductEqual(Value input, Value output, + size_t agg_start_idx) { + ArrayRef input_shape = input.getType().cast().getShape(); + ArrayRef output_shape = + output.getType().cast().getShape(); + + int64_t agg_value = 1; + for (size_t i = agg_start_idx; i < input_shape.size() - 1; ++i) { + agg_value *= input_shape[i]; + } + + return (agg_value == output_shape[agg_start_idx]); +} + // Returns whether the given type `a` is broadcast-compatible with `b`. bool IsBroadcastableElementsAttrAndType(Type a, Type b) { return OpTrait::util::getBroadcastedType(a, b) != Type(); diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td index 66d625e970b31c..1a826cd75f3020 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td @@ -1430,3 +1430,58 @@ def FuseLeakyReluConst : Pat< (HasOneUse $geq_out), (HasOneUse $mul_out), ]>; + +// Return true if the product of dimension values of a subsection of the tensor +// is equal to the non-contracting dimension after a reshape +class BroadcastDimsProductEqual : Constraint>; + +// Returns true if the dimensions of a subsection of two tensors is equal +// and the subsections are not empty +class AreTensorSubSectionShapesEqual : Constraint().getShape()" + ".drop_back("#skip_last#").drop_front("#skip_first#") ==" + "$1.getType().dyn_cast().getShape()" + ".drop_back("#skip_last#").drop_front("#skip_first#"))" + "&& !$0.getType().dyn_cast().getShape()" + ".drop_back("#skip_last#").drop_front("#skip_first#").empty()">>; + +// Returns true if the broadcast dimension of a tensor is [1] +// here- broadcast dimension is first prefix dimension +// excluding the last two dimensions +def IsBroadcastDimEqualToOne : Constraint().getShape()[0] == 1">>; + +// Pattern to fuse/fold the reshape ops around TFL_BatchMatMulOp +// This pattern is applied when the rank of rhs is 2 +// which means it has empty broadcast dimensions +def FuseReshapesAroundBatchMatMulLHS: Pat< + (TFL_ReshapeOp:$final_shape_change + (TFL_BatchMatMulOp:$bmm_tmp_output + (TFL_ReshapeOp:$initial_shape_change $input, (Arith_ConstantOp $s0)), + $rhs, $adj_x, $adj_y, $bool_attr), + (Arith_ConstantOp $s1)), + (TFL_BatchMatMulOp $input, $rhs, $adj_x, $adj_y, $bool_attr), + [(HasRank<2> $rhs), + (HasRank<2> $initial_shape_change), + (BroadcastDimsProductEqual<0> $input, $initial_shape_change), + (BroadcastDimsProductEqual<0> $final_shape_change, $bmm_tmp_output), + (AreTensorSubSectionShapesEqual<0, 1> $input, $final_shape_change)]>; + +// Pattern to fuse/fold the reshape ops around TFL_BatchMatMulOp +// This pattern is applied when the rank of rhs is 3 +// and the broadcast dimension is [1] +def FuseReshapesAroundBatchMatMulLHS1: Pat< + (TFL_ReshapeOp:$final_shape_change + (TFL_BatchMatMulOp:$bmm_tmp_output + (TFL_ReshapeOp:$initial_shape_change $input, (Arith_ConstantOp $s0)), + $rhs, $adj_x, $adj_y, $bool_attr), + (Arith_ConstantOp $s1)), + (TFL_BatchMatMulOp $input, $rhs, $adj_x, $adj_y, $bool_attr), + [(HasRank<3> $rhs), + (HasRank<3> $initial_shape_change), + (IsBroadcastDimEqualToOne $rhs), + (IsBroadcastDimEqualToOne $input), + (BroadcastDimsProductEqual<1> $input, $initial_shape_change), + (BroadcastDimsProductEqual<1> $final_shape_change, $bmm_tmp_output), + (AreTensorSubSectionShapesEqual<1, 1> $input, $final_shape_change)]>; From 8f37b93813326215f54ea76fafd399dcec80b1dc Mon Sep 17 00:00:00 2001 From: Yang Chen Date: Thu, 13 Jul 2023 13:44:53 -0700 Subject: [PATCH 270/376] #tf-data-service Enable bufferedio for loading tf.data snapshots. PiperOrigin-RevId: 547908488 --- tensorflow/core/data/service/snapshot/BUILD | 1 + .../core/data/service/snapshot/snapshot_chunk_dataset_op.cc | 5 +++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/tensorflow/core/data/service/snapshot/BUILD b/tensorflow/core/data/service/snapshot/BUILD index cdb6e3d7ffb184..b4d5fa11433571 100644 --- a/tensorflow/core/data/service/snapshot/BUILD +++ b/tensorflow/core/data/service/snapshot/BUILD @@ -226,6 +226,7 @@ cc_library( "//tensorflow/core:protos_all_cc", "//tensorflow/core/data:name_utils", "//tensorflow/core/data:snapshot_utils", + "//tensorflow/core/data:utils", "//tensorflow/tsl/platform:env", "//tensorflow/tsl/platform:errors", "//tensorflow/tsl/platform:status", diff --git a/tensorflow/core/data/service/snapshot/snapshot_chunk_dataset_op.cc b/tensorflow/core/data/service/snapshot/snapshot_chunk_dataset_op.cc index 78d24dfcb9ff27..0b4c179560b7a2 100644 --- a/tensorflow/core/data/service/snapshot/snapshot_chunk_dataset_op.cc +++ b/tensorflow/core/data/service/snapshot/snapshot_chunk_dataset_op.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/core/data/name_utils.h" #include "tensorflow/core/data/snapshot_utils.h" +#include "tensorflow/core/data/utils.h" #include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor_shape.h" @@ -116,8 +117,8 @@ class SnapshotChunkDatasetOp::Dataset : public DatasetBase { Status Initialize(IteratorContext* ctx) override { reader_ = std::make_unique( - dataset()->chunk_file_, dataset()->compression_, dataset()->dtypes_, - kTFRecordReaderOutputBufferSize); + TranslateFileName(dataset()->chunk_file_), dataset()->compression_, + dataset()->dtypes_, kTFRecordReaderOutputBufferSize); return reader_->Initialize(ctx->env()); } From 2c4a76e24531e94d1ed6ab88b434ad7ea7fdc130 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 13 Jul 2023 14:01:14 -0700 Subject: [PATCH 271/376] Adds check for Optional Tensors before increasing the reference count for graph outputs in arena_planner PiperOrigin-RevId: 547914023 --- tensorflow/lite/arena_planner.cc | 4 +++- tensorflow/lite/arena_planner_test.cc | 19 +++++++++++++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/tensorflow/lite/arena_planner.cc b/tensorflow/lite/arena_planner.cc index 0284beb72a4134..549c447f7413a9 100644 --- a/tensorflow/lite/arena_planner.cc +++ b/tensorflow/lite/arena_planner.cc @@ -236,7 +236,9 @@ TfLiteStatus ArenaPlanner::PlanAllocations() { // artificially adding one to their ref-counts so they are never selected // for deallocation. for (int tensor_index : graph_info_->outputs()) { - ++refcounts_[tensor_index]; + if (tensor_index != kTfLiteOptionalTensor) { + ++refcounts_[tensor_index]; + } } // Variable tensors also should be ensured to be never overwritten and need to diff --git a/tensorflow/lite/arena_planner_test.cc b/tensorflow/lite/arena_planner_test.cc index 9cbbb61563f159..ded2345c93e8c7 100644 --- a/tensorflow/lite/arena_planner_test.cc +++ b/tensorflow/lite/arena_planner_test.cc @@ -709,6 +709,25 @@ TEST_F(ArenaPlannerTest, SimpleGraphWithOptionals) { EXPECT_EQ(GetOffset(2), GetOffsetAfter(4)); } +TEST_F(ArenaPlannerTest, SimpleGraphWithOptionalOutput) { + TestGraph graph({0, -1, 1}, + { + /* in, out, tmp */ + {{0, 1}, {2}, {}}, // First op + {{2, 0}, {4, 5}, {}}, // Second op + {{4, 5}, {3}, {}} // Third op, with optional + }, + {-1, 3}); + SetGraph(&graph); + Execute(0, graph.nodes().size() - 1); + + // Alloc(+) and dealloc(-) order: +0 +1 +2 +4 +5 -2 +3 -4 -5 + EXPECT_EQ(GetOffset(5), 12); + EXPECT_EQ(GetOffset(4), GetOffsetAfter(5)); + EXPECT_EQ(GetOffset(3), GetOffsetAfter(4)); + EXPECT_EQ(GetOffset(2), GetOffsetAfter(4)); +} + TEST_F(ArenaPlannerTest, SimpleGraphWithLargeTensor) { TestGraph graph({0, -1}, { From 7d2345dc687c95f7820ac77d3da21dea72dc6bb2 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 13 Jul 2023 14:13:06 -0700 Subject: [PATCH 272/376] Correct shader generation for different cases of the MUL operation. PiperOrigin-RevId: 547917603 --- .../lite/delegates/gpu/gl/kernels/mul.cc | 102 +++++++----- .../lite/delegates/gpu/gl/kernels/mul_test.cc | 151 +++++++++++++++--- 2 files changed, 189 insertions(+), 64 deletions(-) diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/mul.cc b/tensorflow/lite/delegates/gpu/gl/kernels/mul.cc index 3d21a0aee8e4c5..fb30986290b4b4 100644 --- a/tensorflow/lite/delegates/gpu/gl/kernels/mul.cc +++ b/tensorflow/lite/delegates/gpu/gl/kernels/mul.cc @@ -15,17 +15,12 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/gl/kernels/mul.h" -#include -#include -#include -#include #include #include #include #include #include -#include "absl/memory/memory.h" #include "absl/strings/str_cat.h" #include "tensorflow/lite/delegates/gpu/common/convert.h" #include "tensorflow/lite/delegates/gpu/common/status.h" @@ -37,38 +32,48 @@ namespace gl { namespace { -bool IsApplyMaskSupported(const NodeShader::GenerationContext& ctx) { - if (ctx.input_shapes.size() != 2) return false; - - // [H, W, C] x [H, W, 0][0] - if (ctx.input_shapes[0][1] == ctx.input_shapes[1][1] && - ctx.input_shapes[0][2] == ctx.input_shapes[1][2] && - ctx.input_shapes[1][3] == 1) { - return true; +// Returns the coordinate to iterate over the second runtime tensor. +absl::Status GetCoordinate(const NodeShader::GenerationContext& ctx, int dim, + const std::string& default_coord, + std::string* coord) { + std::string result; + if (ctx.input_shapes[1][dim] == 1 && ctx.input_shapes[0][dim] != 1) { + result = "0"; + } else if (ctx.input_shapes[0][dim] == ctx.input_shapes[1][dim]) { + result = default_coord; + } else { + return absl::InvalidArgumentError( + absl::StrCat("Second runtime tensor dimension ", dim, + " must either match " + "first tensor's dimensions or be 1.")); } - - // [H, W, C] x [H, W, C] - if (ctx.input_shapes[0] == ctx.input_shapes[1]) return true; - - // [H, W, C] x [0, 0, C] - return ctx.input_shapes[1][1] == 1 && ctx.input_shapes[1][2] == 1 && - ctx.input_shapes[0][3] == ctx.input_shapes[1][3]; + *coord = result; + return absl::OkStatus(); } -absl::Status GenerateApplyMaskCode(const NodeShader::GenerationContext& ctx, - GeneratedCode* generated_code) { - std::string source = "value_0 = $input_data_0[gid.x, gid.y, gid.z]$ * "; - if (ctx.input_shapes[1][3] == 1) { - // [H, W, C] x [H, W, 0][0] - absl::StrAppend(&source, "$input_data_1[gid.x, gid.y, 0]$.x;"); - } else if (ctx.input_shapes[0][1] == ctx.input_shapes[1][1] && - ctx.input_shapes[0][2] == ctx.input_shapes[1][2]) { - // [H, W, C] x [H, W, C] - absl::StrAppend(&source, "$input_data_1[gid.x, gid.y, gid.z]$;"); - } else { - // [H, W, C] x [0, 0, C] - absl::StrAppend(&source, "$input_data_1[0, 0, gid.z]$;"); +absl::Status GenerateMultiplyRuntimeTensorCode( + const NodeShader::GenerationContext& ctx, GeneratedCode* generated_code) { + std::string x_coord, y_coord, z_coord; + RETURN_IF_ERROR( + GetCoordinate(ctx, /*dim=*/2, /*default_coord=*/"gid.x", &x_coord)); + RETURN_IF_ERROR( + GetCoordinate(ctx, /*dim=*/1, /*default_coord=*/"gid.y", &y_coord)); + RETURN_IF_ERROR( + GetCoordinate(ctx, /*dim=*/3, /*default_coord=*/"gid.z", &z_coord)); + + std::string source = + absl::StrCat("vec4 input1_value = $input_data_1[", x_coord, ", ", y_coord, + ", ", z_coord, "]$;"); + // Single channel mask support. Without this duplication, the rest of channels + // will be zeros, which will make the mul operation produce incorrect result. + if (ctx.input_shapes[1][3] == 1 && ctx.input_shapes[0][3] != 1) { + absl::StrAppend( + &source, + "\ninput1_value = vec4(input1_value.x, input1_value.x, input1_value.x, " + "input1_value.x);\n"); } + absl::StrAppend( + &source, "value_0 = $input_data_0[gid.x, gid.y, gid.z]$ * input1_value;"); *generated_code = { /*parameters=*/{}, @@ -83,7 +88,7 @@ absl::Status GenerateApplyMaskCode(const NodeShader::GenerationContext& ctx, return absl::OkStatus(); } -absl::Status GenerateMultiplyScalarCode( +absl::Status GenerateMultiplyConstantTensorCode( const NodeShader::GenerationContext& ctx, GeneratedCode* generated_code) { const auto& attr = std::any_cast(ctx.op_attr); @@ -123,6 +128,26 @@ absl::Status GenerateMultiplyScalarCode( } if (std::holds_alternative>(attr.param)) { + bool single_channel_mask = + std::get>(attr.param).shape.c == 1; + std::string source; + if (single_channel_mask) { + source = "vec4 const_val = $hwc_buffer[gid.x, gid.y, 0]$;"; + // Single channel mask support. Without this duplication, the rest of + // channels will be zeros, which will make the mul operation produce + // incorrect result. + if (ctx.input_shapes[0][3] != 1) { + absl::StrAppend( + &source, + "\nconst_val = vec4(const_val.x, const_val.x, const_val.x, " + "const_val.x);\n"); + } + } else { + source = "vec4 const_val = $hwc_buffer[gid.x, gid.y, gid.z]$;"; + } + + absl::StrAppend(&source, "value_0 *= const_val;"); + *generated_code = { /*parameters=*/{}, /*objects=*/ @@ -140,7 +165,8 @@ absl::Status GenerateMultiplyScalarCode( static_cast(ctx.input_shapes[0][1]), DivideRoundUp(static_cast(ctx.input_shapes[0][3]), 4)), /*workgroup=*/uint3(), - /*source_code=*/"value_0 *= $hwc_buffer[gid.x, gid.y, gid.z]$;", + /*source_code=*/ + std::move(source), /*input=*/IOStructure::AUTO, /*output=*/IOStructure::AUTO, }; @@ -154,10 +180,10 @@ class Multiply : public NodeShader { public: absl::Status GenerateCode(const GenerationContext& ctx, GeneratedCode* generated_code) const final { - if (IsApplyMaskSupported(ctx)) { - return GenerateApplyMaskCode(ctx, generated_code); + if (ctx.input_shapes.size() == 2) { + return GenerateMultiplyRuntimeTensorCode(ctx, generated_code); } else { - return GenerateMultiplyScalarCode(ctx, generated_code); + return GenerateMultiplyConstantTensorCode(ctx, generated_code); } } }; diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/mul_test.cc b/tensorflow/lite/delegates/gpu/gl/kernels/mul_test.cc index 3d931df45247f4..e8379610912c5b 100644 --- a/tensorflow/lite/delegates/gpu/gl/kernels/mul_test.cc +++ b/tensorflow/lite/delegates/gpu/gl/kernels/mul_test.cc @@ -31,27 +31,59 @@ namespace gpu { namespace gl { namespace { -TEST(MulTest, Scalar) { +TEST(MulTest, ConstantTensorMatchingShape) { TensorRef input; input.type = DataType::FLOAT32; input.ref = 0; - input.shape = BHWC(1, 2, 2, 1); + input.shape = BHWC(1, 1, 2, 2); TensorRef output; output.type = DataType::FLOAT32; output.ref = 1; - output.shape = BHWC(1, 2, 2, 1); + output.shape = input.shape; ElementwiseAttributes attr; - attr.param = 2.f; + Tensor tensor_3d; + tensor_3d.shape.h = input.shape.h; + tensor_3d.shape.w = input.shape.w; + tensor_3d.shape.c = input.shape.c; + tensor_3d.id = 2; + tensor_3d.data = {-2, 2, -3, 3}; + attr.param = std::move(tensor_3d); SingleOpModel model({ToString(OperationType::MUL), attr}, {input}, {output}); ASSERT_TRUE(model.PopulateTensor(0, {1, 2, 3, 4})); ASSERT_OK(model.Invoke(*NewMultiplyNodeShader())); - EXPECT_THAT(model.GetOutput(0), Pointwise(FloatNear(1e-6), {2, 4, 6, 8})); + EXPECT_THAT(model.GetOutput(0), Pointwise(FloatNear(1e-6), {-2, 4, -9, 12})); +} + +TEST(MulTest, ConstantTensorSingleChannel) { + TensorRef input; + input.type = DataType::FLOAT32; + input.ref = 0; + input.shape = BHWC(1, 1, 2, 2); + + TensorRef output; + output.type = DataType::FLOAT32; + output.ref = 1; + output.shape = input.shape; + + ElementwiseAttributes attr; + Tensor tensor_3d; + tensor_3d.shape.h = input.shape.h; + tensor_3d.shape.w = input.shape.w; + tensor_3d.shape.c = 1; + tensor_3d.id = 2; + tensor_3d.data = {-2, 2}; + attr.param = std::move(tensor_3d); + + SingleOpModel model({ToString(OperationType::MUL), attr}, {input}, {output}); + ASSERT_TRUE(model.PopulateTensor(0, {1, 2, 3, 4})); + ASSERT_OK(model.Invoke(*NewMultiplyNodeShader())); + EXPECT_THAT(model.GetOutput(0), Pointwise(FloatNear(1e-6), {-2, -4, 6, 8})); } -TEST(MulTest, Linear) { +TEST(MulTest, ConstantTensorLinear) { TensorRef input; input.type = DataType::FLOAT32; input.ref = 0; @@ -60,7 +92,7 @@ TEST(MulTest, Linear) { TensorRef output; output.type = DataType::FLOAT32; output.ref = 1; - output.shape = BHWC(1, 1, 2, 2); + output.shape = input.shape; ElementwiseAttributes attr; Tensor tensor; @@ -75,33 +107,76 @@ TEST(MulTest, Linear) { EXPECT_THAT(model.GetOutput(0), Pointwise(FloatNear(1e-6), {2, 6, 6, 12})); } -TEST(MulTest, ConstTensor3D) { +TEST(MulTest, ConstantTensorScalar) { TensorRef input; input.type = DataType::FLOAT32; input.ref = 0; - input.shape = BHWC(1, 1, 2, 2); + input.shape = BHWC(1, 2, 2, 1); TensorRef output; output.type = DataType::FLOAT32; output.ref = 1; - output.shape = BHWC(1, 1, 2, 2); + output.shape = input.shape; ElementwiseAttributes attr; - Tensor tensor_3d; - tensor_3d.shape.h = 1; - tensor_3d.shape.w = 2; - tensor_3d.shape.c = 2; - tensor_3d.id = 2; - tensor_3d.data = {-2, 2, -3, 3}; - attr.param = std::move(tensor_3d); + attr.param = 2.f; SingleOpModel model({ToString(OperationType::MUL), attr}, {input}, {output}); ASSERT_TRUE(model.PopulateTensor(0, {1, 2, 3, 4})); ASSERT_OK(model.Invoke(*NewMultiplyNodeShader())); - EXPECT_THAT(model.GetOutput(0), Pointwise(FloatNear(1e-6), {-2, 4, -9, 12})); + EXPECT_THAT(model.GetOutput(0), Pointwise(FloatNear(1e-6), {2, 4, 6, 8})); +} + +TEST(MulTest, RuntimeTensorMatchingShapeNonOnes) { + TensorRef input; + input.type = DataType::FLOAT32; + input.ref = 0; + input.shape = BHWC(1, 2, 2, 2); + + TensorRef mask; + mask.type = DataType::FLOAT32; + mask.ref = 1; + mask.shape = input.shape; + + TensorRef output; + output.type = DataType::FLOAT32; + output.ref = 2; + output.shape = input.shape; + + SingleOpModel model({ToString(OperationType::MUL), {}}, {input, mask}, + {output}); + ASSERT_TRUE(model.PopulateTensor(0, {1, 2, 3, 4, -1, -2, -3, -4})); + ASSERT_TRUE(model.PopulateTensor(1, {5, 6, 7, 8, 9, 10, 11, 12})); + ASSERT_OK(model.Invoke(*NewMultiplyNodeShader())); + EXPECT_THAT(model.GetOutput(0), + Pointwise(FloatNear(1e-6), {5, 12, 21, 32, -9, -20, -33, -48})); +} + +TEST(MulTest, RuntimeTensorMatchingShapeHeightOne) { + TensorRef input; + input.type = DataType::FLOAT32; + input.ref = 0; + input.shape = BHWC(1, 1, 2, 2); + + TensorRef mask; + mask.type = DataType::FLOAT32; + mask.ref = 1; + mask.shape = input.shape; + + TensorRef output; + output.type = DataType::FLOAT32; + output.ref = 2; + output.shape = input.shape; + + SingleOpModel model({ToString(OperationType::MUL), {}}, {input, mask}, + {output}); + ASSERT_TRUE(model.PopulateTensor(0, {1, 2, 3, 4})); + ASSERT_TRUE(model.PopulateTensor(1, {1, 2, 3, 4})); + ASSERT_OK(model.Invoke(*NewMultiplyNodeShader())); + EXPECT_THAT(model.GetOutput(0), Pointwise(FloatNear(1e-6), {1, 4, 9, 16})); } -TEST(MulTest, MaskChannel1) { +TEST(MulTest, RuntimeTensorSingleChannel) { TensorRef input; input.type = DataType::FLOAT32; input.ref = 0; @@ -110,12 +185,12 @@ TEST(MulTest, MaskChannel1) { TensorRef mask; mask.type = DataType::FLOAT32; mask.ref = 1; - mask.shape = BHWC(1, 1, 2, 1); + mask.shape = BHWC(1, input.shape.h, input.shape.w, 1); TensorRef output; output.type = DataType::FLOAT32; output.ref = 2; - output.shape = BHWC(1, 1, 2, 2); + output.shape = input.shape; SingleOpModel model({ToString(OperationType::MUL), {}}, {input, mask}, {output}); @@ -125,7 +200,7 @@ TEST(MulTest, MaskChannel1) { EXPECT_THAT(model.GetOutput(0), Pointwise(FloatNear(1e-6), {2, 4, 9, 12})); } -TEST(MulTest, MaskChannelEqualsToInputChannel) { +TEST(MulTest, RuntimeTensorLinear) { TensorRef input; input.type = DataType::FLOAT32; input.ref = 0; @@ -134,19 +209,43 @@ TEST(MulTest, MaskChannelEqualsToInputChannel) { TensorRef mask; mask.type = DataType::FLOAT32; mask.ref = 1; - mask.shape = BHWC(1, 1, 2, 2); + mask.shape = BHWC(1, 1, 1, input.shape.c); TensorRef output; output.type = DataType::FLOAT32; output.ref = 2; - output.shape = BHWC(1, 1, 2, 2); + output.shape = input.shape; SingleOpModel model({ToString(OperationType::MUL), {}}, {input, mask}, {output}); ASSERT_TRUE(model.PopulateTensor(0, {1, 2, 3, 4})); - ASSERT_TRUE(model.PopulateTensor(1, {1, 2, 3, 4})); + ASSERT_TRUE(model.PopulateTensor(1, {1, 2})); ASSERT_OK(model.Invoke(*NewMultiplyNodeShader())); - EXPECT_THAT(model.GetOutput(0), Pointwise(FloatNear(1e-6), {1, 4, 9, 16})); + EXPECT_THAT(model.GetOutput(0), Pointwise(FloatNear(1e-6), {1, 4, 3, 8})); +} + +TEST(MulTest, RuntimeTensorScalar) { + TensorRef input; + input.type = DataType::FLOAT32; + input.ref = 0; + input.shape = BHWC(1, 1, 2, 2); + + TensorRef mask; + mask.type = DataType::FLOAT32; + mask.ref = 1; + mask.shape = BHWC(1, 1, 1, 1); + + TensorRef output; + output.type = DataType::FLOAT32; + output.ref = 2; + output.shape = input.shape; + + SingleOpModel model({ToString(OperationType::MUL), {}}, {input, mask}, + {output}); + ASSERT_TRUE(model.PopulateTensor(0, {1, 2, 3, 4})); + ASSERT_TRUE(model.PopulateTensor(1, {5})); + ASSERT_OK(model.Invoke(*NewMultiplyNodeShader())); + EXPECT_THAT(model.GetOutput(0), Pointwise(FloatNear(1e-6), {5, 10, 15, 20})); } } // namespace From 8f04131a84f3677a4f81d88071c4b6f058af5c6c Mon Sep 17 00:00:00 2001 From: Adam Cogdell Date: Thu, 13 Jul 2023 14:22:32 -0700 Subject: [PATCH 273/376] Internal change only. PiperOrigin-RevId: 547920627 --- tensorflow/cc/saved_model/BUILD | 98 +++++++++++++++++++ tensorflow/cc/saved_model/fingerprinting.cc | 4 +- tensorflow/core/protobuf/fingerprint.proto | 1 + .../pywrap_saved_model_fingerprinting_test.py | 13 +++ 4 files changed, 114 insertions(+), 2 deletions(-) diff --git a/tensorflow/cc/saved_model/BUILD b/tensorflow/cc/saved_model/BUILD index e1bc27d3edc7cb..27177124e80b0d 100644 --- a/tensorflow/cc/saved_model/BUILD +++ b/tensorflow/cc/saved_model/BUILD @@ -464,6 +464,7 @@ cc_library( ]) + if_android([ "//tensorflow/core:portable_tensorflow_lib_lite", ]) + if_google([ + ":fingerprinting_utils", "//tensorflow/tools/proto_splitter/cc:util", ]), alwayslink = True, @@ -486,6 +487,103 @@ cc_library( ]) + if_not_mobile(["//tensorflow/core:lib"]) + if_android(["//tensorflow/core:portable_tensorflow_lib_lite"]), ) +# copybara:uncomment_begin(google-only) +# +# cc_library( +# name = "fingerprinting_utils_impl", +# srcs = [ +# "fingerprinting_utils.cc", +# "fingerprinting_utils.h", +# ], +# visibility = [ +# "//tensorflow:__pkg__", +# ], +# deps = [ +# ":constants", +# "@com_google_absl//absl/status", +# "@com_google_absl//absl/status:statusor", +# "@com_google_absl//absl/strings", +# "//third_party/riegeli/bytes:file_reader", +# "//third_party/riegeli/records:record_reader", +# "//tensorflow/core:lib", +# "//tensorflow/core:protos_all_cc", +# "//tensorflow/core/graph/regularization:simple_delete", +# "//tensorflow/core/graph/regularization:util", +# "//tensorflow/core/util/tensor_bundle:naming", +# "//tensorflow/tools/proto_splitter:chunk_proto_cc", +# "//tensorflow/tools/proto_splitter:merge", +# "//tensorflow/tools/proto_splitter/cc:util", +# "//tensorflow/tsl/platform:protobuf", +# ], +# alwayslink = True, +# ) +# +# cc_library( +# name = "fingerprinting_utils", +# hdrs = ["fingerprinting_utils.h"], +# visibility = [ +# "//tensorflow/cc/saved_model:__subpackages__", +# ], +# deps = if_static([ +# ":fingerprinting_utils_impl", +# "@com_google_protobuf//:protobuf_headers", +# "@com_google_absl//absl/status", +# "@com_google_absl//absl/status:statusor", +# "@com_google_absl//absl/strings", +# "//tensorflow/core:protos_all_cc", +# "//tensorflow/tools/proto_splitter:chunk_proto_cc", +# "//tensorflow/tsl/platform:protobuf", +# "//third_party/riegeli/bytes:file_reader", +# "//third_party/riegeli/records:record_reader", +# "//tensorflow/core:lib", +# ]), +# ) +# +# tf_cc_test( +# name = "fingerprinting_utils_test", +# srcs = ["fingerprinting_utils_test.cc"], +# data = [ +# "//tensorflow/tools/proto_splitter/testdata:many-field.cpb", +# "//tensorflow/tools/proto_splitter/testdata:split-standard.cpb", +# ], +# deps = [ +# ":fingerprinting_utils", +# "@com_google_absl//absl/status", +# "@com_google_absl//absl/strings", +# "//third_party/protobuf", +# "//third_party/riegeli/bytes:file_reader", +# "//third_party/riegeli/records:record_reader", +# "//tensorflow/core:protos_all_cc", +# "//tensorflow/core/platform:errors", +# "//tensorflow/core/platform:path", +# "//tensorflow/core/platform:protobuf", +# "//tensorflow/core/platform:test", +# "//tensorflow/tools/proto_splitter:chunk_proto_cc", +# "//tensorflow/tools/proto_splitter/cc:util", +# "//tensorflow/tools/proto_splitter/testdata:test_message_proto_cc", +# "@com_google_googletest//:gtest_main", +# ], +# ) +# +# tf_cc_test( +# name = "fingerprinting_chunked_test", +# size = "small", +# srcs = ["fingerprinting_chunked_test.cc"], +# data = [ +# ":saved_model_fingerprinting_test_files", +# ":saved_model_test_files", +# ], +# deps = [ +# ":fingerprinting", +# "//tensorflow/core:protos_all_cc", +# "//tensorflow/core:test", +# "//tensorflow/core/platform:path", +# "@com_google_googletest//:gtest_main", +# ], +# ) +# +# copybara:uncomment_end + tf_cc_test( name = "fingerprinting_test", size = "small", diff --git a/tensorflow/cc/saved_model/fingerprinting.cc b/tensorflow/cc/saved_model/fingerprinting.cc index 72f23ed745b5e1..d8f91267483f6d 100644 --- a/tensorflow/cc/saved_model/fingerprinting.cc +++ b/tensorflow/cc/saved_model/fingerprinting.cc @@ -23,6 +23,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "tensorflow/cc/saved_model/constants.h" +// Placeholder for protosplitter riegeli includes. #include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/versions.pb.h" #include "tensorflow/core/graph/regularization/simple_delete.h" @@ -178,8 +179,7 @@ absl::StatusOr CreateFingerprintDef( return CreateFingerprintDefPb(export_dir, absl::StrCat(prefix, ".pb")); - return absl::UnimplementedError( - "Chunked proto fingerprinting unimplemented."); + return absl::PermissionDeniedError("Chunked proto format is not available in OSS."); } absl::StatusOr ReadSavedModelFingerprint( diff --git a/tensorflow/core/protobuf/fingerprint.proto b/tensorflow/core/protobuf/fingerprint.proto index 837b9a04d61db0..6ac5307ebacab7 100644 --- a/tensorflow/core/protobuf/fingerprint.proto +++ b/tensorflow/core/protobuf/fingerprint.proto @@ -27,4 +27,5 @@ message FingerprintDef { uint64 checkpoint_hash = 5; // Version specification of the fingerprint. VersionDef version = 6; + // TODO(b/290068219): add USM version when GA } diff --git a/tensorflow/python/saved_model/pywrap_saved_model_fingerprinting_test.py b/tensorflow/python/saved_model/pywrap_saved_model_fingerprinting_test.py index 99a2802f631ad9..ea34b16b77a36c 100644 --- a/tensorflow/python/saved_model/pywrap_saved_model_fingerprinting_test.py +++ b/tensorflow/python/saved_model/pywrap_saved_model_fingerprinting_test.py @@ -89,6 +89,19 @@ def test_read_saved_model_singleprint_from_sm(self): "12074714563970609759", # saved_object_graph_hash ])) + def test_read_chunked_saved_model_fingerprint(self): + if is_oss: + self.skipTest("Experimental image format disabled in OSS.") + export_dir = test.test_src_dir_path( + "cc/saved_model/testdata/chunked_saved_model/chunked_model") + fingerprint = fingerprint_pb2.FingerprintDef().FromString( + pywrap_fingerprinting.CreateFingerprintDef(export_dir)) + self.assertGreater(fingerprint.saved_model_checksum, 0) + self.assertEqual(fingerprint.graph_def_program_hash, 906548630859202535) + self.assertEqual(fingerprint.signature_def_hash, 1043582354059066488) + self.assertEqual(fingerprint.saved_object_graph_hash, 11894619660760763927) + self.assertEqual(fingerprint.checkpoint_hash, 0) + if __name__ == "__main__": test.main() From 26d212e21039ff07cf6f22b0ba6f209d31cab628 Mon Sep 17 00:00:00 2001 From: Swachhand Lokhande Date: Thu, 13 Jul 2023 14:26:11 -0700 Subject: [PATCH 274/376] Delete DeviceCompiler when a new PjRtClient is created for DEVICE_GPU. Also change a python e2e test to use GPU device instead of XLA_GPU device. PiperOrigin-RevId: 547921819 --- tensorflow/compiler/jit/BUILD | 6 ---- tensorflow/core/common_runtime/gpu/BUILD | 3 -- .../core/common_runtime/gpu/gpu_device.cc | 32 ++----------------- .../python/compiler/xla/pjrt_compile_test.py | 8 ++++- 4 files changed, 10 insertions(+), 39 deletions(-) diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index 5241b4a3c5e08b..ab84540ec8c683 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -612,7 +612,6 @@ cc_library( hdrs = ["xla_compile_util.h"], visibility = [ ":internal", - "//tensorflow/core/common_runtime/gpu:__pkg__", "//tensorflow/core/common_runtime/next_pluggable_device:__pkg__", ], deps = [ @@ -655,7 +654,6 @@ cc_library( copts = tf_copts(), visibility = [ ":internal", - "//tensorflow/core/common_runtime/gpu:__pkg__", "//tensorflow/core/common_runtime/next_pluggable_device:__pkg__", ], deps = [ @@ -1419,10 +1417,6 @@ cc_library( name = "device_compilation_profiler", srcs = ["device_compilation_profiler.cc"], hdrs = ["device_compilation_profiler.h"], - visibility = [ - ":internal", - "//tensorflow/core/common_runtime/gpu:__pkg__", - ], deps = [ ":xla_activity_listener", ":xla_activity_proto_cc", diff --git a/tensorflow/core/common_runtime/gpu/BUILD b/tensorflow/core/common_runtime/gpu/BUILD index c7d8d4bc121a6c..d343d9a9ad9eb9 100644 --- a/tensorflow/core/common_runtime/gpu/BUILD +++ b/tensorflow/core/common_runtime/gpu/BUILD @@ -201,9 +201,6 @@ tf_cuda_library( "//tensorflow/compiler/tf2xla:layout_util", "//tensorflow/compiler/jit:flags", "//tensorflow/compiler/jit:pjrt_device_context", - "//tensorflow/compiler/jit:device_compilation_profiler", - "//tensorflow/compiler/jit:device_compiler", - "//tensorflow/compiler/jit:xla_compile_util", "//tensorflow/compiler/xla/pjrt/gpu:gpu_helpers", "//tensorflow/compiler/xla/pjrt/gpu:se_gpu_pjrt_client", "//tensorflow/compiler/xla/stream_executor:tf_allocator_adapter", diff --git a/tensorflow/core/common_runtime/gpu/gpu_device.cc b/tensorflow/core/common_runtime/gpu/gpu_device.cc index d13a73102ff260..fe1dbf05f75e72 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_device.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_device.cc @@ -78,16 +78,12 @@ limitations under the License. #include "tensorflow/core/platform/rocm.h" #endif #ifdef TF_GPU_USE_PJRT -#include "tensorflow/compiler/jit/device_compilation_profiler.h" -#include "tensorflow/compiler/jit/device_compiler.h" #include "tensorflow/compiler/jit/flags.h" -#include "tensorflow/compiler/jit/xla_compile_util.h" #include "tensorflow/compiler/xla/pjrt/gpu/gpu_helpers.h" #include "tensorflow/compiler/xla/pjrt/gpu/se_gpu_pjrt_client.h" #include "tensorflow/compiler/xla/pjrt/pjrt_client.h" #include "tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h" #include "tensorflow/compiler/xla/stream_executor/device_host_allocator.h" -#include "tensorflow/core/tfrt/common/global_state.h" #endif // TF_GPU_USE_PJRT #include "tensorflow/compiler/xla/stream_executor/gpu/gpu_stream.h" #include "tensorflow/compiler/xla/stream_executor/platform/dso_loader.h" @@ -116,21 +112,6 @@ limitations under the License. namespace tensorflow { namespace { -#ifdef TF_GPU_USE_PJRT -using PjRtDeviceCompiler = - DeviceCompiler; - -void DeleteDeviceCompiler(const DeviceType& device_type) { - ResourceMgr* rm = tfrt_global::GetTFGlobalResourceMgr(); - rm->Delete(rm->default_container(), - GetPjRtDeviceCompilerResourceName(device_type)) - .IgnoreError(); - rm->Delete( - rm->default_container(), - GetPjRtDeviceCompilationProfilerResourceName(device_type)) - .IgnoreError(); -} -#endif // TF_GPU_USE_PJRT // Returns priority for the given virtual GPU id from the session options. // Returns 0 if no virtual devices are specified. @@ -1774,22 +1755,15 @@ Status BaseGPUDeviceFactory::CreateDevices( /*should_stage_host_to_device_transfers=*/true, /*gpu_run_options=*/std::move(gpu_run_options)); - TF_RETURN_IF_ERROR(SetPjRtClientInTFGlobalResourceManager( - DeviceType(DEVICE_GPU), std::move(pjrt_client))); - // We don't forsee a realistic scenario where the PjRtClient is deleted and - // replaced by a new one, except in unit tests. However, if this does happen, - // the DeviceCompiler that stores the PjRtLoadedExecutables built by the old - // PjRtClient needs to be deleted. A new DeviceCompiler using the current - // PjRtClient will be created on-demand when compilation is requested (if one - // doesn't exist already). - DeleteDeviceCompiler(DeviceType(DEVICE_GPU)); + return SetPjRtClientInTFGlobalResourceManager(DeviceType(DEVICE_GPU), + std::move(pjrt_client)); #else TF_RETURN_IF_ERROR(CreateGPUDevice(options, name_prefix, tf_device_id, /*dev_locality=*/it->second, gpu_allocator, devices)); } -#endif // TF_GPU_USE_PJRT return OkStatus(); +#endif // TF_GPU_USE_PJRT } static string GetShortDeviceDescription( diff --git a/tensorflow/python/compiler/xla/pjrt_compile_test.py b/tensorflow/python/compiler/xla/pjrt_compile_test.py index ddfa81e1b9408b..31ef70b0fd2166 100644 --- a/tensorflow/python/compiler/xla/pjrt_compile_test.py +++ b/tensorflow/python/compiler/xla/pjrt_compile_test.py @@ -61,7 +61,13 @@ def bar(x, y): x.assign(y) y.assign_add([1.0, 1.0]) - with ops.device("/device:GPU:0"): + # Currently PjRt only supports compilation and execution for the XLA_GPU + # device to unblock development. Support for non-XLA devices (CPU/GPU/single + # core TPU) is going to be added soon, after which support for XLA_* devices + # will be dropped. + # TODO(b/255826209): Modify the test as we progress towards supporting + # non-XLA devices. + with ops.device("/device:XLA_GPU:0"): # Function call with scalars self.assertEqual(self.evaluate(foo(1, 2)), 4) From 29b94307a06058a587c91bc8d39b5e0003f0f2be Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 13 Jul 2023 14:27:32 -0700 Subject: [PATCH 275/376] Support custom dataclasses in TensorFlowTestCase.evaluate. PiperOrigin-RevId: 547922304 --- .../python/debug/wrappers/framework_test.py | 4 +-- tensorflow/python/framework/test_util.py | 7 ++-- tensorflow/python/framework/test_util_test.py | 36 +++++++++++++++++++ 3 files changed, 43 insertions(+), 4 deletions(-) diff --git a/tensorflow/python/debug/wrappers/framework_test.py b/tensorflow/python/debug/wrappers/framework_test.py index a1492f3e5e1fd7..e14363c623a6b6 100644 --- a/tensorflow/python/debug/wrappers/framework_test.py +++ b/tensorflow/python/debug/wrappers/framework_test.py @@ -320,7 +320,7 @@ def testUsingWrappedSessionShouldWorkAsContextManager(self): with wrapper as sess: self.assertAllClose([[3.0], [4.0]], self._s) self.assertEqual(1, self._observer["on_run_start_count"]) - self.assertEqual(self._s, self._observer["run_fetches"]) + self.assertEqual([self._s], self._observer["run_fetches"]) self.assertEqual(1, self._observer["on_run_end_count"]) self.assertAllClose( @@ -337,7 +337,7 @@ def testUsingWrappedSessionShouldSupportEvalWithAsDefault(self): with wrapper.as_default(): foo = constant_op.constant(42, name="foo") self.assertEqual(42, self.evaluate(foo)) - self.assertEqual(foo, self._observer["run_fetches"]) + self.assertEqual([foo], self._observer["run_fetches"]) def testWrapperShouldSupportSessionClose(self): wrapper = TestDebugWrapperSession(self._sess, self._dump_root, diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py index e9096a925eaf25..874bd544c773a2 100644 --- a/tensorflow/python/framework/test_util.py +++ b/tensorflow/python/framework/test_util.py @@ -2700,11 +2700,14 @@ def evaluate(self, tensors): return self._eval_helper(tensors) else: sess = ops.get_default_session() + flattened_tensors = nest.flatten(tensors) if sess is None: with self.test_session() as sess: - return sess.run(tensors) + flattened_results = sess.run(flattened_tensors) else: - return sess.run(tensors) + flattened_results = sess.run(flattened_tensors) + + return nest.pack_sequence_as(tensors, flattened_results) # pylint: disable=g-doc-return-or-yield @contextlib.contextmanager diff --git a/tensorflow/python/framework/test_util_test.py b/tensorflow/python/framework/test_util_test.py index 5ffd054c19e9c2..d882a81bea0221 100644 --- a/tensorflow/python/framework/test_util_test.py +++ b/tensorflow/python/framework/test_util_test.py @@ -16,6 +16,7 @@ import collections import copy +import dataclasses import random import sys import threading @@ -57,6 +58,26 @@ from tensorflow.python.util.protobuf import compare_test_pb2 +@dataclasses.dataclass +class MaskedTensor: + mask: bool + value: ops.Tensor + + def __tf_flatten__(self): + metadata = (self.mask,) + components = (self.value,) + return metadata, components + + @classmethod + def __tf_unflatten__(cls, metadata, components): + mask = metadata[0] + value = components[0] + return MaskedTensor(mask=mask, value=value) + + def __eq__(self, other): + return self.mask == other.mask and self.value == other.value + + class TestUtilTest(test_util.TensorFlowTestCase, parameterized.TestCase): def test_assert_ops_in_graph(self): @@ -910,6 +931,21 @@ def test_nested_tensors_evaluate(self): self.assertEqual(expected, self.evaluate(nested)) + @test_util.run_in_graph_and_eager_modes + def test_custom_dataclass_evaluate(self): + mt = MaskedTensor(mask=True, value=constant_op.constant([1])) + mt_val = self.evaluate(mt) + self.assertEqual(mt_val.mask, True) + self.assertAllEqual(mt_val.value, [1]) + + mt2 = MaskedTensor(mask=True, value=constant_op.constant([1])) + mt2_val = self.evaluate(mt2) + self.assertEqual(mt_val, mt2_val) + + mt3 = MaskedTensor(mask=True, value=constant_op.constant([2])) + mt3_val = self.evaluate(mt3) + self.assertNotEqual(mt_val, mt3_val) + def test_run_in_graph_and_eager_modes(self): l = [] def inc(self, with_brackets): From 06dc7c37abd22aabdb3a6663f687f9b0ab83e941 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Thu, 13 Jul 2023 14:51:12 -0700 Subject: [PATCH 276/376] [xla:gpu] NFC: Add graph exec id to cuda graphs logging This allows to track how many cuda graphs we create over the life time of the process and if we properly release underlying resources PiperOrigin-RevId: 547930117 --- .../xla/stream_executor/cuda/cuda_graph.cc | 17 ++++++++++++----- .../xla/stream_executor/cuda/cuda_graph.h | 16 ++++++++++++---- 2 files changed, 24 insertions(+), 9 deletions(-) diff --git a/tensorflow/compiler/xla/stream_executor/cuda/cuda_graph.cc b/tensorflow/compiler/xla/stream_executor/cuda/cuda_graph.cc index b4e9e3729090bc..c6ea5e27bf1f0b 100644 --- a/tensorflow/compiler/xla/stream_executor/cuda/cuda_graph.cc +++ b/tensorflow/compiler/xla/stream_executor/cuda/cuda_graph.cc @@ -16,7 +16,6 @@ limitations under the License. #include "tensorflow/compiler/xla/stream_executor/cuda/cuda_graph.h" #include -#include #include "absl/strings/str_format.h" #include "third_party/gpus/cuda/include/cuda_runtime_api.h" @@ -45,6 +44,10 @@ std::atomic CudaGraphSupport::alive_cuda_graph_execs_; return allocated_cuda_graph_execs_.fetch_add(1, std::memory_order_relaxed); } +/*static*/ size_t CudaGraphSupport::NotifyGraphExecDestroyed() { + return alive_cuda_graph_execs_.fetch_sub(1, std::memory_order_relaxed) - 1; +} + /*static*/ size_t CudaGraphSupport::allocated_cuda_graph_execs() { return allocated_cuda_graph_execs_.load(std::memory_order_relaxed); } @@ -61,9 +64,6 @@ void CudaGraphSupport::DestroyGraph::operator()(cudaGraph_t graph) { void CudaGraphSupport::DestroyGraphExec::operator()(cudaGraphExec_t instance) { cudaError_t err = cudaGraphExecDestroy(instance); - alive_cuda_graph_execs_.fetch_sub(1, std::memory_order_relaxed); - VLOG(5) << "Destroy CUDA graph exec (remaining alive instances: " - << CudaGraphSupport::alive_cuda_graph_execs() << ")"; CHECK(err == cudaSuccess) << "Failed to destroy CUDA graph instance: " << cudaGetErrorString(err); } @@ -109,6 +109,13 @@ tsl::Status OwnedCudaGraphExec::Launch(stream_executor::Stream* stream) { return tsl::OkStatus(); } +OwnedCudaGraphExec::~OwnedCudaGraphExec() { + if (*this) // do not log for moved-from instances + VLOG(5) << "Destroy CUDA graph exec #" << id_ + << " (remaining alive instances: " + << CudaGraphSupport::NotifyGraphExecDestroyed() << ")"; +} + //===----------------------------------------------------------------------===// // CUDA Graph Helpers. //===----------------------------------------------------------------------===// @@ -196,7 +203,7 @@ tsl::StatusOr InstantiateCudaGraph(OwnedCudaGraph graph) { VLOG(5) << "Instantiated CUDA graph exec instance #" << id << " (alive instances: " << CudaGraphSupport::alive_cuda_graph_execs() << ")"; - return OwnedCudaGraphExec(exec); + return OwnedCudaGraphExec(id, exec); } tsl::StatusOr IsStreamCapturing(stream_executor::Stream* stream) { diff --git a/tensorflow/compiler/xla/stream_executor/cuda/cuda_graph.h b/tensorflow/compiler/xla/stream_executor/cuda/cuda_graph.h index 0b851440126dcc..ad56554c0ad300 100644 --- a/tensorflow/compiler/xla/stream_executor/cuda/cuda_graph.h +++ b/tensorflow/compiler/xla/stream_executor/cuda/cuda_graph.h @@ -18,7 +18,6 @@ limitations under the License. #include #include -#include #include #include "absl/functional/any_invocable.h" @@ -41,6 +40,7 @@ class CudaGraphSupport { }; static size_t NotifyGraphExecCreated(); + static size_t NotifyGraphExecDestroyed(); static size_t allocated_cuda_graph_execs(); static size_t alive_cuda_graph_execs(); @@ -67,11 +67,16 @@ class OwnedCudaGraph class OwnedCudaGraphExec : public std::unique_ptr, CudaGraphSupport::DestroyGraphExec> { - // Bring std::unique_ptr constructors in scope. - using std::unique_ptr, - CudaGraphSupport::DestroyGraphExec>::unique_ptr; + using Base = std::unique_ptr, + CudaGraphSupport::DestroyGraphExec>; public: + OwnedCudaGraphExec(uint64_t id, cudaGraphExec_t exec) : Base(exec), id_(id) {} + ~OwnedCudaGraphExec(); + + OwnedCudaGraphExec(OwnedCudaGraphExec&&) = default; + OwnedCudaGraphExec& operator=(OwnedCudaGraphExec&&) = default; + // Updates executable graph instance with a newly captured graph. Returns an // error if the new graph is not compatible (see `cudaGraphExecUpdate`). tsl::Status Update(OwnedCudaGraph graph); @@ -79,7 +84,10 @@ class OwnedCudaGraphExec // Launches captured graph on a given stream. tsl::Status Launch(stream_executor::Stream* stream); + uint64_t id() const { return id_; } + private: + uint64_t id_; uint64_t num_updates_ = 0; uint64_t num_launches_ = 0; }; From 6f7c908925b280fe96b1529d09b0f4dd5590326d Mon Sep 17 00:00:00 2001 From: Sizhi Tan Date: Thu, 13 Jul 2023 15:04:51 -0700 Subject: [PATCH 277/376] Explicitly disable use of tfrt for failing test. PiperOrigin-RevId: 547934340 --- tensorflow/compiler/tests/BUILD | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 5a887a5d238f35..634baa45b4b4ba 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -747,6 +747,10 @@ tf_xla_py_strict_test( name = "eager_test", size = "medium", srcs = ["eager_test.py"], + # copybara:uncomment_begin + # #TODO(b/287111047): Remove once the bug is fixed. + # disable_tpu_tfrt = True, + # copybara:uncomment_end enable_mlir_bridge = False, python_version = "PY3", tags = [ @@ -1783,6 +1787,10 @@ tf_xla_py_strict_test( name = "while_test", size = "small", srcs = ["while_test.py"], + # copybara:uncomment_begin + # #TODO(b/291130193): Remove once the bug is fixed. + # disable_tpu_tfrt = True, + # copybara:uncomment_end enable_mlir_bridge = False, python_version = "PY3", tags = [ From 15be3fbbf0fd306b2a47f9f53199e8ed12d0479a Mon Sep 17 00:00:00 2001 From: Anlun Xu Date: Thu, 13 Jul 2023 15:09:06 -0700 Subject: [PATCH 278/376] [xla:gpu] Improve stream assignment for concurrent regions Before this change we assigned each kernel in a concurrent region to a different borrowed stream. We can instead assign one of the kernel to the capture stream in order to perform less synchronization. For example, if we have two kernels in the concurrent region, we synchronize two borrowed streams: 1 / \ 2 3 \ / 1 After this change we only synchronize one stream: 1 |\ 1 2 |/ 1 PiperOrigin-RevId: 547935901 --- .../service/gpu/runtime/concurrent_region.cc | 21 ++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/runtime/concurrent_region.cc b/tensorflow/compiler/xla/service/gpu/runtime/concurrent_region.cc index 3fed1c6ed07687..fb11172d7f702f 100644 --- a/tensorflow/compiler/xla/service/gpu/runtime/concurrent_region.cc +++ b/tensorflow/compiler/xla/service/gpu/runtime/concurrent_region.cc @@ -43,14 +43,22 @@ ConcurrentRegionStatus::~ConcurrentRegionStatus() { DCHECK(!IsInConcurrentRegion()); } +// Assign a stream in a round-robin fashion. Either the capture stream or one of +// the borrowed streams is returned. se::Stream* ConcurrentRegionStatus::GetNextStream() { DCHECK(IsInConcurrentRegion()); if (borrowed_streams_.empty()) { return nullptr; } - int index = stream_index_ % borrowed_streams_.size(); + + int index = stream_index_ % (borrowed_streams_.size() + 1); stream_index_++; - return borrowed_streams_[index].get(); + + if (index == 0) { + return capture_stream_; + } + + return borrowed_streams_[index - 1].get(); } absl::Status ConcurrentRegionStatus::StartConcurrentRegion( @@ -68,10 +76,9 @@ absl::Status ConcurrentRegionStatus::StartConcurrentRegion( } } - // Switch borrowed streams into capture mode. If the number of kernel launches - // in the region is less than the number of borrowed streams, only synchronize - // enough streams to run the kernels. - for (int i = 0; i < std::min(size, num_borrowed_streams_); ++i) { + // Switch borrowed streams into capture mode. We only synchronize enough + // streams to run the kernels. + for (int i = 0; i < std::min(size - 1, num_borrowed_streams_); ++i) { borrowed_streams_[i]->ThenWaitFor(capture_stream); } @@ -84,7 +91,7 @@ void ConcurrentRegionStatus::EndConcurrentRegion() { DCHECK(IsInConcurrentRegion()); // Synchronize main capture stream with all borrowed streams in capture mode. - for (int i = 0; i < std::min(region_size_, num_borrowed_streams_); + for (int i = 0; i < std::min(region_size_ - 1, num_borrowed_streams_); ++i) { capture_stream_->ThenWaitFor(borrowed_streams_[i].get()); } From b94792aea42c35c07e0ce759f48aa47a6bf9ed1d Mon Sep 17 00:00:00 2001 From: Chuan He Date: Thu, 13 Jul 2023 15:18:04 -0700 Subject: [PATCH 279/376] Replaces absl::string_view to llvm::StringRef to make it compatible with Android build. PiperOrigin-RevId: 547939062 --- tensorflow/compiler/xla/python/refine_polymorphic_shapes.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/compiler/xla/python/refine_polymorphic_shapes.cc b/tensorflow/compiler/xla/python/refine_polymorphic_shapes.cc index f3179e7b8d1f5f..5381caa8d5dffe 100644 --- a/tensorflow/compiler/xla/python/refine_polymorphic_shapes.cc +++ b/tensorflow/compiler/xla/python/refine_polymorphic_shapes.cc @@ -41,8 +41,8 @@ namespace xla { namespace { -constexpr absl::string_view shapeAssertionName = "shape_assertion"; -constexpr absl::string_view errorMessageAttrName = "error_message"; +constexpr llvm::StringRef shapeAssertionName = "shape_assertion"; +constexpr llvm::StringRef errorMessageAttrName = "error_message"; // We bound the number of error_message_inputs for using llvm::formatv constexpr int maxErrorMessageInputs = 4; From 9335670d00c995294f8cb3eeb70ecc93fcae6bfe Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 13 Jul 2023 15:20:56 -0700 Subject: [PATCH 280/376] Do not run pjrt_c_api_gpu_test in debug mode due to failure. PiperOrigin-RevId: 547940227 --- tensorflow/compiler/xla/pjrt/c/BUILD | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/compiler/xla/pjrt/c/BUILD b/tensorflow/compiler/xla/pjrt/c/BUILD index d1531a0bf1c523..e83fce610ba5bf 100644 --- a/tensorflow/compiler/xla/pjrt/c/BUILD +++ b/tensorflow/compiler/xla/pjrt/c/BUILD @@ -109,7 +109,7 @@ cc_library( xla_cc_test( name = "pjrt_c_api_gpu_test", srcs = ["pjrt_c_api_gpu_test.cc"], - tags = tf_cuda_tests_tags(), + tags = tf_cuda_tests_tags() + ["nodebug"], # TODO(b/291073132): Test failing in debug mode. deps = [ ":pjrt_c_api_gpu", ":pjrt_c_api_hdrs", From 58d49ef60a1a0bce844fa996d142e1d5ae08d323 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 13 Jul 2023 15:32:55 -0700 Subject: [PATCH 281/376] Integrate LLVM at llvm/llvm-project@1936bb81aafd Updates LLVM usage to match [1936bb81aafd](https://github.com/llvm/llvm-project/commit/1936bb81aafd) PiperOrigin-RevId: 547944950 --- third_party/llvm/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl index b9f904631a41c3..ddb4704d4b456c 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "a69b2e3d1c1a123e66df58116e5ca0e57e808307" - LLVM_SHA256 = "6a613ef7f464231b3b5c953095bd19c2a7a813e9b8df1e474b66ba542f04f87f" + LLVM_COMMIT = "1936bb81aafdbb3b4c9770a24fc77ba07669bd19" + LLVM_SHA256 = "7b519ebd1b17dd59b94e5836b431486ec3e2d020c7799c70ce9ee706d25d3c5d" tf_http_archive( name = name, From f71f31cc2a6ad2daf22b17991a545b5e75d1ddaf Mon Sep 17 00:00:00 2001 From: James Mullenbach Date: Thu, 13 Jul 2023 15:45:30 -0700 Subject: [PATCH 282/376] Make FunctionType.flat_inputs and .flat_captures thread safe. This prevents a race condition in ParameterServerStrategy when the same function is being dispatched in parallel. PiperOrigin-RevId: 547949685 --- tensorflow/core/function/polymorphism/function_type.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tensorflow/core/function/polymorphism/function_type.py b/tensorflow/core/function/polymorphism/function_type.py index 06c7e1047c44a9..033afdef6d9b0b 100644 --- a/tensorflow/core/function/polymorphism/function_type.py +++ b/tensorflow/core/function/polymorphism/function_type.py @@ -355,9 +355,10 @@ def placeholder_arguments( def flat_inputs(self) -> List[trace.TraceType]: """Flat tensor inputs accepted by this FunctionType.""" if not hasattr(self, "_cached_flat_inputs"): - self._cached_flat_inputs = [] + cached_flat_inputs = [] for p in self.parameters.values(): - self._cached_flat_inputs.extend(p.type_constraint._flatten()) # pylint: disable=protected-access + cached_flat_inputs.extend(p.type_constraint._flatten()) # pylint: disable=protected-access + self._cached_flat_inputs = cached_flat_inputs return self._cached_flat_inputs @@ -399,9 +400,10 @@ def unpack_inputs( def flat_captures(self) -> List[trace.TraceType]: """Flat tensor captures needed by this FunctionType.""" if not hasattr(self, "_cached_flat_captures"): - self._cached_flat_captures = [] + cached_flat_captures = [] for t in self.captures.values(): - self._cached_flat_captures.extend(t._flatten()) # pylint: disable=protected-access + cached_flat_captures.extend(t._flatten()) # pylint: disable=protected-access + self._cached_flat_captures = cached_flat_captures return self._cached_flat_captures From a2ab482bf4a14ddbc42494516ab56f4a0c7481d4 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 13 Jul 2023 15:46:44 -0700 Subject: [PATCH 283/376] Change IFRT's kPred DType to have 1-byte width instead of 1-bit. XLA's PRED DType has 1-byte, so making IFRT's kPred match this will prevent dtype-translation bugs. We can introduce a kOneBitPred DType later if needed. PiperOrigin-RevId: 547949971 --- tensorflow/compiler/xla/python/ifrt/dtype.cc | 2 +- tensorflow/compiler/xla/python/ifrt/dtype.h | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tensorflow/compiler/xla/python/ifrt/dtype.cc b/tensorflow/compiler/xla/python/ifrt/dtype.cc index de04817559cdef..fe11b672449133 100644 --- a/tensorflow/compiler/xla/python/ifrt/dtype.cc +++ b/tensorflow/compiler/xla/python/ifrt/dtype.cc @@ -26,6 +26,7 @@ namespace ifrt { std::optional DType::byte_size() const { switch (kind_) { + case kPred: case kS8: case kU8: return 1; @@ -53,7 +54,6 @@ std::optional DType::byte_size() const { std::optional DType::bit_size() const { switch (kind_) { case kPred: - return 1; case kS8: case kU8: return 8; diff --git a/tensorflow/compiler/xla/python/ifrt/dtype.h b/tensorflow/compiler/xla/python/ifrt/dtype.h index 7888a479b09a48..f98e823c00f82f 100644 --- a/tensorflow/compiler/xla/python/ifrt/dtype.h +++ b/tensorflow/compiler/xla/python/ifrt/dtype.h @@ -97,8 +97,8 @@ class DType { bool operator!=(const DType& other) const { return kind_ != other.kind_; } // Returns the byte size of a single element of this DType. Returns - // std::nullopt if there is no fixed size or not aligned to a byte boundary - // (such as kPred). + // std::nullopt if not aligned to a byte boundary or there is no fixed size + // (such as kString). std::optional byte_size() const; // Returns the bit size of a single element of this DType. Returns From 3a8b4ffe6f1402b967709d2fc7120f808c5c198c Mon Sep 17 00:00:00 2001 From: T Coxon <97948946+tttc3@users.noreply.github.com> Date: Thu, 13 Jul 2023 15:46:59 -0700 Subject: [PATCH 284/376] PR #3980: [python:xla_extension] Handle unbounded recursion in cyclical PyTrees MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Imported from GitHub PR https://github.com/openxla/xla/pull/3980 Provides more graceful handling of an unbounded recursion error that can occur when attempting to flatten a PyTree with cyclical node references. Example from https://github.com/google/jax/issues/15711: ```Python import jax.tree_util as jtu a = [] a.append(a) jtu.tree_flatten(a) # “python” terminated by signal SIGSEGV (Address boundary error) ``` With this pull, the above snippet now returns the error message: ```Python RecursionError: maximum recursion depth exceeded in flatten; PyTree may have cyclical node references. ``` The maximum recursion depth before the error is throw, is controlled by the Python interpreter. Copybara import of the project: -- 73f9a28d3dd69025a9e90815183dc79e0dddcc0a by tttc3 : Handle unbounded recursion in cyclical PyTrees Merging this change closes #3980 PiperOrigin-RevId: 547950031 --- tensorflow/compiler/xla/python/pytree.cc | 10 ++++++++++ tensorflow/compiler/xla/python/xla_client.py | 2 +- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/tensorflow/compiler/xla/python/pytree.cc b/tensorflow/compiler/xla/python/pytree.cc index bae3fc0083e294..2ae556a19b53b7 100644 --- a/tensorflow/compiler/xla/python/pytree.cc +++ b/tensorflow/compiler/xla/python/pytree.cc @@ -185,7 +185,12 @@ void PyTreeDef::FlattenIntoImpl( } else { node.kind = GetKind(handle, &node.custom); auto recurse = [this, &leaf_predicate, &leaves](py::handle child) { + if (Py_EnterRecursiveCall( + " in flatten; PyTree may have cyclical node references.")) { + return; + } FlattenInto(child, leaves, leaf_predicate); + Py_LeaveRecursiveCall(); }; switch (node.kind) { case PyTreeKind::kNone: @@ -265,6 +270,11 @@ PyTreeDef::Flatten(py::handle x, std::optional leaf_predicate) { std::vector leaves; auto tree = std::make_unique(); tree->FlattenInto(x, leaves, leaf_predicate); + // Handle the unbounded recursion error for trees with cyclical node + // references. + if (PyErr_Occurred()) { + throw py::error_already_set(); + } return std::make_pair(std::move(leaves), std::move(tree)); } diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index 61a3cd8e0e7fd4..c52fa3c3f9d74f 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -44,7 +44,7 @@ # Just an internal arbitrary increasing number to help with backward-compatible # changes. -_version = 166 +_version = 167 # Version number for MLIR:Python components. mlir_api_version = 52 From 8708a23539563326a2a124e8f3557682edf07c60 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 13 Jul 2023 15:49:46 -0700 Subject: [PATCH 285/376] Add dataclass support in tf.data data processing PiperOrigin-RevId: 547950737 --- .../data/kernel_tests/from_tensors_test.py | 25 ++- .../python/data/kernel_tests/map_test.py | 166 ++++++++++++++++++ .../python/data/kernel_tests/zip_test.py | 35 +++- tensorflow/python/data/util/BUILD | 3 + tensorflow/python/data/util/nest_test.py | 104 ++++++++++- tensorflow/python/data/util/structure.py | 7 + tensorflow/python/data/util/structure_test.py | 37 +++- 7 files changed, 373 insertions(+), 4 deletions(-) diff --git a/tensorflow/python/data/kernel_tests/from_tensors_test.py b/tensorflow/python/data/kernel_tests/from_tensors_test.py index 7584d6bd2357c1..aa564d371421b7 100644 --- a/tensorflow/python/data/kernel_tests/from_tensors_test.py +++ b/tensorflow/python/data/kernel_tests/from_tensors_test.py @@ -14,8 +14,9 @@ # ============================================================================== """Tests for `tf.data.Dataset.from_tensors().""" import collections -from absl.testing import parameterized +import dataclasses +from absl.testing import parameterized import numpy as np from tensorflow.core.protobuf import config_pb2 @@ -45,6 +46,22 @@ attr = None +@dataclasses.dataclass +class MaskedTensor: + mask: bool + value: np.ndarray + + def __tf_flatten__(self): + metadata = (self.mask,) + components = (self.value,) + return metadata, components + + def __tf_unflatten__(self, metadata, components): + mask = metadata[0] + value = components[0] + return MaskedTensor(mask=mask, value=value) + + class FromTensorsTest(test_base.DatasetTestBase, parameterized.TestCase): @combinations.generate(test_base.default_test_combinations()) @@ -151,6 +168,12 @@ class Foo: dataset = dataset_ops.Dataset.from_tensors(element) self.assertDatasetProduces(dataset, expected_output=[element]) + @combinations.generate(test_base.default_test_combinations()) + def testFromTensorsDataclass(self): + mt = MaskedTensor(mask=True, value=np.array([1])) + dataset = dataset_ops.Dataset.from_tensors(mt) + self.assertDatasetProduces(dataset, expected_output=[mt]) + @combinations.generate(test_base.default_test_combinations()) def testFromTensorsMixedRagged(self): components = (np.array(1), np.array([1, 2, 3]), np.array(37.0), diff --git a/tensorflow/python/data/kernel_tests/map_test.py b/tensorflow/python/data/kernel_tests/map_test.py index 608764cff7ae6a..a949be7b1893d2 100644 --- a/tensorflow/python/data/kernel_tests/map_test.py +++ b/tensorflow/python/data/kernel_tests/map_test.py @@ -14,6 +14,7 @@ # ============================================================================== """Tests for `tf.data.Dataset.map()`.""" import collections +import dataclasses import functools import threading import time @@ -135,6 +136,59 @@ def __init__(self): pass +@dataclasses.dataclass +class MyDataclass: + value1: ops.Tensor + value2: ops.Tensor + + def __tf_flatten__(self): + metadata = tuple() + components = (self.value1, self.value2) + return metadata, components + + @classmethod + def __tf_unflatten__(cls, metadata, components): + del metadata + return cls(value1=components[0], value2=components[1]) + + +@dataclasses.dataclass +class MaskedTensor: + mask: bool + value: ops.Tensor + + def __tf_flatten__(self): + metadata = (self.mask,) + components = (self.value,) + return metadata, components + + @classmethod + def __tf_unflatten__(cls, metadata, components): + mask = metadata[0] + value = components[0] + return MaskedTensor(mask=mask, value=value) + + +@dataclasses.dataclass +class NestedMaskedTensor: + mask: bool + value: MaskedTensor + + def __tf_flatten__(self): + metadata = (self.mask,) + components = (self.value,) + return metadata, components + + @classmethod + def __tf_unflatten__(cls, metadata, components): + mask = metadata[0] + value = components[0] + return NestedMaskedTensor(mask=mask, value=value) + + def __eq__(self, other): + return self.mask == other.mask and self.value == other.value + + class MapTest(test_base.DatasetTestBase, parameterized.TestCase): def _map_dataset_factory(self, components, apply_map, count): @@ -547,6 +601,118 @@ def testMapDict(self, apply_map): self.assertDatasetProduces( dataset, expected_output=[i * 2 + i**2 for i in range(10)]) + @combinations.generate(_test_combinations()) + def testMapDataclass(self, apply_map): + dataset = dataset_ops.Dataset.range(10) + dataset = apply_map(dataset, lambda x: MyDataclass(value1=x, value2=2 * x)) + dataset = apply_map(dataset, lambda x: x.value1 + x.value2) + self.assertDatasetProduces( + dataset, + expected_output=[3 * x for x in range(10)], + ) + + @combinations.generate(_test_combinations()) + def testMapMaskedTensor(self, apply_map): + dataset = dataset_ops.Dataset.range(10) + dataset = apply_map(dataset, lambda x: MaskedTensor(mask=True, value=x)) + dataset = apply_map(dataset, lambda x: 3 * x.value) + self.assertDatasetProduces( + dataset, + expected_output=[3 * x for x in range(10)], + ) + + @combinations.generate(_test_combinations()) + def testMapDataclassWithInputAndOutput(self, apply_map): + dataset = dataset_ops.Dataset.from_tensors(MyDataclass(value1=1, value2=2)) + dataset = apply_map(dataset, lambda x: (x.value1 * 5, x.value2)) + dataset = apply_map( + dataset, lambda x, y: MaskedTensor(mask=True, value=x + y) + ) + dataset = apply_map( + dataset, lambda m: NestedMaskedTensor(mask=False, value=m) + ) + self.assertDatasetProduces( + dataset, + expected_output=[ + NestedMaskedTensor( + mask=False, value=MaskedTensor(mask=True, value=7) + ) + ], + ) + + @combinations.generate(_test_combinations()) + def testMapListOfDataclassObjects(self, apply_map): + dataset = dataset_ops.Dataset.range(10) + + # Creates a list of dataclass objects. + dataset = apply_map( + dataset, + lambda x: [ # pylint: disable=g-long-lambda + MyDataclass(value1=x, value2=1), + MyDataclass(value1=2, value2=2 * x), + ], + ) + + # Takes a list of dataclass objects as input. + dataset = apply_map(dataset, lambda *x: x[0].value1 + x[1].value2) + + self.assertDatasetProduces( + dataset, + expected_output=[3 * x for x in range(10)], + ) + + @combinations.generate(_test_combinations()) + def testMapDictOfDataclassValues(self, apply_map): + dataset = dataset_ops.Dataset.range(10) + + # Creates a dict of {str -> dataclass}. + dataset = apply_map( + dataset, + lambda x: { # pylint: disable=g-long-lambda + "a": MyDataclass(value1=x, value2=1), + "b": MyDataclass(value1=2, value2=2 * x), + }, + ) + # Takes a dict of dataclass values as input. + dataset = apply_map(dataset, lambda x: x["a"].value1 + x["b"].value2) + self.assertDatasetProduces( + dataset, + expected_output=[3 * x for x in range(10)], + ) + + @combinations.generate(_test_combinations()) + def testMapNestedMaskedTensorWithDataclassInput(self, apply_map): + dataset = dataset_ops.Dataset.range(10) + dataset = apply_map(dataset, lambda x: MaskedTensor(mask=True, value=x)) + dataset = apply_map( + dataset, + # Takes a MaskedTensor as input. + lambda x: NestedMaskedTensor(mask=False, value=x), + ) + dataset = apply_map(dataset, lambda x: 5 * x.value.value) + self.assertDatasetProduces( + dataset, + expected_output=[5 * x for x in range(10)], + ) + + @combinations.generate(_test_combinations()) + def testMapNestedMaskedTensorWithDataclassOutput(self, apply_map): + dataset = dataset_ops.Dataset.range(10) + dataset = apply_map( + dataset, + lambda x: NestedMaskedTensor( # pylint: disable=g-long-lambda + mask=False, value=MaskedTensor(mask=True, value=x) + ), + ) + + # Return a MaskedTensor as the return value. + dataset = apply_map(dataset, lambda x: x.value) + dataset = apply_map(dataset, lambda x: 7 * x.value) + self.assertDatasetProduces( + dataset, + expected_output=[7 * x for x in range(10)], + ) + @combinations.generate(_test_combinations()) def testMapNamedtuple(self, apply_map): # construct dataset of tuples diff --git a/tensorflow/python/data/kernel_tests/zip_test.py b/tensorflow/python/data/kernel_tests/zip_test.py index c4f270f59648b2..c81c70a2cd8bd6 100644 --- a/tensorflow/python/data/kernel_tests/zip_test.py +++ b/tensorflow/python/data/kernel_tests/zip_test.py @@ -14,8 +14,9 @@ # ============================================================================== """Tests for `tf.data.Dataset.zip()`.""" import collections -from absl.testing import parameterized +import dataclasses +from absl.testing import parameterized import numpy as np from tensorflow.python.data.experimental.ops import random_access @@ -42,6 +43,23 @@ def _dataset_factory(components): return dataset_ops.Dataset.zip(datasets) +@dataclasses.dataclass +class MaskedNdarrayPair: + mask: bool + value1: np.ndarray + value2: np.ndarray + + def __tf_flatten__(self): + metadata = (self.mask,) + components = (self.value1, self.value2) + return metadata, components + + def __tf_unflatten__(self, metadata, components): + mask = metadata[0] + value1, value2 = components + return MaskedNdarrayPair(mask=mask, value1=value1, value2=value2) + + class ZipTest(test_base.DatasetTestBase, parameterized.TestCase): @combinations.generate(test_base.default_test_combinations()) @@ -112,6 +130,21 @@ def testNamedTuple(self): expected = [Foo(x=0, y=3), Foo(x=1, y=4), Foo(x=2, y=5)] self.assertDatasetProduces(dataset, expected) + @combinations.generate(test_base.default_test_combinations()) + def testDataclass(self): + mtp = MaskedNdarrayPair( + mask=True, + value1=dataset_ops.Dataset.range(3), + value2=dataset_ops.Dataset.range(3, 6), + ) + dataset = dataset_ops.Dataset.zip(mtp) + expected = [ + MaskedNdarrayPair(mask=True, value1=0, value2=3), + MaskedNdarrayPair(mask=True, value1=1, value2=4), + MaskedNdarrayPair(mask=True, value1=2, value2=5), + ] + self.assertDatasetProduces(dataset, expected) + @combinations.generate(test_base.default_test_combinations()) def testAttrs(self): if attr is None: diff --git a/tensorflow/python/data/util/BUILD b/tensorflow/python/data/util/BUILD index 965106bad275b1..abb9dbcded44f9 100644 --- a/tensorflow/python/data/util/BUILD +++ b/tensorflow/python/data/util/BUILD @@ -26,6 +26,7 @@ py_strict_test( "//tensorflow/python/data/kernel_tests:test_base", "//tensorflow/python/framework:combinations", "//tensorflow/python/framework:constant_op", + "//tensorflow/python/framework:ops", "//tensorflow/python/framework:sparse_tensor", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:math_ops", @@ -91,6 +92,7 @@ py_strict_library( "//tensorflow/python/types:internal", "//tensorflow/python/util:compat", "//tensorflow/python/util:deprecation", + "//tensorflow/python/util:nest_util", "//tensorflow/python/util:tf_export", "@wrapt", ], @@ -110,6 +112,7 @@ py_strict_test( "//tensorflow/python/framework:combinations", "//tensorflow/python/framework:constant_op", "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:ops", "//tensorflow/python/framework:sparse_tensor", "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_shape", diff --git a/tensorflow/python/data/util/nest_test.py b/tensorflow/python/data/util/nest_test.py index edc19d3496c8b7..cdd0f01a938cd8 100644 --- a/tensorflow/python/data/util/nest_test.py +++ b/tensorflow/python/data/util/nest_test.py @@ -15,13 +15,16 @@ """Tests for utilities working with arbitrarily nested structures.""" import collections -import numpy as np +import dataclasses + from absl.testing import parameterized +import numpy as np from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.util import nest from tensorflow.python.framework import combinations from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops @@ -29,6 +32,22 @@ from tensorflow.python.platform import test +@dataclasses.dataclass +class MaskedTensor: + mask: bool + value: ops.Tensor + + def __tf_flatten__(self): + metadata = (self.mask,) + components = (self.value,) + return metadata, components + + def __tf_unflatten__(self, metadata, components): + mask = metadata[0] + value = components[0] + return MaskedTensor(mask=mask, value=value) + + class NestTest(test_base.DatasetTestBase, parameterized.TestCase): @combinations.generate(test_base.default_test_combinations()) @@ -66,6 +85,89 @@ def testFlattenAndPack(self): with self.assertRaises(ValueError): nest.pack_sequence_as([5, 6, [7, 8]], ["a", "b", "c"]) + @combinations.generate(test_base.default_test_combinations()) + def testDataclassIsNested(self): + mt = MaskedTensor(mask=True, value=constant_op.constant([1])) + self.assertTrue(nest.is_nested(mt)) + + @combinations.generate(test_base.default_test_combinations()) + def testFlattenDataclass(self): + mt = MaskedTensor(mask=True, value=constant_op.constant([1])) + leaves = nest.flatten(mt) + self.assertLen(leaves, 1) + self.assertAllEqual(leaves[0], [1]) + + @combinations.generate(test_base.default_test_combinations()) + def testPackDataclass(self): + mt = MaskedTensor(mask=True, value=constant_op.constant([1])) + leaves = nest.flatten(mt) + reconstructed_mt = nest.pack_sequence_as(mt, leaves) + self.assertIsInstance(reconstructed_mt, MaskedTensor) + self.assertEqual(reconstructed_mt.mask, mt.mask) + self.assertAllEqual(reconstructed_mt.value, mt.value) + + mt2 = MaskedTensor(mask=False, value=constant_op.constant([2])) + reconstructed_mt = nest.pack_sequence_as(mt2, leaves) + self.assertIsInstance(reconstructed_mt, MaskedTensor) + self.assertFalse(reconstructed_mt.mask) + self.assertAllEqual(reconstructed_mt.value, [1]) + + @combinations.generate(test_base.default_test_combinations()) + def testDataclassMapStructure(self): + mt = MaskedTensor(mask=True, value=constant_op.constant([1])) + mt_doubled = nest.map_structure(lambda x: x * 2, mt) + self.assertIsInstance(mt_doubled, MaskedTensor) + self.assertEqual(mt_doubled.mask, True) + self.assertAllEqual(mt_doubled.value, [2]) + + @combinations.generate(test_base.default_test_combinations()) + def testDataclassAssertSameStructure(self): + mt1 = MaskedTensor(mask=True, value=constant_op.constant([1])) + mt2 = MaskedTensor(mask=False, value=constant_op.constant([2])) + nest.assert_same_structure(mt1, mt2) + + mt3 = (1, 2) + + with self.assertRaisesRegex( # pylint: disable=g-error-prone-assert-raises + TypeError, + "don't have the same nested structure", + ): + nest.assert_same_structure(mt1, mt3) + + class SubMaskedTensor(MaskedTensor): + pass + + mt_subclass = SubMaskedTensor(mask=True, value=constant_op.constant([1])) + nest.assert_same_structure(mt1, mt_subclass, check_types=False) + with self.assertRaisesRegex( # pylint: disable=g-error-prone-assert-raises + TypeError, + "don't have the same sequence type", + ): + nest.assert_same_structure(mt1, mt_subclass) + + @combinations.generate(test_base.default_test_combinations()) + def testDataclassAssertShallowStructure(self): + mt = MaskedTensor(mask=True, value=constant_op.constant([1])) + structure1 = ("a", "b") + structure2 = (mt, "c") + nest.assert_shallow_structure(structure1, structure2) + + structure3 = (mt, "d", "e") + + with self.assertRaisesRegex( # pylint: disable=g-error-prone-assert-raises + ValueError, + "don't have the same sequence length", + ): + nest.assert_shallow_structure(structure1, structure3) + + structure4 = {"a": mt, "b": "c"} + nest.assert_shallow_structure(structure1, structure4, check_types=False) + with self.assertRaisesRegex( # pylint: disable=g-error-prone-assert-raises + TypeError, + "don't have the same sequence type", + ): + nest.assert_shallow_structure(structure1, structure4) + @combinations.generate(test_base.default_test_combinations()) def testFlattenDictOrder(self): """`flatten` orders dicts by key, including OrderedDicts.""" diff --git a/tensorflow/python/data/util/structure.py b/tensorflow/python/data/util/structure.py index 56d89095da40bc..43dfb7456c05e7 100644 --- a/tensorflow/python/data/util/structure.py +++ b/tensorflow/python/data/util/structure.py @@ -34,6 +34,7 @@ from tensorflow.python.types import internal from tensorflow.python.util import deprecation from tensorflow.python.util.compat import collections_abc +from tensorflow.python.util.nest_util import CustomNestProtocol from tensorflow.python.util.tf_export import tf_export @@ -493,6 +494,12 @@ def type_spec_from_value(element, use_fallback=True): type_spec_from_value(getattr(element, a.name)) for a in attrs ]) + if isinstance(element, CustomNestProtocol): + # pylint: disable=protected-access + metadata, children = element.__tf_flatten__() + return element.__tf_unflatten__(metadata, type_spec_from_value(children)) + # pylint: enable=protected-access + if use_fallback: # As a fallback try converting the element to a tensor. try: diff --git a/tensorflow/python/data/util/structure_test.py b/tensorflow/python/data/util/structure_test.py index 6832a3b3d7c51b..43fb07ca22fa36 100644 --- a/tensorflow/python/data/util/structure_test.py +++ b/tensorflow/python/data/util/structure_test.py @@ -15,11 +15,12 @@ """Tests for utilities working with arbitrarily nested structures.""" import collections +import dataclasses import functools +from absl.testing import parameterized import numpy as np import wrapt -from absl.testing import parameterized from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops @@ -28,6 +29,7 @@ from tensorflow.python.framework import combinations from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor from tensorflow.python.framework import tensor_shape @@ -492,6 +494,22 @@ def reduce_fn(x, y): return functools.reduce(reduce_fn, cases, []) +@dataclasses.dataclass +class MaskedTensor: + mask: bool + value: ops.Tensor + + def __tf_flatten__(self): + metadata = (self.mask,) + components = (self.value,) + return metadata, components + + def __tf_unflatten__(self, metadata, components): + mask = metadata[0] + value = components[0] + return MaskedTensor(mask=mask, value=value) + + # TODO(jsimsa): Add tests for OptionalStructure and DatasetStructure. class StructureTest(test_base.DatasetTestBase, parameterized.TestCase): @@ -954,6 +972,23 @@ def testTypeSpecNotCompatible(self): self.assertEqual(test_obj, test_obj.most_specific_compatible_shape(test_obj)) + @combinations.generate(test_base.default_test_combinations()) + def testDataclasses(self): + mt = MaskedTensor(mask=True, value=constant_op.constant([1])) + + mt_type_spec = structure.type_spec_from_value(mt) + self.assertEqual(mt_type_spec.mask, mt.mask) + self.assertEqual( + mt_type_spec.value, structure.type_spec_from_value(mt.value) + ) + + mt2 = MaskedTensor(mask=True, value=constant_op.constant([2])) + mt3 = MaskedTensor(mask=False, value=constant_op.constant([1])) + mt2_type_spec = structure.type_spec_from_value(mt2) + mt3_type_spec = structure.type_spec_from_value(mt3) + self.assertEqual(mt_type_spec, mt2_type_spec) + self.assertNotEqual(mt_type_spec, mt3_type_spec) + class CustomMap(collections_abc.Mapping): """Custom, immutable map.""" From a03af28ec341fe8407a7bfd177aefbe7ea31ba05 Mon Sep 17 00:00:00 2001 From: Ilia Sergachev Date: Thu, 13 Jul 2023 15:53:17 -0700 Subject: [PATCH 286/376] [XLA:GPU] Implement the dimension analysis of broadcasts in Triton GEMM rewriter. This will not modify the fusions yet because fusing broadcasts is temporarily blocked. PiperOrigin-RevId: 547951667 --- .../xla/service/gpu/gemm_rewriter_triton.cc | 35 ++++---- .../service/gpu/gemm_rewriter_triton_test.cc | 82 +++++++++++++++++++ 2 files changed, 103 insertions(+), 14 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton.cc b/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton.cc index 9e0bfe08f4d3de..d133928c93879f 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton.cc @@ -228,11 +228,17 @@ class DimensionOrder { FusionDecision HandleInstruction(const HloInstruction* hlo, TransformDirection direction) { VLOG(7) << hlo->ToString(); - if (hlo->opcode() == HloOpcode::kParameter) { + if (hlo->opcode() == HloOpcode::kParameter || + hlo_query::IsScalarConstant(hlo)) { return FusionDecision{}; } else if (hlo->opcode() == HloOpcode::kTranspose || hlo->opcode() == HloOpcode::kCopy) { - return HandleCopyOrTranspose(hlo, direction); + return HandleCopyOrTransposeOrBroadcast(hlo, direction); + } else if (hlo->opcode() == HloOpcode::kBroadcast) { + if (direction != TransformDirection::kOutputToInput) { + return "Unsupported broadcast direction."; + } + return HandleCopyOrTransposeOrBroadcast(hlo, direction); } else if (hlo->operand_count() > 0 && IsTritonSupportedElementwise( hlo->opcode(), hlo->operand(0)->shape().element_type())) { @@ -245,11 +251,6 @@ class DimensionOrder { return "Non-bitcast reshape."; } return HandleBitcast(hlo, direction); - } else if (hlo_query::IsScalarConstant(hlo) || - hlo_query::IsBroadcastOfScalarConstant(*hlo)) { - // Dimension order collapses on a scalar, for simplicity leave it equal - // to the output one for now. - return FusionDecision{}; } return "Unimplemented instruction."; } @@ -283,8 +284,8 @@ class DimensionOrder { private: // See HandleInstruction() for the general description of Handle*(). FusionDecision HandleBitcast(const HloInstruction*, TransformDirection); - FusionDecision HandleCopyOrTranspose(const HloInstruction*, - TransformDirection); + FusionDecision HandleCopyOrTransposeOrBroadcast(const HloInstruction*, + TransformDirection); DimOrderVector dim_order_; const int64_t splittable_dimension_index_; @@ -463,8 +464,8 @@ FusionDecision DimensionOrder::HandleBitcast(const HloInstruction* hlo, return FusionDecision{}; } -FusionDecision DimensionOrder::HandleCopyOrTranspose( - const HloInstruction* hlo, TransformDirection direction) { +FusionDecision DimensionOrder::HandleCopyOrTransposeOrBroadcast( + const HloInstruction* hlo, const TransformDirection direction) { // Every HLO dimension can correspond to a group of subdimensions in // dim_order_. For the easier handling of permutations: group dim_order_ by // dimension, apply permutations, then finally remove the grouping. @@ -498,16 +499,22 @@ FusionDecision DimensionOrder::HandleCopyOrTranspose( // Source logical -> destination logical. std::vector dst_logical; if (hlo->opcode() == HloOpcode::kTranspose) { - auto transpose = ::xla::Cast(hlo); + const auto transpose = Cast(hlo); std::vector permutation(transpose->dimensions().cbegin(), transpose->dimensions().cend()); if (direction == TransformDirection::kInputToOutput) { permutation = InversePermutation(permutation); } - dst_logical.resize(src_logical.size()); - for (int i = 0; i < src_logical.size(); ++i) { + dst_logical.resize(permutation.size()); + for (int i = 0; i < permutation.size(); ++i) { dst_logical[permutation[i]] = src_logical[i]; } + } else if (hlo->opcode() == HloOpcode::kBroadcast) { + const auto broadcast = Cast(hlo); + dst_logical.resize(broadcast->dimensions().size()); + for (int i = 0; i < broadcast->dimensions().size(); ++i) { + dst_logical[i] = src_logical[broadcast->dimensions()[i]]; + } } else { // Copy preserves the logical shape, just permutes the layout. CHECK(ShapeUtil::SameDimensions(src->shape(), dst->shape())); diff --git a/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton_test.cc b/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton_test.cc index 53f5e8c163e3d5..c86e78b14389ec 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton_test.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton_test.cc @@ -439,6 +439,88 @@ ENTRY e { /*subfragments=*/ElementsAre(3)))); } +TEST_F(TritonDotAnalysisTest, InputBroadcastFromScalarIsHandled) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +HloModule t + +triton_dot { + p0 = bf16[24,4]{1,0} parameter(0) + p1 = bf16[] parameter(1) + p1b = bf16[4,3] broadcast(p1) + ROOT dot = bf16[24,3]{1,0} dot(p0, p1b), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +} + +ENTRY e { + p0 = bf16[24,4]{1,0} parameter(0) + p1 = bf16[] parameter(1) + ROOT r = bf16[24,3]{1,0} fusion(p0, p1), kind=kCustom, + calls=triton_dot +})")); + const HloComputation* dot_computation = + module->entry_computation()->root_instruction()->called_computations()[0]; + const HloInstruction* scalar = dot_computation->parameter_instruction(1); + const DotFusionAnalysis analysis(dot_computation); + EXPECT_EQ(analysis.IterSpec(DotFusionAnalysis::Scope::RHS, scalar, 0)->size(), + 1); + EXPECT_THAT(*analysis.IterSpec(DotFusionAnalysis::Scope::RHS, scalar, 0), + ElementsAre(FieldsAre(/*stride=*/0, /*count=*/1, + /*subfragments=*/ElementsAre(1)))); +} + +TEST_F(TritonDotAnalysisTest, InputBroadcastFromVectorIsHandled) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +HloModule t + +triton_dot { + p0 = bf16[24,4]{1,0} parameter(0) + p1 = bf16[4] parameter(1) + p1b = bf16[4,3] broadcast(p1), dimensions={0} + ROOT dot = bf16[24,3]{1,0} dot(p0, p1b), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +} + +ENTRY e { + p0 = bf16[24,4]{1,0} parameter(0) + p1 = bf16[4] parameter(1) + ROOT r = bf16[24,3]{1,0} fusion(p0, p1), kind=kCustom, + calls=triton_dot +})")); + const HloComputation* dot_computation = + module->entry_computation()->root_instruction()->called_computations()[0]; + const HloInstruction* vector = dot_computation->parameter_instruction(1); + const DotFusionAnalysis analysis(dot_computation); + EXPECT_EQ(analysis.IterSpec(DotFusionAnalysis::Scope::RHS, vector, 0)->size(), + 1); + EXPECT_THAT(*analysis.IterSpec(DotFusionAnalysis::Scope::RHS, vector, 0), + ElementsAre(FieldsAre(/*stride=*/1, /*count=*/4, + /*subfragments=*/ElementsAre(4)))); +} + +TEST_F(TritonDotAnalysisTest, OutputBroadcastIsNotAccepted) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +HloModule t + +ENTRY e { + p0 = f16[1,35] parameter(0) + p0c = bf16[1,35] convert(p0) + p1 = bf16[35,1] parameter(1) + dot = bf16[1,1] dot(p0c, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} + b = bf16[] bitcast(dot) + ROOT bc = bf16[100] broadcast(b) +})")); + EXPECT_TRUE(GemmRewriterTriton(se::CudaComputeCapability{ + se::CudaComputeCapability::AMPERE, 0}) + .Run(module.get()) + .value()); + EXPECT_EQ(module->entry_computation()->root_instruction()->opcode(), + HloOpcode::kBroadcast); +} + using SplitKTest = HloTestBase; class SplitKTestWithMorePreciseReduction From b8fecf01c57d9a463fd8ccaa43c71d87008b21e3 Mon Sep 17 00:00:00 2001 From: Rahul Joshi Date: Thu, 13 Jul 2023 15:55:35 -0700 Subject: [PATCH 287/376] [NFC] Unify the 2 variant of MatchReduceScatter by using default values. PiperOrigin-RevId: 547952320 --- .../compiler/xla/service/reduce_scatter_utils.cc | 12 ------------ .../compiler/xla/service/reduce_scatter_utils.h | 14 +++++--------- 2 files changed, 5 insertions(+), 21 deletions(-) diff --git a/tensorflow/compiler/xla/service/reduce_scatter_utils.cc b/tensorflow/compiler/xla/service/reduce_scatter_utils.cc index bc179cbfa83550..c9a55b6d1af4ec 100644 --- a/tensorflow/compiler/xla/service/reduce_scatter_utils.cc +++ b/tensorflow/compiler/xla/service/reduce_scatter_utils.cc @@ -262,18 +262,6 @@ bool IsPerIdOffset(const HloInstruction* offset, int64_t shard_size, } // namespace -std::optional MatchReduceScatter( - const HloAllReduceInstruction* ar, int64_t num_partitions, - int64_t num_replicas, bool allow_multiple_split_dims, - bool allow_intervening_reshape, int64_t min_rank) { - HloPredicate match_partition_id = HloPredicateIsOp; - HloPredicate match_replica_id = HloPredicateIsOp; - return MatchReduceScatter(ar, num_partitions, num_replicas, - allow_multiple_split_dims, - allow_intervening_reshape, min_rank, - match_partition_id, match_replica_id); -} - std::optional MatchReduceScatter( const HloAllReduceInstruction* ar, int64_t num_partitions, int64_t num_replicas, bool allow_multiple_split_dims, diff --git a/tensorflow/compiler/xla/service/reduce_scatter_utils.h b/tensorflow/compiler/xla/service/reduce_scatter_utils.h index 5ed64fc864b603..bbebe3e2ac132d 100644 --- a/tensorflow/compiler/xla/service/reduce_scatter_utils.h +++ b/tensorflow/compiler/xla/service/reduce_scatter_utils.h @@ -16,7 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_REDUCE_SCATTER_UTILS_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_REDUCE_SCATTER_UTILS_H_ -#include +#include +#include #include "tensorflow/compiler/xla/hlo/ir/hlo_instructions.h" @@ -35,14 +36,9 @@ struct ReduceScatterSpec { std::optional MatchReduceScatter( const HloAllReduceInstruction* ar, int64_t num_partitions, int64_t num_replicas, bool allow_multiple_split_dims = false, - bool allow_intervening_reshape = false, int64_t min_rank = 1); - -// Matches the given all-reduce operation to a reduce-scatter pattern. -std::optional MatchReduceScatter( - const HloAllReduceInstruction* ar, int64_t num_partitions, - int64_t num_replicas, bool allow_multiple_split_dims, - bool allow_intervening_reshape, int64_t min_rank, - HloPredicate match_partition_id, HloPredicate match_replica_id); + bool allow_intervening_reshape = false, int64_t min_rank = 1, + HloPredicate match_partition_id = HloPredicateIsOp, + HloPredicate match_replica_id = HloPredicateIsOp); } // namespace xla From fc6e00d01a68651bf53dcd99cb2ae95369c5c44e Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Thu, 13 Jul 2023 15:58:12 -0700 Subject: [PATCH 288/376] [xla:gpu] Pass valid pointer for CUDA graph instantiation PiperOrigin-RevId: 547952935 --- .../xla/service/gpu/runtime/executable.cc | 18 ++++++++++++++++-- .../xla/service/gpu/runtime/graph_launch.cc | 13 ++++++++++++- .../xla/service/gpu/runtime/graph_launch.h | 4 ++++ .../xla/stream_executor/cuda/cuda_driver.cc | 6 +++++- 4 files changed, 37 insertions(+), 4 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/runtime/executable.cc b/tensorflow/compiler/xla/service/gpu/runtime/executable.cc index 1c11b72cd5e06e..49fe85dafc9e5f 100644 --- a/tensorflow/compiler/xla/service/gpu/runtime/executable.cc +++ b/tensorflow/compiler/xla/service/gpu/runtime/executable.cc @@ -416,9 +416,23 @@ Status GpuRuntimeExecutable::Execute( #if GOOGLE_CUDA // Instantiate all CUDA graphs before executing the main function. - if (debug_options_.xla_gpu_cuda_graph_num_runs_to_instantiate() < 0) { + if (debug_options_.xla_gpu_cuda_graph_num_runs_to_instantiate() < 0 && + !graph_instances_.InstantiatedAllGraphs(run_options, executable)) { + // To instantiate all Gpu graphs we have to pass a valid device pointer + // because some device operations in XLA (e.g. memcpy) query device + // information from a pointer. We have to find the largest allocation + // available, to guarantee that all memref slices are within bounds, + // otherwise we might get crashes from a Gpu driver. + void* device_ptr = temp_buffer.opaque(); + size_t device_ptr_size = temp_buffer.size(); + + for (unsigned i = 0; i < buffer_allocations.size(); ++i) { + auto mem = buffer_allocations.GetDeviceAddress(i); + if (mem.size() > device_ptr_size) device_ptr = mem.opaque(); + } + if (auto instantiated = graph_instances_.InstantiateAllGraphs( - run_options, executable, user_data, temp_buffer.opaque()); + run_options, executable, user_data, device_ptr); !instantiated.ok()) { return InternalError("Failed to instantiate CUDA graphs: %s", instantiated.message()); diff --git a/tensorflow/compiler/xla/service/gpu/runtime/graph_launch.cc b/tensorflow/compiler/xla/service/gpu/runtime/graph_launch.cc index 359011ab83a32a..4325fe67881d94 100644 --- a/tensorflow/compiler/xla/service/gpu/runtime/graph_launch.cc +++ b/tensorflow/compiler/xla/service/gpu/runtime/graph_launch.cc @@ -85,12 +85,23 @@ CapturedFunctionExecutionCount* CapturedFunctionExecutionCounts::operator()( return &counts_[executor]; } +bool GraphInstances::InstantiatedAllGraphs( + const ServiceExecutableRunOptions* run_options, + const Executable& executable) { + if (executable.num_functions() == 1) return true; + + absl::MutexLock lock(&mutex_); + return instantiated_.contains(run_options->stream()->parent()); +} + Status GraphInstances::InstantiateAllGraphs( const ServiceExecutableRunOptions* run_options, const Executable& executable, const CustomCall::UserData& user_data, void* ptr) { - absl::MutexLock lock(&mutex_); + // We have only "main" function in the executable. + if (executable.num_functions() == 1) return OkStatus(); + absl::MutexLock lock(&mutex_); se::StreamExecutor* executor = run_options->stream()->parent(); // All Gpu graphs are already instantiated for a given executor. diff --git a/tensorflow/compiler/xla/service/gpu/runtime/graph_launch.h b/tensorflow/compiler/xla/service/gpu/runtime/graph_launch.h index 96dfd6a1a3d3ef..5c3bc4b4867450 100644 --- a/tensorflow/compiler/xla/service/gpu/runtime/graph_launch.h +++ b/tensorflow/compiler/xla/service/gpu/runtime/graph_launch.h @@ -93,6 +93,10 @@ class GraphInstances { const runtime::CustomCall::UserData& user_data, void* ptr); + // Returns true if all Gpu graphs were already instantiated. + bool InstantiatedAllGraphs(const ServiceExecutableRunOptions* run_options, + const runtime::Executable& executable); + private: mutable absl::Mutex mutex_; absl::node_hash_map graphs_ diff --git a/tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc b/tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc index 2540b4003cb353..6607a9adbee136 100644 --- a/tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc +++ b/tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc @@ -1234,7 +1234,11 @@ GpuDriver::CreateMemoryHandle(GpuContext* context, uint64_t bytes) { return false; } - if (gpu_dst == 0 || gpu_src == 0) { + // In graph capture mode we never have operations that access peer memory, so + // we can always make a call to cuMemcpyDtoDAsync. + bool is_capturing = stream_capture_status == cudaStreamCaptureStatusActive; + + if ((gpu_dst == 0 || gpu_src == 0) || is_capturing) { // CreatedContexts::GetAnyContext() doesn't works when ptr == 0. // This happens when the size is 0. result = cuMemcpyDtoDAsync(gpu_dst, gpu_src, size, stream); From 306fb282eb9868c92d33f877f81761860b8a410f Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 13 Jul 2023 16:14:19 -0700 Subject: [PATCH 289/376] Update TFRT dependency to use revision http://github.com/tensorflow/runtime/commit/08a6b6ecfc7cce7d0c8388fe7a9c73352467091e. PiperOrigin-RevId: 547957091 --- third_party/tf_runtime/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/tf_runtime/workspace.bzl b/third_party/tf_runtime/workspace.bzl index 8d575144823d6c..f456dbaf7cae33 100644 --- a/third_party/tf_runtime/workspace.bzl +++ b/third_party/tf_runtime/workspace.bzl @@ -6,8 +6,8 @@ def repo(): """Imports TFRT.""" # Attention: tools parse and update these lines. - TFRT_COMMIT = "b325dedfa3a47df75e06da3640424c1bdb28dd3a" - TFRT_SHA256 = "bc4341e8c6d0deed35b662903a82008b21b88127aec053b1a250b92219f4f0c9" + TFRT_COMMIT = "08a6b6ecfc7cce7d0c8388fe7a9c73352467091e" + TFRT_SHA256 = "bb6f479caeba3b28f033a9a420b23cb00f9d235ac8df312b1a57fda1ef2f8039" tf_http_archive( name = "tf_runtime", From 4e7a7c4ff181ed69c2f325ef26af1ee82d3f1ee3 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 13 Jul 2023 16:21:13 -0700 Subject: [PATCH 290/376] Pattern to fuse/fold TFL_TransposeOp into TFL_BatchMatMulOp PiperOrigin-RevId: 547958746 --- .../compiler/mlir/lite/tests/optimize.mlir | 22 ++++++++++++++++ .../compiler/mlir/lite/transforms/optimize.cc | 11 ++++++++ .../mlir/lite/transforms/optimize_patterns.td | 26 +++++++++++++++++++ 3 files changed, 59 insertions(+) diff --git a/tensorflow/compiler/mlir/lite/tests/optimize.mlir b/tensorflow/compiler/mlir/lite/tests/optimize.mlir index 2f1561855e9d4e..f780e10f89b97c 100644 --- a/tensorflow/compiler/mlir/lite/tests/optimize.mlir +++ b/tensorflow/compiler/mlir/lite/tests/optimize.mlir @@ -633,6 +633,8 @@ func.func @FuseReshapeAroundBMMNagativeTest(%arg0: tensor<5x4x1x1024xf32>, %arg1 } // CHECK-LABEL: @FuseReshapeAroundBMMNagativeTest2 +// Checks that the pattern matcher FuseReshapesAroundBatchMatMulLHS does not get +// applied for this case that does not pass the constraint around input rank. func.func @FuseReshapeAroundBMMNagativeTest2(%arg0: tensor<2x1536xf32>) -> tensor<2x768xf32> { %cst = arith.constant dense_resource<__elided__> : tensor<3xi32> %cst_0 = arith.constant dense_resource<__elided__> : tensor<2xi32> @@ -664,6 +666,26 @@ func.func @FuseReshapeAroundBMMRHS(%arg0: tensor<1x3x6x5x1024xf32>) -> tensor<1x // CHECK: return %0 : tensor<1x3x6x5x8192xf32> } +// CHECK-LABEL: @FuseTransposeIntoBMM_RHS +func.func @FuseTransposeIntoBMM_RHS(%arg0: tensor<1x4x1440x256xf32>, %arg1: tensor<1x1440x256xf32>) -> tensor<1x4x1440x1440xf32> { + %cst_1 = arith.constant dense_resource<__elided__> : tensor<3xi32> + %32 = "tfl.transpose"(%arg1, %cst_1) : (tensor<1x1440x256xf32>, tensor<3xi32>) -> tensor<1x256x1440xf32> + %33 = "tfl.batch_matmul"(%arg0, %32) {adj_x = false, adj_y = false} : (tensor<1x4x1440x256xf32>, tensor<1x256x1440xf32>) -> tensor<1x4x1440x1440xf32> + return %33 : tensor<1x4x1440x1440xf32> + // CHECK: %0 = "tfl.batch_matmul"(%arg0, %arg1) {adj_x = false, adj_y = true} : (tensor<1x4x1440x256xf32>, tensor<1x1440x256xf32>) -> tensor<1x4x1440x1440xf32> + // CHECK: return %0 : tensor<1x4x1440x1440xf32> +} + +// CHECK-LABEL: @FuseTransposeIntoBMM_LHS +func.func @FuseTransposeIntoBMM_LHS(%arg0: tensor<1x4x1440x256xf32>, %arg1: tensor<1x1440x256xf32>) -> tensor<1x4x256x256xf32> { + %cst_1 = arith.constant dense_resource<__elided__> : tensor<3xi32> + %32 = "tfl.transpose"(%arg1, %cst_1) : (tensor<1x1440x256xf32>, tensor<3xi32>) -> tensor<1x256x1440xf32> + %33 = "tfl.batch_matmul"(%32, %arg0) {adj_x = false, adj_y = false} : (tensor<1x256x1440xf32>, tensor<1x4x1440x256xf32>) -> tensor<1x4x256x256xf32> + return %33 : tensor<1x4x256x256xf32> + // CHECK: %0 = "tfl.batch_matmul"(%arg1, %arg0) {adj_x = true, adj_y = false} : (tensor<1x1440x256xf32>, tensor<1x4x1440x256xf32>) -> tensor<1x4x256x256xf32> + // CHECK: return %0 : tensor<1x4x256x256xf32> +} + // CHECK-LABEL: @FuseFullyConnectedReshapeAddConst // FOLD-LABEL: @FuseFullyConnectedReshapeAddConst func.func @FuseFullyConnectedReshapeAddConst(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> tensor<40x40xf32> { diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize.cc b/tensorflow/compiler/mlir/lite/transforms/optimize.cc index 03a162f98af533..e0caac4e90490d 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/optimize.cc @@ -139,6 +139,17 @@ bool BroadcastDimsProductEqual(Value input, Value output, return (agg_value == output_shape[agg_start_idx]); } +// Return true if the product of dimension values of a subsection of the tensor +// is equal to the non-contracting dimension after a reshape +bool AreLastTwoDimsTransposed(Value input, Value output) { + ArrayRef input_shape = input.getType().cast().getShape(); + ArrayRef output_shape = + output.getType().cast().getShape(); + + return (input_shape.back() == output_shape[output_shape.size() - 2]) && + (input_shape[input_shape.size() - 2] == output_shape.back()); +} + // Returns whether the given type `a` is broadcast-compatible with `b`. bool IsBroadcastableElementsAttrAndType(Type a, Type b) { return OpTrait::util::getBroadcastedType(a, b) != Type(); diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td index 1a826cd75f3020..b056dc1c977347 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td @@ -57,6 +57,8 @@ class HasRank : Constraint< class FloatValueEquals : Constraint>; +class IsBoolAttrEqual : Constraint>; // Flattens a constant tensor to 1D. def FlattenTo1D : NativeCodeCall<"FlattenTo1D($0)">; @@ -1485,3 +1487,27 @@ def FuseReshapesAroundBatchMatMulLHS1: Pat< (BroadcastDimsProductEqual<1> $input, $initial_shape_change), (BroadcastDimsProductEqual<1> $final_shape_change, $bmm_tmp_output), (AreTensorSubSectionShapesEqual<1, 1> $input, $final_shape_change)]>; + +def AreLastTwoDimsTransposed : Constraint>; + +// Fuse redundant TFL_TransposeOp into TFL_BatchMatMulOp +def FuseTransposeIntoBatchMatMulRHS: Pat< + (TFL_BatchMatMulOp $lhs, + (TFL_TransposeOp:$transposed_value $input, (Arith_ConstantOp $p0)), + $adj_x, $adj_y, $asymmetric_quantize_inputs), + (TFL_BatchMatMulOp $lhs, $input, $adj_x, ConstBoolAttrTrue, $asymmetric_quantize_inputs), + [(AreLastTwoDimsTransposed $input, $transposed_value), + (IsBoolAttrEqual<"false"> $adj_y), + (AreTensorSubSectionShapesEqual<0, 2> $input, $transposed_value)]>; + +// Fuse redundant TFL_TransposeOp into TFL_BatchMatMulOp +def FuseTransposeIntoBatchMatMulLHS: Pat< + (TFL_BatchMatMulOp + (TFL_TransposeOp:$transposed_value $input, (Arith_ConstantOp $p0)), + $rhs, $adj_x, $adj_y, $asymmetric_quantize_inputs), + (TFL_BatchMatMulOp $input, $rhs, ConstBoolAttrTrue, $adj_y, $asymmetric_quantize_inputs), + [(AreLastTwoDimsTransposed $input, $transposed_value), + (IsBoolAttrEqual<"false"> $adj_x), + (AreTensorSubSectionShapesEqual<0, 2> $input, $transposed_value)]>; + From 65d75a7955d324935bbfec274fd9309394d67445 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Thomas=20K=C3=B6ppe?= Date: Thu, 13 Jul 2023 16:36:50 -0700 Subject: [PATCH 291/376] Internal change only. PiperOrigin-RevId: 547962474 --- tensorflow/tensorflow.bzl | 16 ++++++++++++++-- third_party/gpus/cuda/build_defs.bzl.tpl | 11 +++++++++-- third_party/tensorrt/build_defs.bzl.tpl | 4 ++++ 3 files changed, 27 insertions(+), 4 deletions(-) diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl index c3885d0bd84d18..ca57f9081e4643 100644 --- a/tensorflow/tensorflow.bzl +++ b/tensorflow/tensorflow.bzl @@ -29,11 +29,13 @@ load( load( "@local_config_tensorrt//:build_defs.bzl", "if_tensorrt", + "if_tensorrt_exec", ) load( "@local_config_cuda//cuda:build_defs.bzl", "cuda_library", "if_cuda", + "if_cuda_exec", ) load( "@local_config_rocm//rocm:build_defs.bzl", @@ -463,6 +465,16 @@ def tf_copts( }) ) +def tf_copts_exec( + android_optimization_level_override = "-O2", + is_external = False, + allow_exceptions = False): + return tf_copts( + android_optimization_level_override, + is_external, + allow_exceptions, + ) + if_cuda_exec(["-DGOOGLE_CUDA=1"]) + if_tensorrt_exec(["-DGOOGLE_TENSORRT=1"]) + def tf_openmp_copts(): # We assume when compiling on Linux gcc/clang will be used and MSVC on Windows return select({ @@ -555,7 +567,7 @@ def tf_gen_op_libs( for n in op_lib_names: cc_library( name = n + "_op_lib", - copts = tf_copts(is_external = is_external), + copts = tf_copts_exec(is_external = is_external), features = features, srcs = [sub_directory + n + ".cc"], deps = deps + [clean_dep("//tensorflow/core:framework")], @@ -1389,7 +1401,7 @@ def tf_gen_op_wrapper_py( deps = [str(Label("//tensorflow/core:" + name + "_op_lib"))] tf_cc_binary( name = tool_name, - copts = copts + tf_copts(), + copts = copts + tf_copts_exec(), linkopts = if_not_windows(["-lm", "-Wl,-ldl"]) + cc_linkopts, linkstatic = 1, # Faster to link this one-time-use binary dynamically visibility = [clean_dep("//tensorflow:internal")], diff --git a/third_party/gpus/cuda/build_defs.bzl.tpl b/third_party/gpus/cuda/build_defs.bzl.tpl index 71acfa7cb7d629..189d3e3e784003 100644 --- a/third_party/gpus/cuda/build_defs.bzl.tpl +++ b/third_party/gpus/cuda/build_defs.bzl.tpl @@ -4,7 +4,6 @@ def if_cuda(if_true, if_false = []): Returns a select statement which evaluates to if_true if we're building with CUDA enabled. Otherwise, the select statement evaluates to if_false. - """ return select({ "@local_config_cuda//:is_cuda_enabled": if_true, @@ -16,13 +15,21 @@ def if_cuda_clang(if_true, if_false = []): Returns a select statement which evaluates to if_true if we're building with cuda-clang. Otherwise, the select statement evaluates to if_false. - """ return select({ "@local_config_cuda//cuda:using_clang": if_true, "//conditions:default": if_false }) +def if_cuda_exec(if_true, if_false = []): + """Synonym for if_cuda. + + Selects if_true both in target and in exec configurations. In principle, + if_cuda would only need to select if_true in a target configuration, but + not in an exec configuration, but this is not currently implemented. + """ + return if_cuda(if_true, if_false) + def cuda_compiler(if_cuda_clang, if_nvcc, neither = []): """Shorthand for select()'ing on wheteher we're building with cuda-clang or nvcc. diff --git a/third_party/tensorrt/build_defs.bzl.tpl b/third_party/tensorrt/build_defs.bzl.tpl index 6d00513827b380..83fcc7d69717b1 100644 --- a/third_party/tensorrt/build_defs.bzl.tpl +++ b/third_party/tensorrt/build_defs.bzl.tpl @@ -3,3 +3,7 @@ def if_tensorrt(if_true, if_false=[]): """Tests whether TensorRT was enabled during the configure process.""" return %{if_tensorrt} + +def if_tensorrt_exec(if_true, if_false=[]): + """Synonym for if_tensorrt.""" + return %{if_tensorrt} From c6dcf8e91a46d60b898dacd2f8e94b6e46a706a4 Mon Sep 17 00:00:00 2001 From: Fiona Lang Date: Thu, 13 Jul 2023 16:45:54 -0700 Subject: [PATCH 292/376] Update ops.Tensor references to //third_party/tensorflow/python/framework/tensor.py. PiperOrigin-RevId: 547964572 --- tensorflow/lite/python/util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/lite/python/util.py b/tensorflow/lite/python/util.py index 376f4d8e168526..c819e3c67df39f 100644 --- a/tensorflow/lite/python/util.py +++ b/tensorflow/lite/python/util.py @@ -173,7 +173,7 @@ def set_tensor_shapes(tensors, shapes): """Sets Tensor shape for each tensor if the shape is defined. Args: - tensors: TensorFlow ops.Tensor. + tensors: TensorFlow tensor.Tensor. shapes: Dict of strings representing input tensor names to list of integers representing input shapes (e.g., {"foo": : [1, 16, 16, 3]}). From 125e1300ce1e3f2676189b37ef220e18a622105c Mon Sep 17 00:00:00 2001 From: Fiona Lang Date: Thu, 13 Jul 2023 16:56:06 -0700 Subject: [PATCH 293/376] Update ops.Tensor references to //third_party/tensorflow/python/framework/tensor.py. PiperOrigin-RevId: 547966938 --- tensorflow/compiler/tests/BUILD | 2 ++ tensorflow/compiler/tests/pooling_ops_test.py | 5 +-- tensorflow/compiler/tests/sort_ops_test.py | 4 +-- tensorflow/python/autograph/utils/BUILD | 3 +- tensorflow/python/autograph/utils/misc.py | 3 +- .../python/autograph/utils/tensor_list.py | 4 +-- tensorflow/python/debug/lib/BUILD | 2 ++ .../python/debug/lib/debug_gradients.py | 3 +- .../python/debug/lib/debug_gradients_test.py | 33 ++++++++++--------- .../python/kernel_tests/control_flow/BUILD | 7 ++-- .../control_flow/control_flow_ops_py_test.py | 16 ++++----- tensorflow/python/kernel_tests/io_ops/BUILD | 5 ++- .../kernel_tests/io_ops/parsing_ops_test.py | 5 ++- tensorflow/python/kernel_tests/nn_ops/BUILD | 6 ++-- .../kernel_tests/nn_ops/rnn_cell_test.py | 6 ++-- .../python/kernel_tests/sparse_ops/BUILD | 5 ++- .../sparse_conditional_accumulator_test.py | 5 +-- tensorflow/python/ops/linalg/sparse/BUILD | 1 + .../linalg/sparse/sparse_csr_matrix_ops.py | 7 ++-- tensorflow/python/ops/signal/BUILD | 1 + tensorflow/python/ops/signal/mel_ops.py | 5 +-- tensorflow/python/training/experimental/BUILD | 1 + .../training/experimental/loss_scale_test.py | 5 +-- tensorflow/python/training/saving/BUILD | 1 + .../training/saving/saveable_object_util.py | 7 ++-- 25 files changed, 88 insertions(+), 54 deletions(-) diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 634baa45b4b4ba..2f0c5273f2aefa 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -1119,6 +1119,7 @@ tf_xla_py_strict_test( "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:errors", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:test_lib", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:nn_ops", @@ -1940,6 +1941,7 @@ tf_xla_py_strict_test( "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:function", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:test_lib", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:math_ops", diff --git a/tensorflow/compiler/tests/pooling_ops_test.py b/tensorflow/compiler/tests/pooling_ops_test.py index 3a7e22c02e54c5..bb760360687d9e 100644 --- a/tensorflow/compiler/tests/pooling_ops_test.py +++ b/tensorflow/compiler/tests/pooling_ops_test.py @@ -20,6 +20,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_nn_ops @@ -36,7 +37,7 @@ def NHWCToNCHW(input_tensor): Returns: the converted tensor or a shape array """ - if isinstance(input_tensor, ops.Tensor): + if isinstance(input_tensor, tensor.Tensor): return array_ops.transpose(input_tensor, [0, 3, 1, 2]) else: return [input_tensor[0], input_tensor[3], input_tensor[1], input_tensor[2]] @@ -51,7 +52,7 @@ def NCHWToNHWC(input_tensor): Returns: the converted tensor or a shape array """ - if isinstance(input_tensor, ops.Tensor): + if isinstance(input_tensor, tensor.Tensor): return array_ops.transpose(input_tensor, [0, 2, 3, 1]) else: return [input_tensor[0], input_tensor[2], input_tensor[3], input_tensor[1]] diff --git a/tensorflow/compiler/tests/sort_ops_test.py b/tensorflow/compiler/tests/sort_ops_test.py index 250dcc45fe14de..bbadb955356e0f 100644 --- a/tensorflow/compiler/tests/sort_ops_test.py +++ b/tensorflow/compiler/tests/sort_ops_test.py @@ -23,7 +23,7 @@ from tensorflow.compiler.tf2xla.python import xla from tensorflow.python.framework import dtypes from tensorflow.python.framework import function -from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops @@ -48,7 +48,7 @@ def _assertOpOutputMatchesExpected(self, op, args, expected): ] feeds = {placeholders[i]: args[i] for i in range(0, len(args))} output = op(*placeholders) - if isinstance(output, ops.Tensor): + if isinstance(output, tensor.Tensor): output = [output] results = session.run(output, feeds) diff --git a/tensorflow/python/autograph/utils/BUILD b/tensorflow/python/autograph/utils/BUILD index 254881dd92eb8a..d758c28801c315 100644 --- a/tensorflow/python/autograph/utils/BUILD +++ b/tensorflow/python/autograph/utils/BUILD @@ -19,7 +19,7 @@ py_strict_library( srcs = ["tensor_list.py"], visibility = ["//tensorflow:__subpackages__"], deps = [ - "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/ops:list_ops", "//tensorflow/python/ops:tensor_array_ops", ], @@ -51,6 +51,7 @@ py_strict_library( visibility = ["//tensorflow:__subpackages__"], deps = [ "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:math_ops", "//tensorflow/python/ops:math_ops_gen", diff --git a/tensorflow/python/autograph/utils/misc.py b/tensorflow/python/autograph/utils/misc.py index 7404ea5ec75c0a..d14b4758aba03f 100644 --- a/tensorflow/python/autograph/utils/misc.py +++ b/tensorflow/python/autograph/utils/misc.py @@ -15,6 +15,7 @@ """Miscellaneous utilities that don't fit anywhere else.""" from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_math_ops from tensorflow.python.ops import math_ops @@ -36,7 +37,7 @@ def alias_tensors(*args): """ def alias_if_tensor(a): - return array_ops.identity(a) if isinstance(a, ops.Tensor) else a + return array_ops.identity(a) if isinstance(a, tensor.Tensor) else a # TODO(mdan): Recurse into containers? # TODO(mdan): Anything we can do about variables? Fake a scope reuse? diff --git a/tensorflow/python/autograph/utils/tensor_list.py b/tensorflow/python/autograph/utils/tensor_list.py index c91b9be2868dac..c8bdf3ae982982 100644 --- a/tensorflow/python/autograph/utils/tensor_list.py +++ b/tensorflow/python/autograph/utils/tensor_list.py @@ -14,7 +14,7 @@ # ============================================================================== """A typed list in Python.""" -from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.ops import list_ops from tensorflow.python.ops import tensor_array_ops @@ -28,7 +28,7 @@ def dynamic_list_append(target, element): # It may be possible to use TensorList alone if the loop body will not # require wrapping it, although we'd have to think about an autoboxing # mechanism for lists received as parameter. - if isinstance(target, ops.Tensor): + if isinstance(target, tensor.Tensor): return list_ops.tensor_list_push_back(target, element) # Python targets (including TensorList): fallback to their original append. diff --git a/tensorflow/python/debug/lib/BUILD b/tensorflow/python/debug/lib/BUILD index c6cdfb38552ccd..37c99b30dd2056 100644 --- a/tensorflow/python/debug/lib/BUILD +++ b/tensorflow/python/debug/lib/BUILD @@ -149,6 +149,7 @@ py_strict_library( ":debug_data", ":debug_graphs", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/ops:array_ops_gen", "//tensorflow/python/ops:variables", ], @@ -418,6 +419,7 @@ cuda_py_strict_test( "//tensorflow/core:protos_all_py", "//tensorflow/python/client:session", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:test_lib", "//tensorflow/python/lib/io:lib", "//tensorflow/python/ops:gradients_impl", diff --git a/tensorflow/python/debug/lib/debug_gradients.py b/tensorflow/python/debug/lib/debug_gradients.py index 8d202c9e1a0e9b..529264b2c5525f 100644 --- a/tensorflow/python/debug/lib/debug_gradients.py +++ b/tensorflow/python/debug/lib/debug_gradients.py @@ -20,6 +20,7 @@ from tensorflow.python.debug.lib import debug_data from tensorflow.python.debug.lib import debug_graphs from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.ops import gen_array_ops from tensorflow.python.ops import variables @@ -332,7 +333,7 @@ def gradient_tensors(self): return self._gradient_tensors def _get_tensor_name(self, tensor): - if isinstance(tensor, (ops.Tensor, variables.Variable)): + if isinstance(tensor, (tensor_lib.Tensor, variables.Variable)): return tensor.name elif isinstance(tensor, str): return tensor diff --git a/tensorflow/python/debug/lib/debug_gradients_test.py b/tensorflow/python/debug/lib/debug_gradients_test.py index f84be38e0451e8..a3321be710dd8e 100644 --- a/tensorflow/python/debug/lib/debug_gradients_test.py +++ b/tensorflow/python/debug/lib/debug_gradients_test.py @@ -23,6 +23,7 @@ from tensorflow.python.debug.lib import debug_gradients from tensorflow.python.debug.lib import debug_utils from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.framework import test_util from tensorflow.python.lib.io import file_io from tensorflow.python.ops import gradients_impl @@ -68,17 +69,17 @@ def testIdentifyGradientGivesCorrectTensorObjectWithoutContextManager(self): # Fetch the gradient tensor with the x-tensor object. w_grad = grad_debugger.gradient_tensor(self.w) - self.assertIsInstance(w_grad, ops.Tensor) + self.assertIsInstance(w_grad, tensor.Tensor) self.assertAllClose(1.0, self.sess.run(w_grad)) # Fetch the gradient tensor with the x-tensor's name. w_grad = grad_debugger.gradient_tensor(self.w.name) - self.assertIsInstance(w_grad, ops.Tensor) + self.assertIsInstance(w_grad, tensor.Tensor) self.assertAllClose(1.0, self.sess.run(w_grad)) # Fetch the gradient tensor with the x-tensor name. w_grad = grad_debugger.gradient_tensor(self.w.name) - self.assertIsInstance(w_grad, ops.Tensor) + self.assertIsInstance(w_grad, tensor.Tensor) self.assertAllClose(1.0, self.sess.run(w_grad)) def testIdentifyGradientGivesCorrectTensorObjectWithTfGradients(self): @@ -99,17 +100,17 @@ def testIdentifyGradientGivesCorrectTensorObjectWithTfGradients(self): # Fetch the gradient tensor with the x-tensor object. w_grad = grad_debugger.gradient_tensor(self.w) - self.assertIsInstance(w_grad, ops.Tensor) + self.assertIsInstance(w_grad, tensor.Tensor) self.assertAllClose(1.0, self.sess.run(w_grad)) # Fetch the gradient tensor with the x-tensor's name. w_grad = grad_debugger.gradient_tensor(self.w.name) - self.assertIsInstance(w_grad, ops.Tensor) + self.assertIsInstance(w_grad, tensor.Tensor) self.assertAllClose(1.0, self.sess.run(w_grad)) # Fetch the gradient tensor with the x-tensor name. w_grad = grad_debugger.gradient_tensor(self.w.name) - self.assertIsInstance(w_grad, ops.Tensor) + self.assertIsInstance(w_grad, tensor.Tensor) self.assertAllClose(1.0, self.sess.run(w_grad)) def testCallingIdentifyGradientTwiceWithTheSameGradientsDebuggerErrors(self): @@ -137,8 +138,8 @@ def testIdentifyGradientWorksOnMultipleLosses(self): dz1_dy = grad_debugger_1.gradient_tensor(y) dz2_dy = grad_debugger_2.gradient_tensor(y) - self.assertIsInstance(dz1_dy, ops.Tensor) - self.assertIsInstance(dz2_dy, ops.Tensor) + self.assertIsInstance(dz1_dy, tensor.Tensor) + self.assertIsInstance(dz2_dy, tensor.Tensor) self.assertIsNot(dz1_dy, dz2_dy) self.sess.run(variables.global_variables_initializer()) @@ -187,7 +188,7 @@ def testIdentifyGradientTensorWorksWithGradientDescentOptimizer(self): # Fetch the gradient tensor with the x-tensor object. w_grad = grad_debugger.gradient_tensor(self.w) - self.assertIsInstance(w_grad, ops.Tensor) + self.assertIsInstance(w_grad, tensor.Tensor) self.assertAllClose(1.0, self.sess.run(w_grad)) def testWatchGradientsByXTensorNamesWorks(self): @@ -209,11 +210,11 @@ def testWatchGradientsByXTensorNamesWorks(self): self.assertAllClose(2.0, self.sess.run(v_grad)) w_grad = grad_debugger.gradient_tensor(self.w) - self.assertIsInstance(w_grad, ops.Tensor) + self.assertIsInstance(w_grad, tensor.Tensor) self.assertAllClose(1.0, self.sess.run(w_grad)) w_grad = grad_debugger.gradient_tensor("w:0") - self.assertIsInstance(w_grad, ops.Tensor) + self.assertIsInstance(w_grad, tensor.Tensor) self.assertAllClose(1.0, self.sess.run(w_grad)) def testWatchGradientsByXTensorNamesWorksWithoutContextManager(self): @@ -235,11 +236,11 @@ def testWatchGradientsByXTensorNamesWorksWithoutContextManager(self): self.assertAllClose(2.0, self.sess.run(v_grad)) w_grad = grad_debugger.gradient_tensor(self.w) - self.assertIsInstance(w_grad, ops.Tensor) + self.assertIsInstance(w_grad, tensor.Tensor) self.assertAllClose(1.0, self.sess.run(w_grad)) w_grad = grad_debugger.gradient_tensor("w:0") - self.assertIsInstance(w_grad, ops.Tensor) + self.assertIsInstance(w_grad, tensor.Tensor) self.assertAllClose(1.0, self.sess.run(w_grad)) def testWatchGradientsWorksOnRefTensor(self): @@ -272,7 +273,7 @@ def testWatchGradientsWorksOnMultipleTensors(self): self.assertEqual(2, len(grad_debugger.gradient_tensors())) self.assertIs(u_grad, grad_debugger.gradient_tensor("u:0")) - self.assertIsInstance(grad_debugger.gradient_tensor("w:0"), ops.Tensor) + self.assertIsInstance(grad_debugger.gradient_tensor("w:0"), tensor.Tensor) self.sess.run(variables.global_variables_initializer()) self.assertAllClose(1.0, self.sess.run( @@ -317,8 +318,8 @@ def testWatchGradientsByTensorCanWorkOnMultipleLosses(self): dz1_dy = grad_debugger_1.gradient_tensor(y) dz2_dy = grad_debugger_2.gradient_tensor(y) - self.assertIsInstance(dz1_dy, ops.Tensor) - self.assertIsInstance(dz2_dy, ops.Tensor) + self.assertIsInstance(dz1_dy, tensor.Tensor) + self.assertIsInstance(dz2_dy, tensor.Tensor) self.assertIsNot(dz1_dy, dz2_dy) self.sess.run(variables.global_variables_initializer()) diff --git a/tensorflow/python/kernel_tests/control_flow/BUILD b/tensorflow/python/kernel_tests/control_flow/BUILD index 74691d56c614d6..c8e29030a9dc73 100644 --- a/tensorflow/python/kernel_tests/control_flow/BUILD +++ b/tensorflow/python/kernel_tests/control_flow/BUILD @@ -72,11 +72,14 @@ cuda_py_strict_test( "//tensorflow/python/eager:def_function", "//tensorflow/python/eager:wrap_function", "//tensorflow/python/framework:constant_op", + "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:errors", - "//tensorflow/python/framework:for_generated_wrappers", + "//tensorflow/python/framework:function", "//tensorflow/python/framework:indexed_slices", + "//tensorflow/python/framework:ops", "//tensorflow/python/framework:sparse_tensor", - "//tensorflow/python/framework:tensor_spec", + "//tensorflow/python/framework:tensor", + "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:array_ops_gen", "//tensorflow/python/ops:array_ops_stack", diff --git a/tensorflow/python/kernel_tests/control_flow/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow/control_flow_ops_py_test.py index d1d79afa66a176..f7e8f846a8076b 100644 --- a/tensorflow/python/kernel_tests/control_flow/control_flow_ops_py_test.py +++ b/tensorflow/python/kernel_tests/control_flow/control_flow_ops_py_test.py @@ -42,8 +42,8 @@ from tensorflow.python.framework import indexed_slices from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops_stack @@ -177,7 +177,7 @@ def testRefIdentity(self): op = state_ops.assign(v, 9) v2 = control_flow_ops.with_dependencies([op], v) - self.assertTrue(isinstance(v2, ops.Tensor)) + self.assertTrue(isinstance(v2, tensor_lib.Tensor)) self.evaluate(variables.global_variables_initializer()) self.assertEqual(9, self.evaluate(v2)) @@ -2331,8 +2331,8 @@ def testWhileShapeInvariantTensorSpec(self): c = lambda i, _: i < 10 b = lambda i, x: (i + 1, array_ops_stack.stack([x, x])) shape_invariants = [ - tensor_spec.TensorSpec([], dtype=dtypes.int32), - tensor_spec.TensorSpec(None, dtype=dtypes.int32)] + tensor_lib.TensorSpec([], dtype=dtypes.int32), + tensor_lib.TensorSpec(None, dtype=dtypes.int32)] while_loop_tf.while_loop(c, b, [i, x], shape_invariants) # TODO(b/131265085) Remove this decorator when bug is fixed. @@ -2343,7 +2343,7 @@ def testWhileShapeInvariantWrongTypeSpecType(self): i = constant_op.constant(0) x = sparse_tensor.SparseTensor([[0]], [1.0], [10]) shape_invariants = [ - tensor_spec.TensorSpec([], dtype=dtypes.int32), + tensor_lib.TensorSpec([], dtype=dtypes.int32), sparse_tensor.SparseTensorSpec([None])] while_loop_tf.while_loop(c, b, [i, x], shape_invariants) @@ -3489,7 +3489,7 @@ def b(lv0, lv1, lv2): self.assertTrue(isinstance(r, list)) self.assertTrue(isinstance(r[0], named)) self.assertTrue(isinstance(r[1], tuple)) - self.assertTrue(isinstance(r[2], ops.Tensor)) + self.assertTrue(isinstance(r[2], tensor_lib.Tensor)) r_flattened = nest.flatten(r) self.assertEqual([100.0, 1.0, 102.0, 3.0, 4.0 + 100 * 2.0], @@ -4192,7 +4192,7 @@ def testOneValueCond(self): two = ops.convert_to_tensor(2, name="two") p = math_ops.greater_equal(c, 1) i = tf_cond.cond(p, lambda: one, lambda: two) - self.assertTrue(isinstance(i, ops.Tensor)) + self.assertTrue(isinstance(i, tensor_lib.Tensor)) # True case: c = 2 is >= 1 self.assertEqual([1], i.eval(feed_dict={c: 2})) @@ -4328,7 +4328,7 @@ def b(): return state_ops.assign(v, two) i = tf_cond.cond(p, a, b) - self.assertTrue(isinstance(i, ops.Tensor)) + self.assertTrue(isinstance(i, tensor_lib.Tensor)) self.evaluate(variables.global_variables_initializer()) self.assertEqual(0, self.evaluate(v)) diff --git a/tensorflow/python/kernel_tests/io_ops/BUILD b/tensorflow/python/kernel_tests/io_ops/BUILD index 66b71de8933cac..2ff24b665f3df8 100644 --- a/tensorflow/python/kernel_tests/io_ops/BUILD +++ b/tensorflow/python/kernel_tests/io_ops/BUILD @@ -79,9 +79,12 @@ tf_py_strict_test( "//tensorflow/core:protos_all_py", "//tensorflow/python/eager:context", "//tensorflow/python/framework:constant_op", + "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:errors", - "//tensorflow/python/framework:for_generated_wrappers", + "//tensorflow/python/framework:ops", "//tensorflow/python/framework:sparse_tensor", + "//tensorflow/python/framework:tensor", + "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/framework:tensor_util", "//tensorflow/python/framework:test_lib", "//tensorflow/python/ops:array_ops", diff --git a/tensorflow/python/kernel_tests/io_ops/parsing_ops_test.py b/tensorflow/python/kernel_tests/io_ops/parsing_ops_test.py index 87e0e537e22e27..aca17e18280440 100644 --- a/tensorflow/python/kernel_tests/io_ops/parsing_ops_test.py +++ b/tensorflow/python/kernel_tests/io_ops/parsing_ops_test.py @@ -31,6 +31,7 @@ from tensorflow.python.framework import errors_impl from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util from tensorflow.python.framework import test_util @@ -97,7 +98,9 @@ def _test(self, kwargs, expected_values=None, expected_err=None): serialized = kwargs["serialized"] batch_size = ( self.evaluate(serialized).size - if isinstance(serialized, ops.Tensor) else np.asarray(serialized).size) + if isinstance(serialized, tensor_lib.Tensor) + else np.asarray(serialized).size + ) for k, f in kwargs["features"].items(): if isinstance(f, parsing_ops.FixedLenFeature) and f.shape is not None: self.assertEqual(tuple(out[k].shape.as_list()), (batch_size,) + f.shape) diff --git a/tensorflow/python/kernel_tests/nn_ops/BUILD b/tensorflow/python/kernel_tests/nn_ops/BUILD index 3573c45194d559..466e8dc45ca365 100644 --- a/tensorflow/python/kernel_tests/nn_ops/BUILD +++ b/tensorflow/python/kernel_tests/nn_ops/BUILD @@ -728,10 +728,12 @@ cuda_py_strict_test( "//tensorflow/python/eager:context", "//tensorflow/python/eager:def_function", "//tensorflow/python/framework:constant_op", + "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:errors", - "//tensorflow/python/framework:for_generated_wrappers", + "//tensorflow/python/framework:ops", "//tensorflow/python/framework:random_seed", - "//tensorflow/python/framework:tensor_spec", + "//tensorflow/python/framework:tensor", + "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/framework:test_lib", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:array_ops_stack", diff --git a/tensorflow/python/kernel_tests/nn_ops/rnn_cell_test.py b/tensorflow/python/kernel_tests/nn_ops/rnn_cell_test.py index 527d33fc6cb732..0786182c1d6885 100644 --- a/tensorflow/python/kernel_tests/nn_ops/rnn_cell_test.py +++ b/tensorflow/python/kernel_tests/nn_ops/rnn_cell_test.py @@ -29,8 +29,8 @@ from tensorflow.python.framework import errors_impl from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed +from tensorflow.python.framework import tensor from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops_stack @@ -650,7 +650,7 @@ def _testStateTupleWithProjAndSequenceLength(self): self.assertEqual(len(outputs_notuple), len(inputs)) self.assertEqual(len(outputs_tuple), len(inputs)) self.assertTrue(isinstance(state_tuple, tuple)) - self.assertTrue(isinstance(state_notuple, ops.Tensor)) + self.assertTrue(isinstance(state_notuple, tensor.Tensor)) variables_lib.global_variables_initializer().run() input_value = np.random.randn(batch_size, input_size) @@ -3211,7 +3211,7 @@ def testSavedModel(self): with self.cached_session(): root = autotrackable.AutoTrackable() root.cell = rnn_cell_impl.LSTMCell(8) - @def_function.function(input_signature=[tensor_spec.TensorSpec([3, 8])]) + @def_function.function(input_signature=[tensor.TensorSpec([3, 8])]) def call(x): state = root.cell.zero_state(3, dtype=x.dtype) y, _ = root.cell(x, state) diff --git a/tensorflow/python/kernel_tests/sparse_ops/BUILD b/tensorflow/python/kernel_tests/sparse_ops/BUILD index 0cf22e4b073458..866d39a1aed6b2 100644 --- a/tensorflow/python/kernel_tests/sparse_ops/BUILD +++ b/tensorflow/python/kernel_tests/sparse_ops/BUILD @@ -51,9 +51,12 @@ tf_py_strict_test( srcs = ["sparse_conditional_accumulator_test.py"], deps = [ "//tensorflow/python/framework:constant_op", + "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:errors", - "//tensorflow/python/framework:for_generated_wrappers", "//tensorflow/python/framework:indexed_slices", + "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", + "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/framework:test_lib", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:data_flow_ops", diff --git a/tensorflow/python/kernel_tests/sparse_ops/sparse_conditional_accumulator_test.py b/tensorflow/python/kernel_tests/sparse_ops/sparse_conditional_accumulator_test.py index 45f51861695ce5..6b85e5ab719ea4 100644 --- a/tensorflow/python/kernel_tests/sparse_ops/sparse_conditional_accumulator_test.py +++ b/tensorflow/python/kernel_tests/sparse_ops/sparse_conditional_accumulator_test.py @@ -22,6 +22,7 @@ from tensorflow.python.framework import errors_impl from tensorflow.python.framework import indexed_slices from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops @@ -58,7 +59,7 @@ def testConstructor(self): with ops.Graph().as_default(): q = data_flow_ops.SparseConditionalAccumulator( dtypes_lib.float32, name="Q") - self.assertTrue(isinstance(q.accumulator_ref, ops.Tensor)) + self.assertTrue(isinstance(q.accumulator_ref, tensor.Tensor)) self.assertProtoEquals( """ name:'Q' op:'SparseConditionalAccumulator' @@ -81,7 +82,7 @@ def testConstructorWithShape(self): dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([1, 5, 2, 8])) - self.assertTrue(isinstance(q.accumulator_ref, ops.Tensor)) + self.assertTrue(isinstance(q.accumulator_ref, tensor.Tensor)) self.assertProtoEquals( """ name:'Q' op:'SparseConditionalAccumulator' diff --git a/tensorflow/python/ops/linalg/sparse/BUILD b/tensorflow/python/ops/linalg/sparse/BUILD index c14f3e9c0e1a83..fa87211b113e3b 100644 --- a/tensorflow/python/ops/linalg/sparse/BUILD +++ b/tensorflow/python/ops/linalg/sparse/BUILD @@ -80,6 +80,7 @@ py_strict_library( "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", "//tensorflow/python/framework:sparse_tensor", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:math_ops", diff --git a/tensorflow/python/ops/linalg/sparse/sparse_csr_matrix_ops.py b/tensorflow/python/ops/linalg/sparse/sparse_csr_matrix_ops.py index 3c50276a00bfb8..9e5ab10faaabd4 100644 --- a/tensorflow/python/ops/linalg/sparse/sparse_csr_matrix_ops.py +++ b/tensorflow/python/ops/linalg/sparse/sparse_csr_matrix_ops.py @@ -23,6 +23,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops @@ -96,7 +97,7 @@ def dense_shape_and_type(matrix): ValueError: if `matrix` lacks static handle data containing the dense shape and dtype. """ - if not isinstance(matrix, ops.Tensor): + if not isinstance(matrix, tensor_lib.Tensor): raise TypeError("matrix should be a tensor, but saw: %s" % (matrix,)) if matrix.dtype != dtypes.variant: raise TypeError( @@ -352,7 +353,9 @@ def _matrix(self): return self._csr_matrix def _from_matrix(self, matrix, handle_data=None): - assert isinstance(matrix, ops.Tensor) and matrix.dtype == dtypes.variant + assert ( + isinstance(matrix, tensor_lib.Tensor) and matrix.dtype == dtypes.variant + ) ret = type(self).__new__(type(self)) # pylint: disable=protected-access ret._dtype = self._dtype diff --git a/tensorflow/python/ops/signal/BUILD b/tensorflow/python/ops/signal/BUILD index 683c1e52e38e32..a87dd8bf525b32 100644 --- a/tensorflow/python/ops/signal/BUILD +++ b/tensorflow/python/ops/signal/BUILD @@ -52,6 +52,7 @@ py_strict_library( ":shape_ops", "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_util", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:math_ops", diff --git a/tensorflow/python/ops/signal/mel_ops.py b/tensorflow/python/ops/signal/mel_ops.py index bcb306f7873495..47d85859ddd9b4 100644 --- a/tensorflow/python/ops/signal/mel_ops.py +++ b/tensorflow/python/ops/signal/mel_ops.py @@ -16,6 +16,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops @@ -75,7 +76,7 @@ def _validate_arguments(num_mel_bins, sample_rate, if lower_edge_hertz >= upper_edge_hertz: raise ValueError('lower_edge_hertz %.1f >= upper_edge_hertz %.1f' % (lower_edge_hertz, upper_edge_hertz)) - if not isinstance(sample_rate, ops.Tensor): + if not isinstance(sample_rate, tensor.Tensor): if sample_rate <= 0.0: raise ValueError('sample_rate must be positive. Got: %s' % sample_rate) if upper_edge_hertz > sample_rate / 2: @@ -156,7 +157,7 @@ def linear_to_mel_weight_matrix(num_mel_bins=20, """ with ops.name_scope(name, 'linear_to_mel_weight_matrix') as name: # Convert Tensor `sample_rate` to float, if possible. - if isinstance(sample_rate, ops.Tensor): + if isinstance(sample_rate, tensor.Tensor): maybe_const_val = tensor_util.constant_value(sample_rate) if maybe_const_val is not None: sample_rate = maybe_const_val diff --git a/tensorflow/python/training/experimental/BUILD b/tensorflow/python/training/experimental/BUILD index 02c64add3f36a1..604c97765c0710 100644 --- a/tensorflow/python/training/experimental/BUILD +++ b/tensorflow/python/training/experimental/BUILD @@ -87,6 +87,7 @@ py_strict_test( "//tensorflow/python/eager:context", "//tensorflow/python/framework:constant_op", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:test_lib", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:check_ops", diff --git a/tensorflow/python/training/experimental/loss_scale_test.py b/tensorflow/python/training/experimental/loss_scale_test.py index eaa9a55d5c5a73..e42a95d65c2acb 100644 --- a/tensorflow/python/training/experimental/loss_scale_test.py +++ b/tensorflow/python/training/experimental/loss_scale_test.py @@ -22,6 +22,7 @@ from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops @@ -86,7 +87,7 @@ def test_serialization(self): @test_util.run_in_graph_and_eager_modes def test_call_type(self): scalar = loss_scale_module.FixedLossScale(123) - self.assertIsInstance(scalar(), ops.Tensor) + self.assertIsInstance(scalar(), tensor_lib.Tensor) @test_util.run_in_graph_and_eager_modes def test_repr(self): @@ -301,7 +302,7 @@ def test_get(self): @test_util.run_in_graph_and_eager_modes def test_call_type(self): scalar = loss_scale_module.DynamicLossScale() - self.assertIsInstance(scalar(), ops.Tensor) + self.assertIsInstance(scalar(), tensor_lib.Tensor) @parameterized.named_parameters(*TESTCASES) @test_util.run_in_graph_and_eager_modes diff --git a/tensorflow/python/training/saving/BUILD b/tensorflow/python/training/saving/BUILD index 52a0af997f706c..4538bbdf973775 100644 --- a/tensorflow/python/training/saving/BUILD +++ b/tensorflow/python/training/saving/BUILD @@ -51,6 +51,7 @@ py_strict_library( "//tensorflow/python/framework:device", "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_util", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:control_flow_ops_gen", diff --git a/tensorflow/python/training/saving/saveable_object_util.py b/tensorflow/python/training/saving/saveable_object_util.py index ecf4d319df58d7..c14361a23f2a14 100644 --- a/tensorflow/python/training/saving/saveable_object_util.py +++ b/tensorflow/python/training/saving/saveable_object_util.py @@ -23,6 +23,7 @@ from tensorflow.python.framework import device as pydev from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops @@ -94,7 +95,7 @@ class ResourceVariableSaveable(saveable_object.SaveableObject): def __init__(self, var, slice_spec, name): self._var_device = var.device self._var_shape = var.shape - if isinstance(var, ops.Tensor): + if isinstance(var, tensor_lib.Tensor): self.handle_op = var.op.inputs[0] tensor = var elif resource_variable_ops.is_resource_variable(var): @@ -145,7 +146,7 @@ def restore(self, restored_tensors, restored_shapes): def _tensor_comes_from_variable(v): - return isinstance(v, ops.Tensor) and v.op.type in _VARIABLE_OPS + return isinstance(v, tensor_lib.Tensor) and v.op.type in _VARIABLE_OPS def saveable_objects_for_op(op, name): @@ -589,7 +590,7 @@ def restore(self, restored_tensors, restored_shapes): if not ops.executing_eagerly_outside_functions() and any([ spec._tensor.op.type in _REF_VARIABLE_OPS for spec in self.specs - if isinstance(spec._tensor, ops.Tensor)]): + if isinstance(spec._tensor, tensor_lib.Tensor)]): return restore_fn(restored_tensor_dict) # pylint: enable=protected-access From 52b3b6877dfff7f7fec040d76ccf8ab9ee1fce6e Mon Sep 17 00:00:00 2001 From: Jieying Luo Date: Thu, 13 Jul 2023 17:02:22 -0700 Subject: [PATCH 294/376] Fix heap use after free for local variable c_options. PiperOrigin-RevId: 547968334 --- tensorflow/compiler/xla/pjrt/c/BUILD | 2 +- tensorflow/compiler/xla/pjrt/c/pjrt_c_api_gpu_test.cc | 11 ++++++----- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/tensorflow/compiler/xla/pjrt/c/BUILD b/tensorflow/compiler/xla/pjrt/c/BUILD index e83fce610ba5bf..d1531a0bf1c523 100644 --- a/tensorflow/compiler/xla/pjrt/c/BUILD +++ b/tensorflow/compiler/xla/pjrt/c/BUILD @@ -109,7 +109,7 @@ cc_library( xla_cc_test( name = "pjrt_c_api_gpu_test", srcs = ["pjrt_c_api_gpu_test.cc"], - tags = tf_cuda_tests_tags() + ["nodebug"], # TODO(b/291073132): Test failing in debug mode. + tags = tf_cuda_tests_tags(), deps = [ ":pjrt_c_api_gpu", ":pjrt_c_api_hdrs", diff --git a/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_gpu_test.cc b/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_gpu_test.cc index f8b6c7ae5351fa..285869bede2d28 100644 --- a/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_gpu_test.cc +++ b/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_gpu_test.cc @@ -119,9 +119,7 @@ std::unique_ptr<::pjrt::PJRT_KeyValueCallbackData> CreateTestCKVCallback( absl::StatusOr BuildCreateArg( ::pjrt::PJRT_KeyValueCallbackData* kv_callback_data, - const absl::flat_hash_map& options) { - TF_ASSIGN_OR_RETURN(std::vector c_options, - ::pjrt::ConvertToPjRtNamedValueList(options)); + std::vector& c_options) { PJRT_Client_Create_Args args; args.struct_size = PJRT_Client_Create_Args_STRUCT_SIZE; args.priv = nullptr; @@ -153,8 +151,11 @@ TEST(PjrtCApiGpuKVStoreTest, CreateClientWithKVCallback) { absl::flat_hash_map options = { {"num_nodes", static_cast(num_nodes)}, {"node_id", static_cast(i)}}; - TF_ASSERT_OK_AND_ASSIGN(PJRT_Client_Create_Args create_arg, - BuildCreateArg(kv_callback_data.get(), options)); + TF_ASSERT_OK_AND_ASSIGN(std::vector c_options, + ::pjrt::ConvertToPjRtNamedValueList(options)); + TF_ASSERT_OK_AND_ASSIGN( + PJRT_Client_Create_Args create_arg, + BuildCreateArg(kv_callback_data.get(), c_options)); PJRT_Error* error = api->PJRT_Client_Create(&create_arg); EXPECT_EQ(error, nullptr) << error->status.message(); From b27da3b328ee55026348eca8d36b7da6dc6f712b Mon Sep 17 00:00:00 2001 From: Changhui Lin Date: Thu, 13 Jul 2023 17:05:38 -0700 Subject: [PATCH 295/376] Remove GPU specific logic in tfrt_graph_execution_state. PiperOrigin-RevId: 547969073 --- .../tfrt/graph_executor/graph_executor.cc | 3 - .../core/tfrt/saved_model/saved_model.cc | 10 +- .../saved_model/saved_model_import_input.cc | 5 +- .../saved_model/saved_model_import_input.h | 3 +- tensorflow/core/tfrt/utils/BUILD | 13 - .../tfrt/utils/tfrt_graph_execution_state.cc | 224 -------- .../tfrt/utils/tfrt_graph_execution_state.h | 10 - .../utils/tfrt_graph_execution_state_test.cc | 489 ------------------ 8 files changed, 5 insertions(+), 752 deletions(-) diff --git a/tensorflow/core/tfrt/graph_executor/graph_executor.cc b/tensorflow/core/tfrt/graph_executor/graph_executor.cc index 4b33b4750ef4dd..03079e960dd081 100644 --- a/tensorflow/core/tfrt/graph_executor/graph_executor.cc +++ b/tensorflow/core/tfrt/graph_executor/graph_executor.cc @@ -448,9 +448,6 @@ StatusOr> GraphExecutor::Create( TfrtGraphExecutionState::Options graph_execution_state_options; graph_execution_state_options.run_placer_grappler_on_functions = options.run_placer_grappler_on_functions; - graph_execution_state_options.enable_tfrt_gpu = options.enable_tfrt_gpu; - graph_execution_state_options.use_bridge_for_gpu = - options.compile_options.use_bridge_for_gpu; options.compile_options.fuse_get_resource_ops_in_hoisting = !options.enable_mlrt; diff --git a/tensorflow/core/tfrt/saved_model/saved_model.cc b/tensorflow/core/tfrt/saved_model/saved_model.cc index 348d1b3dd903aa..b4c7b32dcf97dd 100644 --- a/tensorflow/core/tfrt/saved_model/saved_model.cc +++ b/tensorflow/core/tfrt/saved_model/saved_model.cc @@ -319,8 +319,7 @@ std::vector FindNamesForValidSignatures( StatusOr> ImportSavedModel( mlir::MLIRContext* context, const tensorflow::MetaGraphDef& meta_graph_def, const FallbackState& fallback_state, std::string saved_model_dir, - bool import_user_signatures, bool run_placer_grappler_on_functions, - bool enable_tfrt_gpu, bool use_bridge_for_gpu) { + bool import_user_signatures, bool run_placer_grappler_on_functions) { std::vector signature_names; if (import_user_signatures) { signature_names = FindNamesForValidSignatures(meta_graph_def); @@ -338,8 +337,7 @@ StatusOr> ImportSavedModel( TF_ASSIGN_OR_RETURN(auto import_input, TfrtSavedModelMLIRImportInput::Create( fallback_state, &meta_graph_def, /*debug_info=*/{}, - run_placer_grappler_on_functions, enable_tfrt_gpu, - use_bridge_for_gpu)); + run_placer_grappler_on_functions)); TF_ASSIGN_OR_RETURN( auto module, @@ -625,9 +623,7 @@ SavedModelImpl::LoadSavedModel(Options options, &context, meta_graph_def, *fallback_state, std::string(saved_model_dir), /*import_user_signatures=*/!options.enable_lazy_loading, - options.graph_execution_options.run_placer_grappler_on_functions, - options.graph_execution_options.enable_tfrt_gpu, - options.graph_execution_options.compile_options.use_bridge_for_gpu)); + options.graph_execution_options.run_placer_grappler_on_functions)); // TODO(b/278143179): Upload module w/o control flow. SymbolUids symbol_uids; symbol_uids.tf_symbol_uid = MaybeUploadMlirToXsymbol(mlir_module.get()); diff --git a/tensorflow/core/tfrt/saved_model/saved_model_import_input.cc b/tensorflow/core/tfrt/saved_model/saved_model_import_input.cc index bfd25ccb7699af..379ed634f98606 100644 --- a/tensorflow/core/tfrt/saved_model/saved_model_import_input.cc +++ b/tensorflow/core/tfrt/saved_model/saved_model_import_input.cc @@ -28,15 +28,12 @@ namespace tfrt_stub { StatusOr TfrtSavedModelMLIRImportInput::Create( const FallbackState& fallback_state, const MetaGraphDef* meta_graph_def, const GraphDebugInfo& debug_info, - bool run_placer_grappler_on_nested_functions, bool enable_tfrt_gpu, - bool use_bridge_for_gpu) { + bool run_placer_grappler_on_nested_functions) { DCHECK(meta_graph_def); TfrtGraphExecutionState::Options options; options.run_placer_grappler_on_functions = run_placer_grappler_on_nested_functions; - options.enable_tfrt_gpu = enable_tfrt_gpu; - options.use_bridge_for_gpu = use_bridge_for_gpu; TF_ASSIGN_OR_RETURN( auto graph_execution_state, TfrtGraphExecutionState::Create(options, meta_graph_def->graph_def(), diff --git a/tensorflow/core/tfrt/saved_model/saved_model_import_input.h b/tensorflow/core/tfrt/saved_model/saved_model_import_input.h index 3c1b9fca053ffb..f1913359935801 100644 --- a/tensorflow/core/tfrt/saved_model/saved_model_import_input.h +++ b/tensorflow/core/tfrt/saved_model/saved_model_import_input.h @@ -33,8 +33,7 @@ class TfrtSavedModelMLIRImportInput : public SavedModelMLIRImportInput { static StatusOr Create( const FallbackState& fallback_state, const MetaGraphDef* meta_graph_def, const GraphDebugInfo& debug_info, - bool run_placer_grappler_on_nested_functions = false, - bool enable_tfrt_gpu = false, bool use_bridge_for_gpu = false); + bool run_placer_grappler_on_nested_functions = false); TfrtSavedModelMLIRImportInput( const MetaGraphDef* meta_graph_def, const GraphDebugInfo& debug_info, diff --git a/tensorflow/core/tfrt/utils/BUILD b/tensorflow/core/tfrt/utils/BUILD index c1c83267233098..44e29ea288284f 100644 --- a/tensorflow/core/tfrt/utils/BUILD +++ b/tensorflow/core/tfrt/utils/BUILD @@ -198,8 +198,6 @@ cc_library( srcs = ["tfrt_graph_execution_state.cc"], hdrs = ["tfrt_graph_execution_state.h"], deps = [ - "//tensorflow/compiler/jit:common", - "//tensorflow/compiler/jit:compilation_passes", "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags", "//tensorflow/compiler/mlir/tensorflow:upgrade_graph", "//tensorflow/core:core_cpu_base", @@ -236,25 +234,14 @@ tf_cc_test( "//tensorflow/cc:array_ops", "//tensorflow/cc:cc_ops", "//tensorflow/cc:const_op", - "//tensorflow/cc:function_ops", "//tensorflow/cc:functional_ops", - "//tensorflow/cc:math_ops", - "//tensorflow/cc:resource_variable_ops", - "//tensorflow/cc:scope", - "//tensorflow/cc:sendrecv_ops", "//tensorflow/cc:while_loop", - "//tensorflow/compiler/jit:common", - "//tensorflow/compiler/tf2xla/cc:xla_jit_ops", "//tensorflow/core:framework", - "//tensorflow/core:framework_internal", "//tensorflow/core:test", - "//tensorflow/core/framework:attr_value_proto_cc", "//tensorflow/core/framework:graph_proto_cc", - "//tensorflow/core/framework:node_def_proto_cc", "//tensorflow/core/framework:tensor_testutil", "//tensorflow/core/framework:types_proto_cc", "//tensorflow/core/grappler/utils:grappler_test", - "//tensorflow/core/kernels:resource_variable_ops", "//tensorflow/core/protobuf:for_core_protos_cc", "@com_google_googletest//:gtest_main", ], diff --git a/tensorflow/core/tfrt/utils/tfrt_graph_execution_state.cc b/tensorflow/core/tfrt/utils/tfrt_graph_execution_state.cc index 22ffa9a6b32ee5..8a80444993f4ff 100644 --- a/tensorflow/core/tfrt/utils/tfrt_graph_execution_state.cc +++ b/tensorflow/core/tfrt/utils/tfrt_graph_execution_state.cc @@ -26,16 +26,10 @@ limitations under the License. #include "absl/synchronization/mutex.h" #include "absl/time/clock.h" #include "absl/types/span.h" -#include "tensorflow/compiler/jit/defs.h" -#include "tensorflow/compiler/jit/encapsulate_xla_computations_pass.h" #include "tensorflow/compiler/mlir/tensorflow/translate/upgrade_graph.h" #include "tensorflow/core/common_runtime/function_body.h" #include "tensorflow/core/common_runtime/function_def_utils.h" #include "tensorflow/core/common_runtime/graph_constructor.h" -#include "tensorflow/core/common_runtime/lower_functional_ops.h" -#include "tensorflow/core/common_runtime/optimization_registry.h" -#include "tensorflow/core/common_runtime/partitioning_utils.h" -#include "tensorflow/core/common_runtime/placer.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/function.pb.h" @@ -46,7 +40,6 @@ limitations under the License. #include "tensorflow/core/framework/op_def.pb.h" #include "tensorflow/core/framework/versions.pb.h" #include "tensorflow/core/graph/graph.h" -#include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/grappler/utils.h" #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/status.h" @@ -231,171 +224,6 @@ NodeDef CreateNewIdentityNode(const NodeDef& node, return identity; } -// Inlines functions into the top level graph. -Status InlineFunctions(std::unique_ptr* graph, - const DeviceSet* device_set) { - GraphOptimizationPassOptions optimization_options; - SessionOptions session_options; - // We don't lower v2 control flow to v1 for now. - session_options.config.mutable_experimental()->set_use_tfrt(true); - session_options.config.mutable_graph_options() - ->mutable_optimizer_options() - ->set_do_function_inlining(true); - optimization_options.session_options = &session_options; - optimization_options.graph = graph; - optimization_options.flib_def = (*graph)->mutable_flib_def(); - optimization_options.device_set = device_set; - optimization_options.is_function_graph = false; - - LowerFunctionalOpsPass pass; - return pass.Run(optimization_options); -} - -// Assigns input/output nodes to the host. -Status PlaceInputOutputNodesOnHost(const std::vector& inputs, - const std::vector& outputs, - const Device* cpu_device, Graph* graph) { - std::unordered_map name_to_node_map = - graph->BuildNodeNameIndex(); - for (const auto& input : inputs) { - name_to_node_map.at(grappler::NodeName(input)) - ->set_assigned_device_name(cpu_device->name()); - } - - // Collect all output nodes. - absl::flat_hash_set output_nodes; - for (const auto& output : outputs) { - output_nodes.insert(name_to_node_map.at(grappler::NodeName(output))); - } - for (const auto& output_node : output_nodes) { - // Append an IdentityN node to the original output node if it is not - // assigned to the host. - if (!output_node->IsIdentity() && - output_node->type_string() != "IdentityN" && - output_node->assigned_device_name() != cpu_device->name()) { - // Rename the original output node. - std::string output_node_name = output_node->name(); - output_node->set_name(output_node_name + "/tfrt_renamed"); - - // Append an IdentityN node with the original output node name. - std::vector output_tensors; - output_tensors.reserve(output_node->num_outputs()); - for (int i = 0; i < output_node->num_outputs(); i++) { - output_tensors.push_back(NodeBuilder::NodeOut(output_node, i)); - } - TF_RETURN_IF_ERROR(NodeBuilder(output_node_name, "IdentityN") - .AssignedDevice(cpu_device->name()) - .Input(output_tensors) - .Finalize(graph, /*created_node=*/nullptr)); - } else { - output_node->set_assigned_device_name(cpu_device->name()); - } - } - return OkStatus(); -} - -Status AdjustDeviceAssignment(const std::vector& inputs, - const std::vector& outputs, - const std::vector& control_outputs, - const Device* cpu_device, Graph* graph) { - // TODO(b/232299232): We don't inline and partition v2 control flow currently. - // All ops within control flow are placed on CPU for now. Figure out a better - // way to handle v2 control flow. - for (Node* node : graph->op_nodes()) { - if (node->IsWhileNode() || node->IsIfNode()) { - LOG(WARNING) << "The control flow node " << node->name() - << " is placed on CPU."; - node->set_assigned_device_name(cpu_device->name()); - } - } - - TF_RETURN_IF_ERROR( - PlaceInputOutputNodesOnHost(inputs, outputs, cpu_device, graph)); - return OkStatus(); -} - -bool IsTpuGraph(const Graph* graph) { - static const auto* const kTpuOps = new absl::flat_hash_set{ - "TPUPartitionedCall", "TPUCompile", "TPUReplicateMetadata"}; - for (const Node* node : graph->nodes()) { - if (kTpuOps->contains(node->type_string())) { - return true; - } - } - for (const std::string& func_name : graph->flib_def().ListFunctionNames()) { - const FunctionDef* func_def = graph->flib_def().Find(func_name); - for (const NodeDef& node_def : func_def->node_def()) { - if (kTpuOps->contains(node_def.op())) return true; - } - } - return false; -} - -// Adds Send/Recv ops to `graph` for data transfer, if ops are run on different -// devices. Returns a new graph with the added Send/Recv ops. -// This is done by partitioning `graph` and add Send/Recv ops on the edges -// across devices. -StatusOr> BuildXlaOpsAndMaybeInsertTransferOps( - const std::string& graph_func_name, const FallbackState& fallback_state, - const std::vector& inputs, - const std::vector& outputs, - const std::vector& control_outputs, - std::unique_ptr graph) { - // Skip inserting transfer ops if this is a TPU graph. - // Our stack currently cannot run the old bridge on TPU graphs, as it will - // generate ops that are not supported by the subsequent MLIR passes. - // In the case where TPU related ops are not wrapped in TPUPartitionedCall, - // running placer and partitioning on such graphs will fail. So we skip TPU - // graphs for now. - // TODO(b/228510957): In the long term, we will want a unified way for data - // transfer, i.e., using Send/Recv ops for data transfer for TPU as well. - if (IsTpuGraph(graph.get())) { - return graph; - } - - // Inline functions to facilitate partitioning nodes in the functions. - TF_RETURN_IF_ERROR(InlineFunctions(&graph, &fallback_state.device_set())); - if (VLOG_IS_ON(1)) { - DumpGraphToFile("after_inlining", *graph); - } - - // Replace the StatefulPartitionedCall op that should be compiled to an - // XlaLaunch op. - // TODO(b/239089915): Clean this up after the logic is implemented in TFXLA - // bridge. - TF_RETURN_IF_ERROR(BuildXlaLaunchOps(graph.get())); - if (VLOG_IS_ON(1)) { - DumpGraphToFile("after_build_xla_launch", *graph); - } - - // Run placer. - const Device* cpu_device = fallback_state.device_manager().HostCPU(); - if (cpu_device == nullptr) { - return errors::Internal("No CPU device found."); - } - Placer placer(graph.get(), /*function_name=*/"", &graph->flib_def(), - &fallback_state.device_set(), cpu_device, - /*allow_soft_placement=*/true, - /*log_device_placement=*/false); - TF_RETURN_IF_ERROR(placer.Run()); - if (VLOG_IS_ON(1)) { - DumpGraphToFile("after_placer", *graph); - } - - TF_RETURN_IF_ERROR(AdjustDeviceAssignment(inputs, outputs, control_outputs, - cpu_device, graph.get())); - - // Insert send/recv ops to the graph. - TF_ASSIGN_OR_RETURN( - std::unique_ptr new_graph, - InsertTransferOps(fallback_state.device_set(), std::move(graph))); - if (VLOG_IS_ON(1)) { - DumpGraphToFile("after_transfer_ops_insertion", *new_graph); - } - - return new_graph; -} - } // namespace StatusOr @@ -463,22 +291,6 @@ TfrtGraphExecutionState::CreateOptimizedGraph( result.grappler_duration = absl::Now() - grappler_start_time; - if (options_.enable_tfrt_gpu && !options_.use_bridge_for_gpu) { - TF_ASSIGN_OR_RETURN( - result.graph, - BuildXlaOpsAndMaybeInsertTransferOps( - graph_import_config.graph_func_name, fallback_state_, inputs, - graph_import_config.outputs, graph_import_config.control_outputs, - std::move(result.graph))); - - // Update `control_outputs` as there might be newly added Send ops. - for (const Node* node : result.graph->nodes()) { - if (node->IsSend()) { - graph_import_config.control_outputs.push_back(node->name()); - } - } - } - return result; } @@ -865,41 +677,5 @@ TfrtGraphExecutionState::OptimizeGraph( return optimized_graph; } -// TODO(b/239089915): Clean this up after the logic is implemented in TFXLA -// bridge. -Status BuildXlaLaunchOps(Graph* graph) { - const auto is_xla_launch_node = [](const Node& n) -> StatusOr { - if (!n.IsPartitionedCall()) { - return false; - } - bool xla_must_compile = false; - const bool has_attribute = - TryGetNodeAttr(n.attrs(), kXlaMustCompileAttr, &xla_must_compile); - return has_attribute && xla_must_compile; - }; - - const auto get_xla_function_info = [](const Node& launch) - -> StatusOr { - EncapsulateXlaComputationsPass::XlaFunctionInfo result; - std::vector tin_dtypes; - TF_RETURN_IF_ERROR(GetNodeAttr(launch.def(), "Tin", &tin_dtypes)); - int variable_start_index = 0; - for (; variable_start_index < tin_dtypes.size(); ++variable_start_index) { - if (tin_dtypes.at(variable_start_index) == DT_RESOURCE) break; - } - result.variable_start_index = variable_start_index; - - NameAttrList func; - TF_RETURN_IF_ERROR(GetNodeAttr(launch.attrs(), "f", &func)); - result.function_name = func.name(); - - return result; - }; - - return EncapsulateXlaComputationsPass::BuildXlaLaunchOps( - graph, is_xla_launch_node, get_xla_function_info, - /*add_edges_to_output_of_downstream_nodes=*/false); -} - } // namespace tfrt_stub } // namespace tensorflow diff --git a/tensorflow/core/tfrt/utils/tfrt_graph_execution_state.h b/tensorflow/core/tfrt/utils/tfrt_graph_execution_state.h index d592d857fd6769..e347412ec532f6 100644 --- a/tensorflow/core/tfrt/utils/tfrt_graph_execution_state.h +++ b/tensorflow/core/tfrt/utils/tfrt_graph_execution_state.h @@ -52,10 +52,6 @@ class TfrtGraphExecutionState { struct Options { bool run_placer_grappler_on_functions = false; - // TODO(b/262826012): Remove the flag after we switch to using bridge. - bool enable_tfrt_gpu = false; - // TODO(b/260915352): Remove the flag and default to using bridge. - bool use_bridge_for_gpu = false; bool run_placer_on_graph = true; }; @@ -138,12 +134,6 @@ Status EliminateRefVariablesFromV1ControlFlow(GraphDef& graph_def); // Removes the "_input_shapes" attribute of functions in the graph. void RemoveInputShapesInFunctions(tensorflow::GraphDef& graph_def); -// Replaces partitioned calls in the graph that have _XlaMustCompile attribute -// set to true with XlaLaunch op. -// TODO(b/239089915): Clean this up after the logic is implemented in TFXLA -// bridge. -Status BuildXlaLaunchOps(Graph* graph); - } // namespace tfrt_stub } // namespace tensorflow diff --git a/tensorflow/core/tfrt/utils/tfrt_graph_execution_state_test.cc b/tensorflow/core/tfrt/utils/tfrt_graph_execution_state_test.cc index aa99c168ebd1c1..e16b941cc46c4a 100644 --- a/tensorflow/core/tfrt/utils/tfrt_graph_execution_state_test.cc +++ b/tensorflow/core/tfrt/utils/tfrt_graph_execution_state_test.cc @@ -21,32 +21,18 @@ limitations under the License. #include #include -#include "tensorflow/cc/framework/scope.h" #include "tensorflow/cc/ops/array_ops.h" #include "tensorflow/cc/ops/const_op.h" -#include "tensorflow/cc/ops/function_ops.h" #include "tensorflow/cc/ops/functional_ops.h" -#include "tensorflow/cc/ops/math_ops.h" -#include "tensorflow/cc/ops/resource_variable_ops.h" -#include "tensorflow/cc/ops/sendrecv_ops.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/cc/ops/while_loop.h" -#include "tensorflow/compiler/jit/defs.h" -#include "tensorflow/compiler/tf2xla/cc/ops/xla_jit_ops.h" -#include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/device_factory.h" -#include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/graph.pb.h" -#include "tensorflow/core/framework/graph_to_functiondef.h" -#include "tensorflow/core/framework/node_def.pb.h" -#include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/grappler/utils/grappler_test.h" -#include "tensorflow/core/kernels/resource_variable_ops.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/protobuf/rewriter_config.pb.h" -#include "tensorflow/core/util/equal_graph_def.h" namespace tensorflow { namespace tfrt_stub { @@ -761,481 +747,6 @@ TEST_F(ExtendGraphTest, ExtendGraph) { CompareGraphs(expected, *graph_execution_state->original_graph_def()); } -// An auxiliary struct to verify the graph after partitioning and inserting -// transfer ops. -struct GraphInfo { - NodeDef* input_node = nullptr; - NodeDef* output_node = nullptr; - NodeDef* stateful_partitioned_call_node = nullptr; - std::vector partitioned_call_nodes; - std::vector fdefs; -}; - -class InsertTransferOpsTest : public grappler::GrapplerTest { - protected: - void SetUp() override { - SessionOptions options; - auto* device_count = options.config.mutable_device_count(); - device_count->insert({"CPU", 2}); - std::vector> devices; - TF_ASSERT_OK(DeviceFactory::AddDevices(options, "/job:a/replica:0/task:0", - &devices)); - device0_ = devices[0].get(); - device1_ = devices[1].get(); - - fallback_state_ = - std::make_unique(options, std::move(devices), fdef_lib_); - } - - GraphInfo GetGraphInfo(const std::string& input, const std::string& output, - GraphDef& graphdef) { - GraphInfo graph_info; - for (NodeDef& node : *graphdef.mutable_node()) { - if (node.op() == "PartitionedCall") { - graph_info.partitioned_call_nodes.push_back(&node); - } else if (node.op() == "StatefulPartitionedCall") { - graph_info.stateful_partitioned_call_node = &node; - } else if (node.name() == input) { - graph_info.input_node = &node; - } else if (node.name() == output) { - graph_info.output_node = &node; - } - } - - // Find the corresponding function called by the PartitionedCall nodes. - absl::flat_hash_map func_name_to_func; - for (const FunctionDef& fdef : graphdef.library().function()) { - func_name_to_func[fdef.signature().name()] = fdef; - } - for (NodeDef* node : graph_info.partitioned_call_nodes) { - CHECK(node->attr().contains("f")); - CHECK(func_name_to_func.contains(node->attr().at("f").func().name())); - const FunctionDef& fdef = - func_name_to_func.at(node->attr().at("f").func().name()); - graph_info.fdefs.push_back(fdef); - } - return graph_info; - } - - std::unique_ptr fallback_state_; - Device* device0_ = nullptr; // Not owned. - Device* device1_ = nullptr; // Not owned. - tensorflow::FunctionDefLibrary fdef_lib_; -}; - -TEST_F(InsertTransferOpsTest, InsertTransferOps) { - GraphDef graphdef; - { - Scope scope = Scope::NewRootScope(); - Scope scope1 = scope.WithDevice(device0_->name()); - Scope scope2 = scope.WithDevice(device1_->name()); - - // A graph whose nodes are on different devices. - // a(Const, on device0) -> b(Abs, on device1) -> c(Identity, on device0) - Output a = ops::Const(scope1.WithOpName("a"), 2.0, {1, 1}); - Output b = ops::Abs(scope2.WithOpName("b"), a); - Output c = ops::Identity(scope1.WithOpName("c"), b); - - // Before partitioning, there is no send/recv nodes. - int send_count = 0, recv_count = 0; - for (const auto* op : scope.graph()->op_nodes()) { - if (op->IsSend()) - ++send_count; - else if (op->IsRecv()) - ++recv_count; - } - ASSERT_EQ(scope.graph()->num_op_nodes(), 3); - ASSERT_EQ(send_count, 0); - ASSERT_EQ(recv_count, 0); - - TF_ASSERT_OK(scope.ToGraphDef(&graphdef)); - } - - TfrtGraphExecutionState::Options options; - options.run_placer_grappler_on_functions = false; - options.enable_tfrt_gpu = true; - TF_ASSERT_OK_AND_ASSIGN( - auto graph_execution_state, - TfrtGraphExecutionState::Create(options, graphdef, *fallback_state_)); - - tensorflow::GraphImportConfig graph_import_config; - graph_import_config.prune_unused_nodes = true; - graph_import_config.enable_shape_inference = false; - tensorflow::ArrayInfo array_info; - array_info.imported_dtype = DT_FLOAT; - array_info.shape.set_unknown_rank(true); - graph_import_config.inputs["a"] = array_info; - graph_import_config.outputs = {"c"}; - - TF_ASSERT_OK_AND_ASSIGN( - auto optimized_graph, - graph_execution_state->CreateOptimizedGraph(graph_import_config)); - - // Verify that two paris of Send/Recv nodes are added. - int send_count = 0, recv_count = 0; - for (const auto* op : optimized_graph.graph->op_nodes()) { - if (op->IsSend()) - ++send_count; - else if (op->IsRecv()) - ++recv_count; - } - EXPECT_EQ(optimized_graph.graph->num_op_nodes(), 7); - EXPECT_EQ(send_count, 2); - EXPECT_EQ(recv_count, 2); -} - -TEST_F(InsertTransferOpsTest, InsertTransferOpsWithFunctionInlining) { - GraphDef graphdef; - { - Scope scope = Scope::NewRootScope(); - Scope scope1 = scope.WithDevice(device0_->name()); - Scope scope2 = scope.WithDevice(device1_->name()); - - // A graph whose nodes are on different devices. - // a(Const, on device0) -> b(PartitionedCall) -> c(Identity, on device0) - // where PartitionedCall invokes a function with two nodes assigned to - // different devices. - const Tensor kThree = test::AsScalar(3.0); - auto fdef = tensorflow::FunctionDefHelper::Create( - "_Pow3", {"x: float"}, {"y: float"}, {}, - {// The two nodes in the function are assigned to different devices. - {{"three"}, - "Const", - {}, - {{"dtype", DT_FLOAT}, {"value", kThree}}, - /*dep=*/{}, - device0_->name()}, - {{"pow3"}, - "Pow", - {"x", "three:output:0"}, - {{"T", DT_FLOAT}}, - /*dep=*/{}, - device1_->name()}}, - {{"y", "pow3:z:0"}}); - - tensorflow::FunctionDefLibrary fdef_lib; - *fdef_lib.add_function() = fdef; - TF_ASSERT_OK(scope.graph()->AddFunctionLibrary(fdef_lib)); - - Output a = ops::Const(scope1.WithOpName("a"), 2.0, {1, 1}); - - std::vector inputs = {a}; - std::vector output_dtypes = { - fdef.signature().output_arg(0).type()}; - tensorflow::NameAttrList func_attr; - func_attr.set_name(fdef.signature().name()); - auto pcall = ops::PartitionedCall(scope2, inputs, output_dtypes, func_attr); - Output b = pcall.output.front(); - - Output c = ops::Identity(scope1.WithOpName("c"), b); - - TF_ASSERT_OK(scope.ToGraphDef(&graphdef)); - - // Before partitioning, there is no send/recv nodes. - int partitioned_call_count = 0, mul_count = 0, send_count = 0, - recv_count = 0; - for (const auto* op : scope.graph()->op_nodes()) { - if (op->IsPartitionedCall()) - ++partitioned_call_count; - else if (op->IsSend()) - ++send_count; - else if (op->IsRecv()) - ++recv_count; - else if (op->type_string() == "Mul") - ++mul_count; - } - ASSERT_EQ(partitioned_call_count, 1); - ASSERT_EQ(send_count, 0); - ASSERT_EQ(recv_count, 0); - ASSERT_EQ(mul_count, 0); - } - - TfrtGraphExecutionState::Options options; - options.run_placer_grappler_on_functions = false; - options.enable_tfrt_gpu = true; - TF_ASSERT_OK_AND_ASSIGN( - auto graph_execution_state, - TfrtGraphExecutionState::Create(options, graphdef, *fallback_state_)); - - tensorflow::GraphImportConfig graph_import_config; - graph_import_config.prune_unused_nodes = true; - graph_import_config.enable_shape_inference = false; - tensorflow::ArrayInfo array_info; - array_info.imported_dtype = DT_FLOAT; - array_info.shape.set_unknown_rank(true); - graph_import_config.inputs["a"] = array_info; - graph_import_config.outputs = {"c"}; - - TF_ASSERT_OK_AND_ASSIGN( - auto optimized_graph, - graph_execution_state->CreateOptimizedGraph(graph_import_config)); - - // Verify that the resultant graph has no PartitionedCall ops, function body - // is inlined into the main graph, and send/recv ops are added. - int partitioned_call_count = 0, mul_count = 0, send_count = 0, recv_count = 0; - for (const auto* op : optimized_graph.graph->op_nodes()) { - if (op->IsPartitionedCall()) - ++partitioned_call_count; - else if (op->IsSend()) - ++send_count; - else if (op->IsRecv()) - ++recv_count; - else if (op->type_string() == "Mul") - ++mul_count; - } - - EXPECT_EQ(partitioned_call_count, 0); - EXPECT_EQ(send_count, 2); - EXPECT_EQ(recv_count, 2); - EXPECT_EQ(mul_count, 1); -} - -TEST_F(InsertTransferOpsTest, AppendIdentityN) { - GraphDef graphdef; - { - Scope scope = Scope::NewRootScope(); - Scope scope1 = scope.WithDevice(device0_->name()); - Scope scope2 = scope.WithDevice(device1_->name()); - - // A graph with two nodes assigned on different devices. - // a(Const, on device0) -> b(Abs, on device1) - Output a = ops::Const(scope1.WithOpName("a"), 2.0, {1, 1}); - Output b = ops::Abs(scope2.WithOpName("b"), a); - - TF_ASSERT_OK(scope.ToGraphDef(&graphdef)); - - // There is no IdentityN/Send/Recv nodes originally. - int identity_count = 0, abs_count = 0, const_count = 0, send_count = 0, - recv_count = 0; - for (const auto* op : scope.graph()->op_nodes()) { - if (op->type_string() == "IdentityN") - ++identity_count; - else if (op->IsConstant()) - ++const_count; - else if (op->type_string() == "Abs") - ++abs_count; - else if (op->IsSend()) - ++send_count; - else if (op->IsRecv()) - ++recv_count; - } - ASSERT_EQ(scope.graph()->num_op_nodes(), 2); - ASSERT_EQ(identity_count, 0); - ASSERT_EQ(const_count, 1); - ASSERT_EQ(abs_count, 1); - ASSERT_EQ(send_count, 0); - ASSERT_EQ(recv_count, 0); - } - TfrtGraphExecutionState::Options options; - options.run_placer_grappler_on_functions = false; - options.enable_tfrt_gpu = true; - TF_ASSERT_OK_AND_ASSIGN( - auto graph_execution_state, - TfrtGraphExecutionState::Create(options, graphdef, *fallback_state_)); - - tensorflow::GraphImportConfig graph_import_config; - graph_import_config.prune_unused_nodes = true; - graph_import_config.enable_shape_inference = false; - tensorflow::ArrayInfo array_info; - array_info.imported_dtype = DT_FLOAT; - array_info.shape.set_unknown_rank(true); - graph_import_config.inputs["a"] = array_info; - graph_import_config.outputs = {"b"}; - - TF_ASSERT_OK_AND_ASSIGN( - auto optimized_graph, - graph_execution_state->CreateOptimizedGraph(graph_import_config)); - GraphDef optimized_graphdef; - optimized_graph.graph->ToGraphDef(&optimized_graphdef); - - // Verify that IdentityN/Send/Recv nodes are added. - int identity_count = 0, abs_count = 0, const_count = 0, send_count = 0, - recv_count = 0; - for (const auto* op : optimized_graph.graph->op_nodes()) { - if (op->type_string() == "IdentityN") - ++identity_count; - else if (op->IsConstant()) - ++const_count; - else if (op->type_string() == "Abs") - ++abs_count; - else if (op->IsSend()) - ++send_count; - else if (op->IsRecv()) - ++recv_count; - } - EXPECT_EQ(optimized_graph.graph->num_op_nodes(), 7); - EXPECT_EQ(identity_count, 1); - EXPECT_EQ(const_count, 1); - EXPECT_EQ(abs_count, 1); - EXPECT_EQ(send_count, 2); - EXPECT_EQ(recv_count, 2); -} - -std::unique_ptr MakeOuterGraph(const FunctionLibraryDefinition& flib_def, - const std::string& function_name) { - Scope scope = Scope::NewRootScope().ExitOnError(); - TF_EXPECT_OK(scope.graph()->AddFunctionLibrary(flib_def.ToProto())); - - auto a = ops::Placeholder(scope.WithOpName("A"), DT_INT32); - auto b = ops::Placeholder(scope.WithOpName("B"), DT_FLOAT); - auto c = ops::Placeholder(scope.WithOpName("C"), DT_INT32); - auto d = ops::Placeholder(scope.WithOpName("D"), DT_FLOAT); - auto u = ops::Placeholder(scope.WithOpName("U"), DT_RESOURCE); - auto v = ops::Placeholder(scope.WithOpName("V"), DT_RESOURCE); - auto w = ops::Placeholder(scope.WithOpName("W"), DT_RESOURCE); - - std::vector func_inputs; - func_inputs.push_back( - tensorflow::NodeDefBuilder::NodeOut(a.node()->name(), 0, DT_INT32)); - func_inputs.push_back(tensorflow::NodeDefBuilder::NodeOut(b.node()->name(), 0, - b.output.type())); - func_inputs.push_back(tensorflow::NodeDefBuilder::NodeOut(c.node()->name(), 0, - c.output.type())); - func_inputs.push_back(tensorflow::NodeDefBuilder::NodeOut(d.node()->name(), 0, - d.output.type())); - func_inputs.push_back(tensorflow::NodeDefBuilder::NodeOut(u.node()->name(), 0, - u.output.type())); - func_inputs.push_back(tensorflow::NodeDefBuilder::NodeOut(v.node()->name(), 0, - v.output.type())); - func_inputs.push_back(tensorflow::NodeDefBuilder::NodeOut(w.node()->name(), 0, - w.output.type())); - - std::vector input_dtypes; - for (const NodeDefBuilder::NodeOut& func_input : func_inputs) { - input_dtypes.push_back(func_input.data_type); - } - - std::vector output_dtypes = {DT_FLOAT, DT_INT32, DT_FLOAT, - DT_FLOAT}; - - NameAttrList f; - f.set_name(function_name); - - NodeDef def; - TF_CHECK_OK(NodeDefBuilder("xla_call_0", "StatefulPartitionedCall", &flib_def) - .Input(func_inputs) - .Attr("Tin", input_dtypes) - .Attr("Tout", output_dtypes) - .Attr("f", f) - .Device("/gpu:0") - .Attr(kXlaMustCompileAttr, true) - .Finalize(&def)); - - Status status; - Node* launch = scope.graph()->AddNode(def, &status); - TF_CHECK_OK(status); - TF_CHECK_OK(scope.DoShapeInference(launch)); - scope.graph()->AddEdge(a.node(), 0, launch, 0); - scope.graph()->AddEdge(b.node(), 0, launch, 1); - scope.graph()->AddEdge(c.node(), 0, launch, 2); - scope.graph()->AddEdge(d.node(), 0, launch, 3); - scope.graph()->AddEdge(u.node(), 0, launch, 4); - scope.graph()->AddEdge(v.node(), 0, launch, 5); - scope.graph()->AddEdge(w.node(), 0, launch, 6); - - auto consumer0_a = - ops::Identity(scope.WithOpName("consumer0_a"), Output(launch, 0)); - auto consumer0_b = - ops::Identity(scope.WithOpName("consumer0_b"), Output(launch, 0)); - auto consumer0_c = - ops::Identity(scope.WithOpName("consumer0_c"), Output(launch, 0)); - auto consumer1 = - ops::Identity(scope.WithOpName("consumer1"), Output(launch, 1)); - auto consumer2 = - ops::Identity(scope.WithOpName("consumer2"), Output(launch, 2)); - auto consumer3 = - ops::Identity(scope.WithOpName("consumer3"), Output(launch, 3)); - - std::unique_ptr graph(new Graph(OpRegistry::Global())); - TF_CHECK_OK(scope.ToGraph(graph.get())); - return graph; -} - -// Makes an encapsulate body graph for use in tests. -std::unique_ptr MakeBodyGraph() { - Scope scope = Scope::NewRootScope().ExitOnError(); - - auto arg0 = ops::_Arg(scope.WithOpName("a_0_arg"), DT_INT32, 0); - auto arg1 = ops::_Arg(scope.WithOpName("b_0_arg"), DT_FLOAT, 1); - auto arg2 = ops::_Arg(scope.WithOpName("c_0_arg"), DT_INT32, 2); - auto arg3 = ops::_Arg(scope.WithOpName("d_0_arg"), DT_FLOAT, 3); - - auto arg4 = ops::_Arg(scope.WithOpName("u_0_arg"), DT_RESOURCE, 4); - auto arg5 = ops::_Arg(scope.WithOpName("v_0_arg"), DT_RESOURCE, 5); - auto arg6 = ops::_Arg(scope.WithOpName("w_0_arg"), DT_RESOURCE, 6); - - auto b_identity = ops::Identity(scope.WithOpName("B_identity"), arg1); - auto read_u = ops::ReadVariableOp(scope.WithOpName("ReadU"), arg4, DT_FLOAT); - auto read_v = ops::ReadVariableOp(scope.WithOpName("ReadV"), arg5, DT_FLOAT); - auto read_w = ops::ReadVariableOp(scope.WithOpName("ReadW"), arg6, DT_FLOAT); - - auto e = ops::Add(scope.WithOpName("E"), arg0, arg2); - auto f = ops::Add(scope.WithOpName("F"), read_v, read_w); - auto g = ops::Add(scope.WithOpName("G"), f, arg3); - - auto out0 = ops::_Retval(scope.WithOpName("b_identity_0_retval_RetVal"), - b_identity, 0); - auto out1 = ops::_Retval(scope.WithOpName("e_0_retval_RetVal"), e, 1); - auto out2 = ops::_Retval(scope.WithOpName("g_0_retval_RetVal"), g, 2); - auto out3 = - ops::_Retval(scope.WithOpName("readu_0_retval_RetVal"), read_u, 3); - - std::unique_ptr graph(new Graph(OpRegistry::Global())); - TF_CHECK_OK(scope.ToGraph(graph.get())); - return graph; -} - -TEST(BuildXlaOpsTest, BuildXlaLaunchOp) { - std::unique_ptr body_graph = MakeBodyGraph(); - FunctionDefLibrary flib; - TF_ASSERT_OK( - GraphToFunctionDef(*body_graph, "xla_func_0", flib.add_function())); - - FunctionLibraryDefinition flib_def(OpRegistry::Global(), flib); - - std::unique_ptr graph = MakeOuterGraph(flib_def, "xla_func_0"); - TF_ASSERT_OK(BuildXlaLaunchOps(graph.get())); - - Scope scope = Scope::DisabledShapeInferenceScope().ExitOnError(); - TF_EXPECT_OK(scope.graph()->AddFunctionLibrary(flib)); - - auto a = ops::Placeholder(scope.WithOpName("A"), DT_INT32); - auto b = ops::Placeholder(scope.WithOpName("B"), DT_FLOAT); - auto c = ops::Placeholder(scope.WithOpName("C"), DT_INT32); - auto d = ops::Placeholder(scope.WithOpName("D"), DT_FLOAT); - auto u = ops::Placeholder(scope.WithOpName("U"), DT_RESOURCE); - auto v = ops::Placeholder(scope.WithOpName("V"), DT_RESOURCE); - auto w = ops::Placeholder(scope.WithOpName("W"), DT_RESOURCE); - - NameAttrList function; - function.set_name("xla_func_0"); - auto launch = ops::XlaLaunch( - scope.WithOpName("xla_call_0").WithDevice("/gpu:0"), - std::initializer_list{}, std::initializer_list{a, b, c, d}, - std::initializer_list{u, v, w}, - DataTypeVector{DT_FLOAT, DT_INT32, DT_FLOAT, DT_FLOAT}, function); - - auto consumer0_a = - ops::Identity(scope.WithOpName("consumer0_a"), launch.results[0]); - auto consumer0_b = - ops::Identity(scope.WithOpName("consumer0_b"), launch.results[0]); - auto consumer0_c = - ops::Identity(scope.WithOpName("consumer0_c"), launch.results[0]); - auto consumer1 = - ops::Identity(scope.WithOpName("consumer1"), launch.results[1]); - auto consumer2 = - ops::Identity(scope.WithOpName("consumer2"), launch.results[2]); - auto consumer3 = - ops::Identity(scope.WithOpName("consumer3"), launch.results[3]); - - GraphDef expected_def; - TF_ASSERT_OK(scope.ToGraphDef(&expected_def)); - - GraphDef actual_def; - graph->ToGraphDef(&actual_def); - TF_EXPECT_GRAPH_EQ(expected_def, actual_def); -} - } // namespace } // namespace tfrt_stub } // namespace tensorflow From e468d20854020f3a15c81771e29732e05fff9af3 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Thu, 13 Jul 2023 17:22:38 -0700 Subject: [PATCH 296/376] [xla:gpu] Instantiate all cuda graphs ahead of time PiperOrigin-RevId: 547972630 --- tensorflow/compiler/xla/debug_options_flags.cc | 2 +- tensorflow/compiler/xla/service/gpu/runtime/executable.cc | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/tensorflow/compiler/xla/debug_options_flags.cc b/tensorflow/compiler/xla/debug_options_flags.cc index eef4cc1ac0a5f0..facc9f0cc482f0 100644 --- a/tensorflow/compiler/xla/debug_options_flags.cc +++ b/tensorflow/compiler/xla/debug_options_flags.cc @@ -105,7 +105,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { // TODO(b/258036887): Enable cuda_graph_level=2. Currently blocked by CUDA 12 // integration. opts.set_xla_gpu_cuda_graph_level(1); - opts.set_xla_gpu_cuda_graph_num_runs_to_instantiate(2); + opts.set_xla_gpu_cuda_graph_num_runs_to_instantiate(-1); opts.set_xla_gpu_enable_persistent_temp_buffers(false); opts.set_xla_gpu_cuda_graph_min_graph_size(5); opts.set_xla_gpu_cuda_graph_enable_concurrent_region(false); diff --git a/tensorflow/compiler/xla/service/gpu/runtime/executable.cc b/tensorflow/compiler/xla/service/gpu/runtime/executable.cc index 49fe85dafc9e5f..96b7309136645b 100644 --- a/tensorflow/compiler/xla/service/gpu/runtime/executable.cc +++ b/tensorflow/compiler/xla/service/gpu/runtime/executable.cc @@ -428,7 +428,10 @@ Status GpuRuntimeExecutable::Execute( for (unsigned i = 0; i < buffer_allocations.size(); ++i) { auto mem = buffer_allocations.GetDeviceAddress(i); - if (mem.size() > device_ptr_size) device_ptr = mem.opaque(); + if (mem.size() > device_ptr_size) { + device_ptr = mem.opaque(); + device_ptr_size = mem.size(); + } } if (auto instantiated = graph_instances_.InstantiateAllGraphs( From be1c8f6b19948645e886fad182ae4a8d8dfd2a85 Mon Sep 17 00:00:00 2001 From: Rahul Joshi Date: Thu, 13 Jul 2023 18:14:08 -0700 Subject: [PATCH 297/376] [NFC] Cleanup unused headers in HLO Rematerialization PiperOrigin-RevId: 547981554 --- tensorflow/compiler/xla/service/BUILD | 20 +++---------------- .../xla/service/hlo_memory_scheduler.h | 8 ++------ .../xla/service/hlo_rematerialization.cc | 12 +++-------- .../xla/service/hlo_rematerialization.h | 2 +- .../xla/service/hlo_rematerialization_test.cc | 4 +--- .../hlo_rematerialization_test_utils.h | 5 ----- .../hlo_rematerialization_test_utils_test.cc | 8 -------- 7 files changed, 10 insertions(+), 49 deletions(-) diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 100727bdd53f42..36924224482207 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -1803,7 +1803,6 @@ cc_library( deps = [ ":heap_simulator", ":hlo_alias_analysis", - ":hlo_ordering", ":hlo_pass", ":logical_buffer", ":tuple_points_to_analysis", @@ -4605,22 +4604,19 @@ cc_library( srcs = ["hlo_rematerialization.cc"], hdrs = ["hlo_rematerialization.h"], deps = [ - ":buffer_value", ":call_graph", - ":flatten_call_graph", + ":hlo_dataflow_analysis", ":hlo_dce", - ":hlo_memory_scheduler", ":hlo_ordering", + ":hlo_pass", ":logical_buffer", ":tuple_points_to_analysis", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/hlo/ir:hlo", "//tensorflow/compiler/xla/hlo/utils:hlo_query", - "//tensorflow/tsl/platform:logging", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", @@ -4636,16 +4632,11 @@ cc_library( testonly = 1, hdrs = ["hlo_rematerialization_test_utils.h"], deps = [ - ":hlo_ordering", - ":hlo_rematerialization", "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/hlo/ir:hlo", - "//tensorflow/compiler/xla/hlo/utils:hlo_matchers", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/tsl/lib/core:status_test_util", ], ) @@ -4653,16 +4644,10 @@ xla_cc_test( name = "hlo_rematerialization_test_utils_test", srcs = ["hlo_rematerialization_test_utils_test.cc"], deps = [ - ":hlo_ordering", ":hlo_rematerialization_test_utils", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/hlo/ir:hlo", - "//tensorflow/compiler/xla/hlo/utils:hlo_matchers", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/tsl/lib/core:status_test_util", ], ) @@ -4670,6 +4655,7 @@ xla_cc_test( name = "hlo_rematerialization_test", srcs = ["hlo_rematerialization_test.cc"], deps = [ + ":hlo_memory_scheduler", ":hlo_ordering", ":hlo_rematerialization", ":hlo_rematerialization_test_utils", diff --git a/tensorflow/compiler/xla/service/hlo_memory_scheduler.h b/tensorflow/compiler/xla/service/hlo_memory_scheduler.h index a345920fd1ca33..eabf659879ba87 100644 --- a/tensorflow/compiler/xla/service/hlo_memory_scheduler.h +++ b/tensorflow/compiler/xla/service/hlo_memory_scheduler.h @@ -16,19 +16,15 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MEMORY_SCHEDULER_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MEMORY_SCHEDULER_H_ -#include - #include "absl/container/flat_hash_map.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_instruction.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_module.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_schedule.h" #include "tensorflow/compiler/xla/service/hlo_alias_analysis.h" -#include "tensorflow/compiler/xla/service/hlo_ordering.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/compiler/xla/types.h" namespace xla { @@ -142,8 +138,8 @@ class HloMemoryScheduler : public HloModulePass { // size_function is the function returning the number of bytes required for a // LogicalBuffer. algorithm is the memory scheduling algorithm to use. If not // specified, then DefaultMemoryScheduler is used. - HloMemoryScheduler(const LogicalBuffer::SizeFunction& size_function, - const ModuleSchedulerAlgorithm& algorithm = {}); + explicit HloMemoryScheduler(const LogicalBuffer::SizeFunction& size_function, + const ModuleSchedulerAlgorithm& algorithm = {}); ~HloMemoryScheduler() override = default; diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc index b2f517962c6e75..d89df1a4419f3f 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc @@ -40,18 +40,12 @@ limitations under the License. #include "tensorflow/compiler/xla/hlo/ir/hlo_schedule.h" #include "tensorflow/compiler/xla/hlo/utils/hlo_query.h" #include "tensorflow/compiler/xla/map_util.h" -#include "tensorflow/compiler/xla/primitive_util.h" -#include "tensorflow/compiler/xla/service/buffer_value.h" -#include "tensorflow/compiler/xla/service/flatten_call_graph.h" +#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" #include "tensorflow/compiler/xla/service/hlo_dce.h" -#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h" -#include "tensorflow/compiler/xla/service/hlo_ordering.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/tsl/platform/logging.h" namespace xla { namespace { @@ -586,7 +580,7 @@ class MemoryUsageTracker { bool HasUnplacedUsers(Item* item) const; // Returns the list of uses for a specific 'item'. - const UsesList GetItemUses(Item* item) const; + UsesList GetItemUses(Item* item) const; // Returns whether 'item' is currently in progress. bool IsInProgressItem(Item* item) const { return item == in_progress_item_; } @@ -1532,7 +1526,7 @@ bool MemoryUsageTracker::HasUnplacedUsers(Item* item) const { return false; } -const UsesList MemoryUsageTracker::GetItemUses(Item* item) const { +UsesList MemoryUsageTracker::GetItemUses(Item* item) const { UsesList combined_users; for (BufferId buffer_id : item->buffers_defined) { const Buffer& buffer = buffers_.at(buffer_id); diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.h b/tensorflow/compiler/xla/service/hlo_rematerialization.h index ab85c9764cb5d8..611740c56e83b8 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.h +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.h @@ -25,7 +25,7 @@ #include "tensorflow/compiler/xla/hlo/ir/hlo_module.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_schedule.h" #include "tensorflow/compiler/xla/service/call_graph.h" -#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/statusor.h" diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc index 84f470c5da87ab..8c61c6cdd4a55f 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc @@ -23,11 +23,9 @@ limitations under the License. #include "tensorflow/compiler/xla/hlo/ir/hlo_instruction.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_opcode.h" #include "tensorflow/compiler/xla/hlo/utils/hlo_matchers.h" -#include "tensorflow/compiler/xla/service/hlo_ordering.h" +#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h" #include "tensorflow/compiler/xla/service/hlo_rematerialization_test_utils.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" -#include "tensorflow/compiler/xla/types.h" #include "tensorflow/tsl/lib/core/status_test_util.h" namespace xla { diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization_test_utils.h b/tensorflow/compiler/xla/service/hlo_rematerialization_test_utils.h index aac74366baa7e2..88f637249fa3f8 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization_test_utils.h +++ b/tensorflow/compiler/xla/service/hlo_rematerialization_test_utils.h @@ -24,14 +24,9 @@ limitations under the License. #include "tensorflow/compiler/xla/hlo/ir/hlo_computation.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_instruction.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_opcode.h" -#include "tensorflow/compiler/xla/hlo/utils/hlo_matchers.h" -#include "tensorflow/compiler/xla/service/hlo_ordering.h" -#include "tensorflow/compiler/xla/service/hlo_rematerialization.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" -#include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/tsl/lib/core/status_test_util.h" namespace xla { diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization_test_utils_test.cc b/tensorflow/compiler/xla/service/hlo_rematerialization_test_utils_test.cc index b3dc0861f421c1..a448e74b8c9e36 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization_test_utils_test.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization_test_utils_test.cc @@ -19,14 +19,6 @@ limitations under the License. #include "tensorflow/compiler/xla/hlo/ir/hlo_computation.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_instruction.h" -#include "tensorflow/compiler/xla/hlo/ir/hlo_opcode.h" -#include "tensorflow/compiler/xla/hlo/utils/hlo_matchers.h" -#include "tensorflow/compiler/xla/iterator_util.h" -#include "tensorflow/compiler/xla/service/hlo_ordering.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/tsl/lib/core/status_test_util.h" namespace xla { namespace { From 910241807f43979b3c1da1dc183f2bb50f561453 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 13 Jul 2023 19:59:46 -0700 Subject: [PATCH 298/376] Pass criticality and queue option for low priority to support priority queue in shared hatch scheduler PiperOrigin-RevId: 547997254 --- tensorflow/core/kernels/batch_kernels.cc | 32 +++++++++++- tensorflow/core/kernels/batching_util/BUILD | 1 + .../batching_util/batch_resource_base.cc | 51 +++++++++++++++++++ .../batching_util/batch_resource_base.h | 13 +++++ .../batching_util/shared_batch_scheduler.h | 29 +++++++++++ 5 files changed, 124 insertions(+), 2 deletions(-) diff --git a/tensorflow/core/kernels/batch_kernels.cc b/tensorflow/core/kernels/batch_kernels.cc index 4657f2c18d3bec..fc98d646c003dd 100644 --- a/tensorflow/core/kernels/batch_kernels.cc +++ b/tensorflow/core/kernels/batch_kernels.cc @@ -157,6 +157,27 @@ class BatchResource : public serving::BatchResourceBase { const std::vector& allowed_batch_sizes, bool enable_large_batch_splitting, std::unique_ptr* resource) { + return Create(has_process_batch_function, num_batch_threads, + max_execution_batch_size, batch_timeout_micros, + max_enqueued_batches, allowed_batch_sizes, + /*low_priority_max_batch_size=*/0, + /*low_priority_batch_timeout_micros=*/0, + /*low_priority_max_enqueued_batches=*/0, + /*low_priority_allowed_batch_sizes=*/{}, + enable_large_batch_splitting, resource); + } + + static Status Create( + bool has_process_batch_function, int32_t num_batch_threads, + int32_t max_execution_batch_size, int32_t batch_timeout_micros, + int32_t max_enqueued_batches, + const std::vector& allowed_batch_sizes, + int32_t low_priority_max_batch_size, + int32_t low_priority_batch_timeout_micros, + int32_t low_priority_max_enqueued_batches, + const std::vector& low_priority_allowed_batch_sizes, + bool enable_large_batch_splitting, + std::unique_ptr* resource) { BatcherT::Options batcher_options; batcher_options.num_batch_threads = num_batch_threads; std::shared_ptr batcher; @@ -167,7 +188,11 @@ class BatchResource : public serving::BatchResourceBase { GetBatcherQueueOptions( num_batch_threads, max_execution_batch_size, batch_timeout_micros, max_enqueued_batches, allowed_batch_sizes, - enable_large_batch_splitting, /*disable_padding=*/false), + enable_large_batch_splitting, + /*disable_padding=*/false, low_priority_max_batch_size, + low_priority_batch_timeout_micros, + low_priority_max_enqueued_batches, + low_priority_allowed_batch_sizes), allowed_batch_sizes)); return OkStatus(); } @@ -393,7 +418,10 @@ void BatchFunctionKernel::ComputeAsync(OpKernelContext* c, DoneCallback done) { TF_RETURN_IF_ERROR(BatchResource::Create( /*has_process_batch_function=*/true, num_batch_threads_, max_batch_size_, batch_timeout_micros_, max_enqueued_batches_, - allowed_batch_sizes_, enable_large_batch_splitting_, &new_resource)); + allowed_batch_sizes_, low_priority_max_batch_size_, + low_priority_batch_timeout_micros_, + low_priority_max_enqueued_batches_, low_priority_allowed_batch_sizes_, + enable_large_batch_splitting_, &new_resource)); if (session_metadata) { new_resource->set_session_metadata(*session_metadata); } diff --git a/tensorflow/core/kernels/batching_util/BUILD b/tensorflow/core/kernels/batching_util/BUILD index b7e6bfd2f6cf45..ad431a0956ec90 100644 --- a/tensorflow/core/kernels/batching_util/BUILD +++ b/tensorflow/core/kernels/batching_util/BUILD @@ -363,6 +363,7 @@ cc_library( "//tensorflow/core/profiler/lib:traceme_encode", "//tensorflow/core/protobuf:for_core_protos_cc", "//tensorflow/core/util:incremental_barrier", + "//tensorflow/tsl/platform:criticality", "@com_google_absl//absl/strings", "@com_google_absl//absl/time", "@com_google_absl//absl/types:optional", diff --git a/tensorflow/core/kernels/batching_util/batch_resource_base.cc b/tensorflow/core/kernels/batching_util/batch_resource_base.cc index fcdf0e46fc8a6a..4f2c00811a7732 100644 --- a/tensorflow/core/kernels/batching_util/batch_resource_base.cc +++ b/tensorflow/core/kernels/batching_util/batch_resource_base.cc @@ -272,6 +272,11 @@ Status BatchResourceBase::RegisterInput( batch_components->start_time = EnvTime::NowNanos(); batch_components->guid = guid; batch_components->propagated_context = Context(ContextKind::kThread); + + if (batcher_queue_options_.enable_priority_queue) { + batch_components->criticality = tsl::criticality::GetCriticality(); + } + OpInputList tensors; TF_RETURN_IF_ERROR(context->input_list("in_tensors", &tensors)); batch_components->inputs.reserve(tensors.size()); @@ -381,10 +386,44 @@ BatchResourceBase::GetBatcherQueueOptions( int32_t batch_timeout_micros, int32_t max_enqueued_batches, const std::vector& allowed_batch_sizes, bool enable_large_batch_splitting, bool disable_padding) { + return GetBatcherQueueOptions( + num_batch_threads, max_batch_size, batch_timeout_micros, + max_enqueued_batches, allowed_batch_sizes, enable_large_batch_splitting, + disable_padding, /*low_priority_max_batch_size=*/0, + /*low_priority_batch_timeout_micros=*/0, + /*low_priority_max_enqueued_batches=*/0, + /*low_priority_allowed_batch_sizes=*/{}); +} + +/*static*/ BatchResourceBase::BatcherT::QueueOptions +BatchResourceBase::GetBatcherQueueOptions( + int32_t num_batch_threads, int32_t max_batch_size, + int32_t batch_timeout_micros, int32_t max_enqueued_batches, + const std::vector& allowed_batch_sizes, + bool enable_large_batch_splitting, bool disable_padding, + int32_t low_priority_max_batch_size, + int32_t low_priority_batch_timeout_micros, + int32_t low_priority_max_enqueued_batches, + const std::vector& low_priority_allowed_batch_sizes) { BatcherT::QueueOptions batcher_queue_options; batcher_queue_options.input_batch_size_limit = max_batch_size; batcher_queue_options.max_enqueued_batches = max_enqueued_batches; batcher_queue_options.batch_timeout_micros = batch_timeout_micros; + if (low_priority_max_batch_size > 0) { + batcher_queue_options.enable_priority_queue = true; + } + batcher_queue_options.high_priority_queue_options.input_batch_size_limit = + max_batch_size; + batcher_queue_options.high_priority_queue_options.max_enqueued_batches = + max_enqueued_batches; + batcher_queue_options.high_priority_queue_options.batch_timeout_micros = + batch_timeout_micros; + batcher_queue_options.low_priority_queue_options.input_batch_size_limit = + low_priority_max_batch_size; + batcher_queue_options.low_priority_queue_options.max_enqueued_batches = + low_priority_max_enqueued_batches; + batcher_queue_options.low_priority_queue_options.batch_timeout_micros = + low_priority_batch_timeout_micros; batcher_queue_options.enable_large_batch_splitting = enable_large_batch_splitting; if (enable_large_batch_splitting) { @@ -398,9 +437,21 @@ BatchResourceBase::GetBatcherQueueOptions( if (allowed_batch_sizes.empty()) { batcher_queue_options.max_execution_batch_size = max_batch_size; + batcher_queue_options.high_priority_queue_options + .max_execution_batch_size = max_batch_size; } else { batcher_queue_options.max_execution_batch_size = *allowed_batch_sizes.rbegin(); + batcher_queue_options.high_priority_queue_options + .max_execution_batch_size = *allowed_batch_sizes.rbegin(); + } + if (low_priority_allowed_batch_sizes.empty()) { + batcher_queue_options.low_priority_queue_options + .max_execution_batch_size = low_priority_max_batch_size; + } else { + batcher_queue_options.low_priority_queue_options + .max_execution_batch_size = + *low_priority_allowed_batch_sizes.rbegin(); } } batcher_queue_options.disable_padding = disable_padding; diff --git a/tensorflow/core/kernels/batching_util/batch_resource_base.h b/tensorflow/core/kernels/batching_util/batch_resource_base.h index 446186472f5a62..a2fbfbcb7d9756 100644 --- a/tensorflow/core/kernels/batching_util/batch_resource_base.h +++ b/tensorflow/core/kernels/batching_util/batch_resource_base.h @@ -35,6 +35,7 @@ limitations under the License. #include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/thread_annotations.h" #include "tensorflow/core/protobuf/config.pb.h" +#include "tensorflow/tsl/platform/criticality.h" namespace tensorflow { namespace serving { @@ -107,6 +108,8 @@ class BatchResourceBase : public ResourceBase { // this task's processing costs. RequestCost* request_cost = nullptr; + tsl::criticality::Criticality criticality; + protected: virtual std::unique_ptr CreateDerivedTask() { return std::make_unique(); @@ -166,6 +169,16 @@ class BatchResourceBase : public ResourceBase { const std::vector& allowed_batch_sizes, bool enable_large_batch_splitting, bool disable_padding); + static BatcherT::QueueOptions GetBatcherQueueOptions( + int32_t num_batch_threads, int32_t max_batch_size, + int32_t batch_timeout_micros, int32_t max_enqueued_batches, + const std::vector& allowed_batch_sizes, + bool enable_large_batch_splitting, bool disable_padding, + int32_t low_priority_max_batch_size, + int32_t low_priority_batch_timeout_micros, + int32_t low_priority_max_enqueued_batches, + const std::vector& low_priority_allowed_batch_sizes); + static AdaptiveBatcherT::QueueOptions GetAdaptiveBatcherQueueOptions( int32_t max_batch_size, int32_t batch_timeout_micros, int32_t max_enqueued_batches, bool enable_large_batch_splitting, diff --git a/tensorflow/core/kernels/batching_util/shared_batch_scheduler.h b/tensorflow/core/kernels/batching_util/shared_batch_scheduler.h index 0e5c0b2f210709..df7866c5f3c473 100644 --- a/tensorflow/core/kernels/batching_util/shared_batch_scheduler.h +++ b/tensorflow/core/kernels/batching_util/shared_batch_scheduler.h @@ -221,6 +221,27 @@ class SharedBatchScheduler // If true, the padding will not be appended. bool disable_padding = false; + + // If true, queue implementation would split high priority and low priority + // inputs into two sub queues. + bool enable_priority_queue = false; + + // A separate set of queue options for different priority inputs. + // Use iff `enable_priority_queue` is true. + struct PriorityQueueOptions { + // See QueueOptions.max_execution_batch_size + size_t max_execution_batch_size = 0; + // See QueueOptions.batch_timeout_micros + int64_t batch_timeout_micros = 0; + // See QueueOptions.input_batch_size_limit + size_t input_batch_size_limit = 0; + // See QueueOptions.max_enqueued_batches + size_t max_enqueued_batches = 0; + }; + // A subset of queue options for high priority input. + PriorityQueueOptions high_priority_queue_options; + // A subset of queue options for low priority input. + PriorityQueueOptions low_priority_queue_options; }; Status AddQueue(const QueueOptions& options, std::function>)> @@ -465,6 +486,14 @@ class Queue { std::deque>>> task_handle_batches_ TF_GUARDED_BY(mu_); + // The enqueued batches for low priority input + std::deque>> low_priority_batches_ + TF_GUARDED_BY(mu_); + + // The enqueued batches for high priority input + std::deque>> high_priority_batches_ + TF_GUARDED_BY(mu_); + // The counter of the TraceMe context ids. uint64 traceme_context_id_counter_ TF_GUARDED_BY(mu_) = 0; From 4499c968316453d6abacd7b1dde77e1fa97c4499 Mon Sep 17 00:00:00 2001 From: Marcello Maggioni Date: Thu, 13 Jul 2023 20:42:49 -0700 Subject: [PATCH 299/376] [XLA] Fix masking for pad uneven sharding. We were looking at the wrong shape for the condition. PiperOrigin-RevId: 548005156 --- .../xla/service/spmd/spmd_partitioner.cc | 9 ++++----- .../xla/service/spmd/spmd_partitioner_test.cc | 19 +++++++++++++++++++ .../xla/service/spmd/spmd_partitioner_util.cc | 13 ++++++------- .../xla/service/spmd/spmd_partitioner_util.h | 3 +-- 4 files changed, 30 insertions(+), 14 deletions(-) diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc b/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc index b6cfb614d66849..dff4a2e5188be1 100644 --- a/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc +++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc @@ -1666,8 +1666,8 @@ PartitionedHlo PartitionedHlo::ReshardWithAllToAll( HloInstruction* zero = CreateZero( ShapeUtil::MakeShape(hlo_->shape().element_type(), {}), state_.b); HloSharding sharding_copy = sharding(); - auto padded_phlo = ReshardDataForPad(zero, pc, p_hlo, padded_base_shape, - sharding_copy, state_.b); + auto padded_phlo = + ReshardDataForPad(zero, pc, p_hlo, sharding_copy, state_.b); CHECK(padded_phlo.has_value()); VLOG(5) << "Resharded: " << padded_phlo->sharded_input->ToString(); VLOG(5) << "Padded Window: " << padded_phlo->shard_window.DebugString(); @@ -3870,9 +3870,8 @@ Status SpmdPartitioningVisitor::HandlePad(HloInstruction* hlo) { auto replicated_rhs = GetPartitionedHlo(hlo->operand(1)) .Reshard(HloSharding::Replicate()) .hlo(); - auto reshard_operand = - ReshardDataForPad(replicated_rhs, hlo->padding_config(), lhs, - hlo->shape(), hlo->sharding(), &b_); + auto reshard_operand = ReshardDataForPad( + replicated_rhs, hlo->padding_config(), lhs, hlo->sharding(), &b_); if (!reshard_operand.has_value()) { return DefaultAction(hlo); } diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc index 3fa98496e3091c..5a6b1dee645e4b 100644 --- a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc +++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc @@ -13539,6 +13539,25 @@ ENTRY %entry (p0: f32[8], p1: f32[1]) -> (f32[1], token[]) { EXPECT_THAT(outfeed->operand(0), op::Shape("(u32[2]{0})")); } +TEST_F(SpmdPartitioningTest, PadUneven) { + absl::string_view hlo_string = R"( +HloModule module + +ENTRY entry { + %param0 = f32[128,13,257] parameter(0), sharding={devices=[1,2,1]0,1} + %const = f32[] constant(0) + ROOT %pad = f32[128,14,257] pad(%param0, %const), padding=0_0x0_1x0_0, + sharding={devices=[1,2,1]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + const auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(op::Select(), op::Shape("f32[128,7,257]"))); +} + } // namespace } // namespace spmd } // namespace xla diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.cc b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.cc index 740bc394e2fd7a..8a98ec6d76f138 100644 --- a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.cc +++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.cc @@ -2350,8 +2350,7 @@ HloInstruction* SliceDataFromWindowReshard( std::optional ReshardDataForPad( HloInstruction* pad_value, PaddingConfig pc, PartitionedHlo to_reshard, - const Shape& target_shape, const HloSharding& target_sharding, - SpmdBuilder* b) { + const HloSharding& target_sharding, SpmdBuilder* b) { // Create a window config to represent the pad. Window window; bool needs_masking = false; @@ -2371,11 +2370,11 @@ std::optional ReshardDataForPad( // Need masking only if there is non-zero padding value or the operand is // unevenly partitioned. Halo exchange fills 0 in collective permute result // for non-destination cores. - needs_masking |= - shard_count > 1 && - (pd.edge_padding_low() > 0 || pd.edge_padding_high() > 0 || - pd.interior_padding() > 0) && - (!pad_value_is_zero || target_shape.dimensions(i) % shard_count != 0); + needs_masking |= shard_count > 1 && + (pd.edge_padding_low() > 0 || pd.edge_padding_high() > 0 || + pd.interior_padding() > 0) && + (!pad_value_is_zero || + to_reshard.base_shape().dimensions(i) % shard_count != 0); } // In compact halo exchange, we can't skip masking. return to_reshard.ReshardAsWindowedInput( diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h index 57413f865693f0..51448b8f2f2036 100644 --- a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h +++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h @@ -555,8 +555,7 @@ HloInstruction* SliceDataFromWindowReshard( // parameters. std::optional ReshardDataForPad( HloInstruction* pad_value, PaddingConfig pc, PartitionedHlo to_reshard, - const Shape& target_shape, const HloSharding& target_sharding, - SpmdBuilder* b); + const HloSharding& target_sharding, SpmdBuilder* b); // Performs padding of data based on the windowed sharding passed as input. HloInstruction* PadDataFromWindowReshard( From cd35adffd39a87cc35ac9e29beeaf6efd08bb51a Mon Sep 17 00:00:00 2001 From: Hyeontaek Lim Date: Thu, 13 Jul 2023 22:24:31 -0700 Subject: [PATCH 300/376] [IFRT] Roll forward with fix: Add serialization/deserialization for shardings This change adds serialization/deserialization for the following IFRT sharding types: * `SingleDeviceSharding` * `OpaqueSharding` * `ConcreteSharding` * `ConcreteEvenSharding` * `HloSharding` `ShardingParamSharding` serialization/deserialization is not yet supported. Fix: Sharding serialization and deserialization are seperate library targets that are always linked. This makes its registration mechanism to always run even when the whole binary target is built using full static linking. PiperOrigin-RevId: 548020248 --- tensorflow/compiler/xla/python/ifrt/BUILD | 42 +++ tensorflow/compiler/xla/python/ifrt/device.cc | 24 ++ tensorflow/compiler/xla/python/ifrt/device.h | 11 + tensorflow/compiler/xla/python/ifrt/shape.cc | 25 ++ tensorflow/compiler/xla/python/ifrt/shape.h | 8 + .../compiler/xla/python/ifrt/sharding.cc | 26 +- .../compiler/xla/python/ifrt/sharding.h | 15 +- .../compiler/xla/python/ifrt/sharding.proto | 46 ++++ .../xla/python/ifrt/sharding_serdes.cc | 240 ++++++++++++++++++ .../xla/python/ifrt/sharding_serdes.h | 48 ++++ .../xla/python/ifrt/sharding_serdes_test.cc | 157 ++++++++++++ .../compiler/xla/python/ifrt/types.proto | 31 +++ .../compiler/xla/python/pjrt_ifrt/BUILD | 36 +++ .../xla/python/pjrt_ifrt/xla_sharding.proto | 27 ++ .../python/pjrt_ifrt/xla_sharding_serdes.cc | 79 ++++++ .../pjrt_ifrt/xla_sharding_serdes_test.cc | 95 +++++++ 16 files changed, 891 insertions(+), 19 deletions(-) create mode 100644 tensorflow/compiler/xla/python/ifrt/sharding.proto create mode 100644 tensorflow/compiler/xla/python/ifrt/sharding_serdes.cc create mode 100644 tensorflow/compiler/xla/python/ifrt/sharding_serdes.h create mode 100644 tensorflow/compiler/xla/python/ifrt/sharding_serdes_test.cc create mode 100644 tensorflow/compiler/xla/python/ifrt/types.proto create mode 100644 tensorflow/compiler/xla/python/pjrt_ifrt/xla_sharding.proto create mode 100644 tensorflow/compiler/xla/python/pjrt_ifrt/xla_sharding_serdes.cc create mode 100644 tensorflow/compiler/xla/python/pjrt_ifrt/xla_sharding_serdes_test.cc diff --git a/tensorflow/compiler/xla/python/ifrt/BUILD b/tensorflow/compiler/xla/python/ifrt/BUILD index ec8774cd5ef187..a16497218b4744 100644 --- a/tensorflow/compiler/xla/python/ifrt/BUILD +++ b/tensorflow/compiler/xla/python/ifrt/BUILD @@ -66,12 +66,15 @@ cc_library( ], deps = [ ":serdes", + ":sharding_proto_cc", + ":types_proto_cc", "//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/pjrt:pjrt_client", "//tensorflow/compiler/xla/python/ifrt/ir", "//tensorflow/tsl/platform:logging", + "//tensorflow/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:inlined_vector", @@ -309,3 +312,42 @@ tf_proto_library( name = "serdes_proto", srcs = ["serdes.proto"], ) + +cc_library( + name = "sharding_serdes", + srcs = ["sharding_serdes.cc"], + hdrs = ["sharding_serdes.h"], + deps = [ + ":ifrt", + ":serdes", + ":sharding_proto_cc", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/tsl/platform:statusor", + "@llvm-project//llvm:Support", + ], + alwayslink = 1, +) + +xla_cc_test( + name = "sharding_serdes_test", + srcs = ["sharding_serdes_test.cc"], + deps = [ + ":ifrt", + ":mock", + ":serdes", + ":sharding_serdes", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_googletest//:gtest_main", + ], +) + +tf_proto_library( + name = "types_proto", + srcs = ["types.proto"], +) + +tf_proto_library( + name = "sharding_proto", + srcs = ["sharding.proto"], + protodeps = [":types_proto"], +) diff --git a/tensorflow/compiler/xla/python/ifrt/device.cc b/tensorflow/compiler/xla/python/ifrt/device.cc index 0f02149ae48a64..a549a811de6de6 100644 --- a/tensorflow/compiler/xla/python/ifrt/device.cc +++ b/tensorflow/compiler/xla/python/ifrt/device.cc @@ -15,11 +15,35 @@ limitations under the License. #include "tensorflow/compiler/xla/python/ifrt/device.h" +#include #include +#include "tensorflow/compiler/xla/python/ifrt/client.h" +#include "tensorflow/compiler/xla/python/ifrt/types.pb.h" + namespace xla { namespace ifrt { +StatusOr DeviceList::FromProto(Client* client, + const DeviceListProto& proto) { + DeviceList::Devices devices; + devices.reserve(proto.device_ids_size()); + for (int device_id : proto.device_ids()) { + TF_ASSIGN_OR_RETURN(Device * device, client->LookupDevice(device_id)); + devices.push_back(device); + } + return DeviceList(std::move(devices)); +} + +DeviceListProto DeviceList::ToProto() const { + DeviceListProto proto; + proto.mutable_device_ids()->Reserve(devices().size()); + for (Device* device : devices()) { + proto.mutable_device_ids()->AddAlreadyReserved(device->id()); + } + return proto; +} + std::vector GetDeviceIds(DeviceList device_list) { std::vector ids; ids.reserve(device_list.devices().size()); diff --git a/tensorflow/compiler/xla/python/ifrt/device.h b/tensorflow/compiler/xla/python/ifrt/device.h index a2d5f61dd35c2a..d54afa190deaa9 100644 --- a/tensorflow/compiler/xla/python/ifrt/device.h +++ b/tensorflow/compiler/xla/python/ifrt/device.h @@ -21,10 +21,13 @@ limitations under the License. #include "absl/container/inlined_vector.h" #include "tensorflow/compiler/xla/pjrt/pjrt_client.h" +#include "tensorflow/compiler/xla/python/ifrt/types.pb.h" namespace xla { namespace ifrt { +class Client; + // Short-term alias to reuse `xla::PjRtDevice` without a separate abstract type. using Device = ::xla::PjRtDevice; @@ -42,6 +45,14 @@ class DeviceList { explicit DeviceList(Devices devices) : devices_(std::move(devices)) {} + // Constructs `DeviceList` from `DeviceListProto`. Device ids in the proto + // must be consistent with the devices owned by `client'. + static StatusOr FromProto(Client* client, + const DeviceListProto& proto); + + // Returns a `DeviceListProto` representation. + DeviceListProto ToProto() const; + absl::Span devices() const { return devices_; } int size() const { return devices_.size(); } diff --git a/tensorflow/compiler/xla/python/ifrt/shape.cc b/tensorflow/compiler/xla/python/ifrt/shape.cc index bd3ff1fc8e08b6..07e8e2b81494a5 100644 --- a/tensorflow/compiler/xla/python/ifrt/shape.cc +++ b/tensorflow/compiler/xla/python/ifrt/shape.cc @@ -17,12 +17,37 @@ limitations under the License. #include #include +#include #include "absl/strings/str_join.h" +#include "tensorflow/compiler/xla/python/ifrt/types.pb.h" +#include "tensorflow/compiler/xla/util.h" namespace xla { namespace ifrt { +StatusOr Shape::FromProto(const ShapeProto& proto) { + Shape::Dimensions dims; + dims.reserve(proto.dims_size()); + for (int64_t dim : proto.dims()) { + if (dim < 0) { + return InvalidArgument( + "Shape expects non-negative dimension sizes, but got %d", dim); + } + dims.push_back(dim); + } + return Shape(std::move(dims)); +} + +ShapeProto Shape::ToProto() const { + ShapeProto proto; + proto.mutable_dims()->Reserve(dims().size()); + for (int64_t dim : dims()) { + proto.mutable_dims()->AddAlreadyReserved(dim); + } + return proto; +} + int64_t Shape::num_elements() const { int64_t count = 1; for (int64_t d : dims_) { diff --git a/tensorflow/compiler/xla/python/ifrt/shape.h b/tensorflow/compiler/xla/python/ifrt/shape.h index 3558e3518ed84d..f3ce028789d5ef 100644 --- a/tensorflow/compiler/xla/python/ifrt/shape.h +++ b/tensorflow/compiler/xla/python/ifrt/shape.h @@ -22,6 +22,8 @@ limitations under the License. #include "absl/container/inlined_vector.h" #include "absl/types/span.h" +#include "tensorflow/compiler/xla/python/ifrt/types.pb.h" +#include "tensorflow/compiler/xla/statusor.h" namespace xla { namespace ifrt { @@ -42,6 +44,12 @@ class Shape { Shape& operator=(const Shape&) = default; Shape& operator=(Shape&&) = default; + // Constructs `Shape` from `ShapeProto`. + static StatusOr FromProto(const ShapeProto& proto); + + // Returns a `ShapeProto` representation. + ShapeProto ToProto() const; + absl::Span dims() const { return dims_; } bool operator==(const Shape& other) const { return dims_ == other.dims_; } diff --git a/tensorflow/compiler/xla/python/ifrt/sharding.cc b/tensorflow/compiler/xla/python/ifrt/sharding.cc index f057ad53fccf83..8caaf9f12a83e4 100644 --- a/tensorflow/compiler/xla/python/ifrt/sharding.cc +++ b/tensorflow/compiler/xla/python/ifrt/sharding.cc @@ -159,8 +159,10 @@ std::ostream& operator<<(std::ostream& os, const Sharding& sharding) { return os << sharding.DebugString(); } -std::unique_ptr SingleDeviceSharding::Create(Device* device) { - return std::unique_ptr(new SingleDeviceSharding(device)); +std::unique_ptr SingleDeviceSharding::Create( + Device* device) { + return std::unique_ptr( + new SingleDeviceSharding(device)); } StatusOr>>> @@ -187,8 +189,9 @@ std::string SingleDeviceSharding::DebugString() const { devices_.front()->ToString()); } -std::unique_ptr OpaqueSharding::Create(DeviceList devices) { - return std::unique_ptr(new OpaqueSharding(std::move(devices))); +std::unique_ptr OpaqueSharding::Create(DeviceList devices) { + return std::unique_ptr( + new OpaqueSharding(std::move(devices))); } OpaqueSharding::OpaqueSharding(DeviceList devices) @@ -217,10 +220,10 @@ std::string OpaqueSharding::DebugString() const { })); } -std::unique_ptr ConcreteSharding::Create( +std::unique_ptr ConcreteSharding::Create( DeviceList devices, Shape shape, std::vector shard_shapes) { CHECK_EQ(devices.size(), shard_shapes.size()); - return std::unique_ptr(new ConcreteSharding( + return std::unique_ptr(new ConcreteSharding( std::move(devices), std::move(shape), std::move(shard_shapes))); } @@ -270,10 +273,9 @@ std::string ConcreteSharding::DebugString() const { })); } -std::unique_ptr ConcreteEvenSharding::Create(DeviceList devices, - Shape shape, - Shape shard_shape) { - return std::unique_ptr(new ConcreteEvenSharding( +std::unique_ptr ConcreteEvenSharding::Create( + DeviceList devices, Shape shape, Shape shard_shape) { + return std::unique_ptr(new ConcreteEvenSharding( std::move(devices), std::move(shape), std::move(shard_shape))); } @@ -318,7 +320,7 @@ std::string ConcreteEvenSharding::DebugString() const { shape_.DebugString(), shard_shape_.DebugString()); } -StatusOr> ShardingParamSharding::Create( +StatusOr> ShardingParamSharding::Create( ShardingParam sharding_param, DeviceList devices) { int64_t device_count = absl::c_accumulate(sharding_param.minor_to_major().axis_sizes, 1, @@ -329,7 +331,7 @@ StatusOr> ShardingParamSharding::Create( "%d", device_count, devices.size()); } - return std::unique_ptr( + return std::unique_ptr( new ShardingParamSharding(std::move(sharding_param), std::move(devices))); } diff --git a/tensorflow/compiler/xla/python/ifrt/sharding.h b/tensorflow/compiler/xla/python/ifrt/sharding.h index 375cedc16a0a68..6e3d30e99d2584 100644 --- a/tensorflow/compiler/xla/python/ifrt/sharding.h +++ b/tensorflow/compiler/xla/python/ifrt/sharding.h @@ -83,7 +83,7 @@ class SingleDeviceSharding final : public llvm::RTTIExtends { public: // Creates a single-device sharding. - static std::unique_ptr Create(Device* device); + static std::unique_ptr Create(Device* device); // Sharding implementation. @@ -110,7 +110,7 @@ class SingleDeviceSharding final class OpaqueSharding : public llvm::RTTIExtends { public: // Creates an opaque sharding. `Disassemble()` will fail. - static std::unique_ptr Create(DeviceList devices); + static std::unique_ptr Create(DeviceList devices); // Sharding implementation. @@ -138,8 +138,8 @@ class ConcreteSharding : public llvm::RTTIExtends { public: // Creates a concrete sharding that may contain non-identical shard shapes. // REQUIRES: devices.size() == shard_shapes.size() - static std::unique_ptr Create(DeviceList devices, Shape shape, - std::vector shard_shapes); + static std::unique_ptr Create( + DeviceList devices, Shape shape, std::vector shard_shapes); Shape shape() const { DCHECK(this); @@ -179,8 +179,9 @@ class ConcreteEvenSharding : public llvm::RTTIExtends { public: // Creates a concrete even sharding. - static std::unique_ptr Create(DeviceList devices, Shape shape, - Shape shard_shape); + static std::unique_ptr Create(DeviceList devices, + Shape shape, + Shape shard_shape); Shape shape() const { DCHECK(this); @@ -216,7 +217,7 @@ class ConcreteEvenSharding class ShardingParamSharding : public llvm::RTTIExtends { public: - static StatusOr> Create( + static StatusOr> Create( ShardingParam sharding_param, DeviceList devices); StatusOr>>> diff --git a/tensorflow/compiler/xla/python/ifrt/sharding.proto b/tensorflow/compiler/xla/python/ifrt/sharding.proto new file mode 100644 index 00000000000000..066bce11413998 --- /dev/null +++ b/tensorflow/compiler/xla/python/ifrt/sharding.proto @@ -0,0 +1,46 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto3"; + +package xla.ifrt; + +import "tensorflow/compiler/xla/python/ifrt/types.proto"; + +// Wire format for `SingleDeviceSharding`. +message SingleDeviceShardingProto { + // Serialization and deserialization are expected to ensure that device ids + // are stable across proto construction and consumption. + int32 device_id = 1; +} + +// Wire format for `OpaqueSharding`. +message OpaqueShardingProto { + DeviceListProto devices = 1; +} + +// Wire format for `ConcreteSharding`. +message ConcreteShardingProto { + DeviceListProto devices = 1; + ShapeProto shape = 2; + repeated ShapeProto shard_shapes = 3; +} + +// Wire format for `ConcreteEvenSharding`. +message ConcreteEvenShardingProto { + DeviceListProto devices = 1; + ShapeProto shape = 2; + ShapeProto shard_shape = 3; +} diff --git a/tensorflow/compiler/xla/python/ifrt/sharding_serdes.cc b/tensorflow/compiler/xla/python/ifrt/sharding_serdes.cc new file mode 100644 index 00000000000000..d9ade8d6a62b96 --- /dev/null +++ b/tensorflow/compiler/xla/python/ifrt/sharding_serdes.cc @@ -0,0 +1,240 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/python/ifrt/sharding_serdes.h" + +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/python/ifrt/client.h" +#include "tensorflow/compiler/xla/python/ifrt/device.h" +#include "tensorflow/compiler/xla/python/ifrt/serdes.h" +#include "tensorflow/compiler/xla/python/ifrt/shape.h" +#include "tensorflow/compiler/xla/python/ifrt/sharding.h" +#include "tensorflow/compiler/xla/python/ifrt/sharding.pb.h" +#include "tensorflow/tsl/platform/statusor.h" + +namespace xla { +namespace ifrt { + +char DeserializeShardingOptions::ID = 0; + +namespace { + +// Serialization/deserialization for `SingleDeviceSharding`. +class SingleDeviceShardingSerDes + : public llvm::RTTIExtends { + public: + absl::string_view type_name() const override { + return "xla::ifrt::SingleDeviceSharding"; + } + + absl::StatusOr Serialize(Serializable& serializable) override { + const SingleDeviceSharding& sharding = + llvm::cast(serializable); + SingleDeviceShardingProto proto; + proto.set_device_id(sharding.devices().front()->id()); + return proto.SerializeAsString(); + } + + absl::StatusOr> Deserialize( + const std::string& serialized, + std::unique_ptr options) override { + TF_ASSIGN_OR_RETURN(auto deserialize_sharding_options, + GetDeserializeShardingOptions(std::move(options))); + SingleDeviceShardingProto proto; + if (!proto.ParseFromString(serialized)) { + return absl::InvalidArgumentError( + "Failed to parse serialized SimpleDeviceSharding"); + } + TF_ASSIGN_OR_RETURN( + Device * device, + deserialize_sharding_options->client->LookupDevice(proto.device_id())); + return SingleDeviceSharding::Create(device); + } + + static char ID; // NOLINT +}; + +// Serialization/deserialization for `OpaqueSharding`. +class OpaqueShardingSerDes + : public llvm::RTTIExtends { + public: + absl::string_view type_name() const override { + return "xla::ifrt::OpaqueSharding"; + } + + absl::StatusOr Serialize(Serializable& serializable) override { + const OpaqueSharding& sharding = llvm::cast(serializable); + OpaqueShardingProto proto; + *proto.mutable_devices() = sharding.devices().ToProto(); + return proto.SerializeAsString(); + } + + absl::StatusOr> Deserialize( + const std::string& serialized, + std::unique_ptr options) override { + TF_ASSIGN_OR_RETURN(auto deserialize_sharding_options, + GetDeserializeShardingOptions(std::move(options))); + + OpaqueShardingProto proto; + if (!proto.ParseFromString(serialized)) { + return absl::InvalidArgumentError( + "Failed to parse serialized OpaqueSharding"); + } + TF_ASSIGN_OR_RETURN(auto devices, DeviceList::FromProto( + deserialize_sharding_options->client, + proto.devices())); + return OpaqueSharding::Create(std::move(devices)); + } + + static char ID; // NOLINT +}; + +// Serialization/deserialization for `ConcreteSharding`. +class ConcreteShardingSerDes + : public llvm::RTTIExtends { + public: + absl::string_view type_name() const override { + return "xla::ifrt::ConcreteSharding"; + } + + absl::StatusOr Serialize(Serializable& serializable) override { + const ConcreteSharding& sharding = + llvm::cast(serializable); + ConcreteShardingProto proto; + *proto.mutable_devices() = sharding.devices().ToProto(); + *proto.mutable_shape() = sharding.shape().ToProto(); + for (const Shape& shape : sharding.shard_shapes()) { + *proto.add_shard_shapes() = shape.ToProto(); + } + return proto.SerializeAsString(); + } + + absl::StatusOr> Deserialize( + const std::string& serialized, + std::unique_ptr options) override { + TF_ASSIGN_OR_RETURN(auto deserialize_sharding_options, + GetDeserializeShardingOptions(std::move(options))); + + ConcreteShardingProto proto; + if (!proto.ParseFromString(serialized)) { + return absl::InvalidArgumentError( + "Failed to parse serialized ConcreteSharding"); + } + TF_ASSIGN_OR_RETURN(auto devices, DeviceList::FromProto( + deserialize_sharding_options->client, + proto.devices())); + TF_ASSIGN_OR_RETURN(auto shape, Shape::FromProto(proto.shape())); + std::vector shard_shapes; + shard_shapes.reserve(proto.shard_shapes_size()); + for (const auto& shard_shape_proto : proto.shard_shapes()) { + TF_ASSIGN_OR_RETURN(auto shard_shape, + Shape::FromProto(shard_shape_proto)); + shard_shapes.push_back(std::move(shard_shape)); + } + return ConcreteSharding::Create(std::move(devices), std::move(shape), + std::move(shard_shapes)); + } + + static char ID; // NOLINT +}; + +// Serialization/deserialization for `ConcreteEvenSharding`. +class ConcreteEvenShardingSerDes + : public llvm::RTTIExtends { + public: + absl::string_view type_name() const override { + return "xla::ifrt::ConcreteEvenSharding"; + } + + absl::StatusOr Serialize(Serializable& serializable) override { + const ConcreteEvenSharding& sharding = + llvm::cast(serializable); + ConcreteEvenShardingProto proto; + *proto.mutable_devices() = sharding.devices().ToProto(); + *proto.mutable_shape() = sharding.shape().ToProto(); + *proto.mutable_shard_shape() = sharding.shard_shape().ToProto(); + return proto.SerializeAsString(); + } + + absl::StatusOr> Deserialize( + const std::string& serialized, + std::unique_ptr options) override { + TF_ASSIGN_OR_RETURN(auto deserialize_sharding_options, + GetDeserializeShardingOptions(std::move(options))); + + ConcreteEvenShardingProto proto; + if (!proto.ParseFromString(serialized)) { + return absl::InvalidArgumentError( + "Failed to parse serialized ConcreteEvenSharding"); + } + TF_ASSIGN_OR_RETURN(auto devices, DeviceList::FromProto( + deserialize_sharding_options->client, + proto.devices())); + TF_ASSIGN_OR_RETURN(auto shape, Shape::FromProto(proto.shape())); + TF_ASSIGN_OR_RETURN(auto shard_shape, + Shape::FromProto(proto.shard_shape())); + return ConcreteEvenSharding::Create(std::move(devices), std::move(shape), + std::move(shard_shape)); + } + + static char ID; // NOLINT +}; + +// TODO(hyeontaek): Implement `ShardingParamShardingSerDes`. + +[[maybe_unused]] char SingleDeviceShardingSerDes::ID = 0; // NOLINT +[[maybe_unused]] char OpaqueShardingSerDes::ID = 0; // NOLINT +[[maybe_unused]] char ConcreteShardingSerDes::ID = 0; // NOLINT +[[maybe_unused]] char ConcreteEvenShardingSerDes::ID = 0; // NOLINT + +// clang-format off +bool register_single_device_sharding_serdes = ([]{ + RegisterSerDes( + std::make_unique()); +}(), true); + +bool register_opaque_sharding_serdes = ([]{ + RegisterSerDes( + std::make_unique()); +}(), true); + +bool register_concrete_sharding_serdes = ([]{ + RegisterSerDes( + std::make_unique()); +}(), true); + +bool register_concrete_even_sharding_serdes = ([]{ + RegisterSerDes( + std::make_unique()); +}(), true); +// clang-format on + +} // namespace + +StatusOr> +GetDeserializeShardingOptions(std::unique_ptr options) { + if (!llvm::isa(options.get())) { + return xla::InvalidArgument("options must be DeserializeShardingOptions"); + } + return std::unique_ptr( + static_cast(options.release())); +} + +} // namespace ifrt +} // namespace xla diff --git a/tensorflow/compiler/xla/python/ifrt/sharding_serdes.h b/tensorflow/compiler/xla/python/ifrt/sharding_serdes.h new file mode 100644 index 00000000000000..965670bcbc3401 --- /dev/null +++ b/tensorflow/compiler/xla/python/ifrt/sharding_serdes.h @@ -0,0 +1,48 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_IFRT_SHARDING_SERDES_H_ +#define TENSORFLOW_COMPILER_XLA_PYTHON_IFRT_SHARDING_SERDES_H_ + +#include + +#include "llvm/Support/ExtensibleRTTI.h" +#include "tensorflow/compiler/xla/python/ifrt/serdes.h" +#include "tensorflow/compiler/xla/statusor.h" + +namespace xla { +namespace ifrt { + +class Client; + +// Options for deserializing shardings. +struct DeserializeShardingOptions + : llvm::RTTIExtends { + explicit DeserializeShardingOptions(Client* client) : client(client) {} + + static char ID; // NOLINT + + // The client whose devices will be used by deserialized shardings. + Client* client; +}; + +// Casts `DeserializeOptions` into `DeserializeShardingOptions`. +StatusOr> +GetDeserializeShardingOptions(std::unique_ptr options); + +} // namespace ifrt +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_PYTHON_IFRT_SHARDING_SERDES_H_ diff --git a/tensorflow/compiler/xla/python/ifrt/sharding_serdes_test.cc b/tensorflow/compiler/xla/python/ifrt/sharding_serdes_test.cc new file mode 100644 index 00000000000000..90efc6d9667167 --- /dev/null +++ b/tensorflow/compiler/xla/python/ifrt/sharding_serdes_test.cc @@ -0,0 +1,157 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/python/ifrt/sharding_serdes.h" + +#include +#include +#include + +#include +#include +#include "absl/container/flat_hash_map.h" +#include "tensorflow/compiler/xla/python/ifrt/mock.h" +#include "tensorflow/compiler/xla/python/ifrt/serdes.h" +#include "tensorflow/compiler/xla/python/ifrt/sharding.h" + +namespace xla { +namespace ifrt { +namespace { + +using ::testing::ElementsAreArray; + +// Test fixture for sharding serialization and deserialization. It makes a mock +// client with a number of fake devices. Client implements `devices()` and +// `LookupDevice()`, and Device implements `id()`, with an arbitrary device ids +// assigned. +class ShardingSerDesTest : public ::testing::TestWithParam { + public: + void SetUp() override { + const int num_devices = GetParam(); + device_map_.reserve(num_devices); + devices_.reserve(num_devices); + for (int i = 0; i < num_devices; ++i) { + auto device = std::make_unique(); + ON_CALL(*device, id).WillByDefault([i]() { return i + 10; }); + devices_.push_back(device.get()); + device_map_.insert({i + 10, std::move(device)}); + } + client_ = std::make_unique(); + ON_CALL(*client_, devices) + .WillByDefault( + [this]() -> absl::Span { return devices_; }); + ON_CALL(*client_, LookupDevice) + .WillByDefault([this](int device_id) -> StatusOr { + auto it = device_map_.find(device_id); + if (it == device_map_.end()) { + return InvalidArgument("Unexpected device id: %d", device_id); + } + return it->second.get(); + }); + } + Client* client() { return client_.get(); } + + private: + std::unique_ptr client_; + absl::flat_hash_map> device_map_; + std::vector devices_; +}; + +TEST_P(ShardingSerDesTest, SingleDeviceShardingRoundTrip) { + auto sharding = SingleDeviceSharding::Create(client()->devices().front()); + + TF_ASSERT_OK_AND_ASSIGN(auto serialized, Serialize(*sharding)); + + auto deserialized_options = + std::make_unique(client()); + TF_ASSERT_OK_AND_ASSIGN( + auto deserialized, + Deserialize(serialized, std::move(deserialized_options))); + + const auto* out_sharding = + llvm::dyn_cast(deserialized.get()); + ASSERT_NE(out_sharding, nullptr); + EXPECT_THAT(out_sharding->devices(), ElementsAreArray(sharding->devices())); +} + +TEST_P(ShardingSerDesTest, OpaqueShardingRoundTrip) { + auto sharding = OpaqueSharding::Create(DeviceList(DeviceList::Devices( + client()->devices().begin(), client()->devices().end()))); + + TF_ASSERT_OK_AND_ASSIGN(auto serialized, Serialize(*sharding)); + + auto deserialized_options = + std::make_unique(client()); + TF_ASSERT_OK_AND_ASSIGN( + auto deserialized, + Deserialize(serialized, std::move(deserialized_options))); + + const auto* out_sharding = llvm::dyn_cast(deserialized.get()); + ASSERT_NE(out_sharding, nullptr); + EXPECT_THAT(out_sharding->devices(), ElementsAreArray(sharding->devices())); +} + +TEST_P(ShardingSerDesTest, ConcreteShardingRoundTrip) { + auto sharding = ConcreteSharding::Create( + DeviceList(DeviceList::Devices(client()->devices().begin(), + client()->devices().end())), + /*shape=*/Shape({10, 20}), + /*shard_shapes=*/{Shape({3, 20}), Shape({7, 20})}); + + TF_ASSERT_OK_AND_ASSIGN(auto serialized, Serialize(*sharding)); + + auto deserialized_options = + std::make_unique(client()); + TF_ASSERT_OK_AND_ASSIGN( + auto deserialized, + Deserialize(serialized, std::move(deserialized_options))); + + const auto* out_sharding = + llvm::dyn_cast(deserialized.get()); + ASSERT_NE(out_sharding, nullptr); + EXPECT_THAT(out_sharding->devices(), ElementsAreArray(sharding->devices())); + EXPECT_THAT(out_sharding->shape(), sharding->shape()); + EXPECT_THAT(out_sharding->shard_shapes(), + ElementsAreArray(sharding->shard_shapes())); +} + +TEST_P(ShardingSerDesTest, ConcreteEvenShardingRoundTrip) { + auto sharding = ConcreteEvenSharding::Create( + DeviceList(DeviceList::Devices(client()->devices().begin(), + client()->devices().end())), + /*shape=*/Shape({10, 20}), + /*shard_shape=*/Shape({5, 20})); + + TF_ASSERT_OK_AND_ASSIGN(auto serialized, Serialize(*sharding)); + + auto deserialized_options = + std::make_unique(client()); + TF_ASSERT_OK_AND_ASSIGN( + auto deserialized, + Deserialize(serialized, std::move(deserialized_options))); + + const auto* out_sharding = + llvm::dyn_cast(deserialized.get()); + ASSERT_NE(out_sharding, nullptr); + EXPECT_THAT(out_sharding->devices(), ElementsAreArray(sharding->devices())); + EXPECT_THAT(out_sharding->shape(), sharding->shape()); + EXPECT_THAT(out_sharding->shard_shape(), sharding->shard_shape()); +} + +INSTANTIATE_TEST_SUITE_P(NumDevices, ShardingSerDesTest, testing::Values(2)); + +} // namespace +} // namespace ifrt +} // namespace xla diff --git a/tensorflow/compiler/xla/python/ifrt/types.proto b/tensorflow/compiler/xla/python/ifrt/types.proto new file mode 100644 index 00000000000000..e9c799bcc1ed6c --- /dev/null +++ b/tensorflow/compiler/xla/python/ifrt/types.proto @@ -0,0 +1,31 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto3"; + +package xla.ifrt; + +// Wire format for `DeviceList`. +message DeviceListProto { + // Serialization and deserialization are expected to ensure that device ids + // are stable across proto construction and consumption. + repeated int32 device_ids = 1; +} + +// Wire format for `Shape`. Currently support static shapes with all dimension +// sizes greater than or equal to 0. +message ShapeProto { + repeated int64 dims = 1; +} diff --git a/tensorflow/compiler/xla/python/pjrt_ifrt/BUILD b/tensorflow/compiler/xla/python/pjrt_ifrt/BUILD index e1c1a36bb4ef62..060be5ab05d5c9 100644 --- a/tensorflow/compiler/xla/python/pjrt_ifrt/BUILD +++ b/tensorflow/compiler/xla/python/pjrt_ifrt/BUILD @@ -110,6 +110,42 @@ xla_cc_test( ], ) +tf_proto_library( + name = "xla_sharding_proto", + srcs = ["xla_sharding.proto"], + protodeps = [ + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/python/ifrt:types_proto", + ], +) + +cc_library( + name = "xla_sharding_serdes", + srcs = ["xla_sharding_serdes.cc"], + deps = [ + ":xla_ifrt", + ":xla_sharding_proto_cc", + "//tensorflow/compiler/xla/hlo/ir:hlo", + "//tensorflow/compiler/xla/python/ifrt", + "//tensorflow/compiler/xla/python/ifrt:serdes", + "//tensorflow/compiler/xla/python/ifrt:sharding_serdes", + ], + alwayslink = 1, +) + +xla_cc_test( + name = "xla_sharding_serdes_test", + srcs = ["xla_sharding_serdes_test.cc"], + deps = [ + ":xla_ifrt", + ":xla_sharding_serdes", + "//tensorflow/compiler/xla/hlo/ir:hlo", + "//tensorflow/compiler/xla/python/ifrt:mock", + "//tensorflow/compiler/xla/python/ifrt:sharding_serdes", + "@com_google_googletest//:gtest_main", + ], +) + # TODO(hyeontaek): Move this target out of pjrt_ifrt. cc_library( name = "xla_executable_impl_test_lib", diff --git a/tensorflow/compiler/xla/python/pjrt_ifrt/xla_sharding.proto b/tensorflow/compiler/xla/python/pjrt_ifrt/xla_sharding.proto new file mode 100644 index 00000000000000..0ff8040b66233e --- /dev/null +++ b/tensorflow/compiler/xla/python/pjrt_ifrt/xla_sharding.proto @@ -0,0 +1,27 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto3"; + +package xla.ifrt; + +import "tensorflow/compiler/xla/python/ifrt/types.proto"; +import "tensorflow/compiler/xla/xla_data.proto"; + +// Wire format for `HloSharding`. +message HloShardingProto { + DeviceListProto devices = 1; + xla.OpSharding xla_op_sharding = 2; +} diff --git a/tensorflow/compiler/xla/python/pjrt_ifrt/xla_sharding_serdes.cc b/tensorflow/compiler/xla/python/pjrt_ifrt/xla_sharding_serdes.cc new file mode 100644 index 00000000000000..c3d8d2470600b9 --- /dev/null +++ b/tensorflow/compiler/xla/python/pjrt_ifrt/xla_sharding_serdes.cc @@ -0,0 +1,79 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "tensorflow/compiler/xla/hlo/ir/hlo_sharding.h" +#include "tensorflow/compiler/xla/python/ifrt/serdes.h" +#include "tensorflow/compiler/xla/python/ifrt/sharding_serdes.h" +#include "tensorflow/compiler/xla/python/pjrt_ifrt/xla_sharding.h" +#include "tensorflow/compiler/xla/python/pjrt_ifrt/xla_sharding.pb.h" + +namespace xla { +namespace ifrt { + +namespace { + +// Serialization/deserialization for `HloSharding`. +class HloShardingSerDes : public llvm::RTTIExtends { + public: + absl::string_view type_name() const override { + return "xla::ifrt::HloSharding"; + } + + absl::StatusOr Serialize(Serializable& serializable) override { + const HloSharding& sharding = llvm::cast(serializable); + HloShardingProto proto; + *proto.mutable_devices() = sharding.devices().ToProto(); + *proto.mutable_xla_op_sharding() = sharding.xla_hlo_sharding().ToProto(); + return proto.SerializeAsString(); + } + + absl::StatusOr> Deserialize( + const std::string& serialized, + std::unique_ptr options) override { + TF_ASSIGN_OR_RETURN(auto deserialize_sharding_options, + GetDeserializeShardingOptions(std::move(options))); + + HloShardingProto proto; + if (!proto.ParseFromString(serialized)) { + return absl::InvalidArgumentError( + "Failed to parse serialized HloSharding"); + } + TF_ASSIGN_OR_RETURN(auto devices, DeviceList::FromProto( + deserialize_sharding_options->client, + proto.devices())); + TF_ASSIGN_OR_RETURN(auto xla_hlo_sharding, + xla::HloSharding::FromProto(proto.xla_op_sharding())); + return HloSharding::Create(std::move(devices), std::move(xla_hlo_sharding)); + } + + static char ID; // NOLINT +}; + +[[maybe_unused]] char HloShardingSerDes::ID = 0; // NOLINT + +// clang-format off +bool register_hlo_sharding_serdes = ([] { + RegisterSerDes( + std::make_unique()); +}(), true); +// clang-format on + +} // namespace +} // namespace ifrt +} // namespace xla diff --git a/tensorflow/compiler/xla/python/pjrt_ifrt/xla_sharding_serdes_test.cc b/tensorflow/compiler/xla/python/pjrt_ifrt/xla_sharding_serdes_test.cc new file mode 100644 index 00000000000000..e043fb7e575f2b --- /dev/null +++ b/tensorflow/compiler/xla/python/pjrt_ifrt/xla_sharding_serdes_test.cc @@ -0,0 +1,95 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include +#include +#include "tensorflow/compiler/xla/hlo/ir/hlo_sharding.h" +#include "tensorflow/compiler/xla/python/ifrt/mock.h" +#include "tensorflow/compiler/xla/python/ifrt/sharding_serdes.h" +#include "tensorflow/compiler/xla/python/pjrt_ifrt/xla_sharding.h" + +namespace xla { +namespace ifrt { +namespace { + +using ::testing::ElementsAreArray; + +// Test fixture for sharding serialization and deserialization. It makes a mock +// client with a number of fake devices. Client implements `devices()` and +// `LookupDevice()`, and Device implements `id()`, with an arbitrary device ids +// assigned. +class XlaShardingSerDesTest : public ::testing::TestWithParam { + public: + void SetUp() override { + const int num_devices = GetParam(); + device_map_.reserve(num_devices); + devices_.reserve(num_devices); + for (int i = 0; i < num_devices; ++i) { + auto device = std::make_unique(); + ON_CALL(*device, id).WillByDefault([i]() { return i + 10; }); + devices_.push_back(device.get()); + device_map_.insert({i + 10, std::move(device)}); + } + client_ = std::make_unique(); + ON_CALL(*client_, devices) + .WillByDefault( + [this]() -> absl::Span { return devices_; }); + ON_CALL(*client_, LookupDevice) + .WillByDefault([this](int device_id) -> StatusOr { + auto it = device_map_.find(device_id); + if (it == device_map_.end()) { + return InvalidArgument("Unexpected device id: %d", device_id); + } + return it->second.get(); + }); + } + Client* client() { return client_.get(); } + + private: + std::unique_ptr client_; + absl::flat_hash_map> device_map_; + std::vector devices_; +}; + +TEST_P(XlaShardingSerDesTest, HloShardingRoundTrip) { + auto xla_hlo_sharding = xla::HloSharding::Tile(xla::TileAssignment({2, 1})); + auto sharding = HloSharding::Create( + DeviceList(DeviceList::Devices(client()->devices().begin(), + client()->devices().end())), + /*xla_hlo_sharding=*/xla_hlo_sharding); + + TF_ASSERT_OK_AND_ASSIGN(auto serialized, Serialize(*sharding)); + + auto deserialized_options = + std::make_unique(client()); + TF_ASSERT_OK_AND_ASSIGN( + auto deserialized, + Deserialize(serialized, std::move(deserialized_options))); + + const auto* out_sharding = llvm::dyn_cast(deserialized.get()); + ASSERT_NE(out_sharding, nullptr); + EXPECT_THAT(out_sharding->devices(), ElementsAreArray(sharding->devices())); + EXPECT_EQ(out_sharding->xla_hlo_sharding(), sharding->xla_hlo_sharding()); +} + +INSTANTIATE_TEST_SUITE_P(NumDevices, XlaShardingSerDesTest, testing::Values(2)); + +} // namespace +} // namespace ifrt +} // namespace xla From 4d911a7d82fb8940cd4c7c9355cb4f6feda6a59c Mon Sep 17 00:00:00 2001 From: Ashish Shenoy Date: Thu, 13 Jul 2023 22:39:24 -0700 Subject: [PATCH 301/376] Merge `GetWindowedOutputSizeVerboseV2` with `GetWindowedOutputSizeVerbose`. PiperOrigin-RevId: 548022384 --- tensorflow/compiler/mlir/lite/ir/tfl_ops.cc | 2 +- .../compiler/mlir/tensorflow/ir/tf_ops_a_m.cc | 2 +- .../tensorflow/transforms/legalize_hlo.cc | 2 +- .../mlir/tf2xla/transforms/legalize_tf.cc | 10 +++-- .../compiler/mlir/tosa/g3doc/legalization.md | 2 +- .../mlir/tosa/transforms/legalize_utils.cc | 2 +- .../tf2xla/kernels/conv_op_helpers.cc | 2 +- .../kernels/extract_image_patches_op.cc | 2 +- .../xla/service/dynamic_window_utils.h | 2 +- .../core/framework/kernel_shape_util.cc | 33 ++++++--------- tensorflow/core/framework/kernel_shape_util.h | 8 ---- .../core/kernels/conv_grad_filter_ops.cc | 4 +- .../core/kernels/conv_grad_filter_ops_3d.cc | 42 +++++++++++-------- .../kernels/conv_grad_filter_ops_launcher.cc | 8 ++-- .../core/kernels/conv_grad_input_ops.cc | 4 +- tensorflow/core/kernels/conv_grad_input_ops.h | 8 ++-- .../core/kernels/conv_grad_input_ops_3d.cc | 42 +++++++++++-------- .../core/kernels/conv_grad_shape_utils.cc | 2 +- tensorflow/core/kernels/conv_ops.cc | 4 +- tensorflow/core/kernels/conv_ops_impl.h | 10 ++--- .../core/kernels/depthwise_conv_grad_op.cc | 14 ++++--- tensorflow/core/kernels/depthwise_conv_op.cc | 14 ++++--- tensorflow/core/kernels/mkl/mkl_conv_ops.h | 10 ++--- .../kernels/mkl/mkl_pooling_ops_common.cc | 22 +++++----- tensorflow/core/kernels/ops_util_test.cc | 8 ++-- tensorflow/core/kernels/pooling_ops_common.cc | 28 +++++++------ tensorflow/core/ops/array_ops.cc | 20 ++++----- tensorflow/core/ops/nn_ops.cc | 8 ++-- 28 files changed, 162 insertions(+), 153 deletions(-) diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc index 230e3565ad1865..9dc61d0d907981 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc @@ -1305,7 +1305,7 @@ static LogicalResult ComputeConvWindowedOutputSize( int64_t pad_low; int64_t pad_high; - tensorflow::Status status = tensorflow::GetWindowedOutputSizeVerboseV2( + tensorflow::Status status = tensorflow::GetWindowedOutputSizeVerbose( input_size, filter_size, dilation_rate, stride, padding, output_size, &pad_low, &pad_high); // Return failure if expected_output_size could not be calculated. diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc index 7f066b3f327eb5..f66c996f32a888 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc @@ -1922,7 +1922,7 @@ static LogicalResult inferConvReturnTypeComponents( // Skip if input or filter size is dynamic. if (input_ty.isDynamicDim(dim) || filter_ty.isDynamicDim(i)) continue; // Calculate the expected_output_size. - tensorflow::Status status = tensorflow::GetWindowedOutputSizeVerboseV2( + tensorflow::Status status = tensorflow::GetWindowedOutputSizeVerbose( input_ty.getDimSize(dim), filter_ty.getDimSize(i), get_int(dilations[dim]), stride, padding, &expected_output_size, &pad_low, &pad_high); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc b/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc index 3ffa9a120ca573..87586a57bbef2e 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc @@ -502,7 +502,7 @@ class Convert2DConvOp : public OpConversionPattern, int64_t output_size; int64_t pad_low_int64; int64_t pad_high_int64; - tensorflow::Status status = tensorflow::GetWindowedOutputSizeVerboseV2( + tensorflow::Status status = tensorflow::GetWindowedOutputSizeVerbose( conv_op.getLhs().getType().cast().getDimSize( input_spatial_dim[i]), conv_op.getRhs().getType().cast().getDimSize( diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf.cc index cce28c45d704ec..be7fceb6632773 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf.cc @@ -14,14 +14,18 @@ limitations under the License. ==============================================================================*/ // This file implements logic for lowering TensorFlow dialect to XLA dialect. - +#include #include +#include #include #include #include #include #include #include +#include +#include +#include #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" @@ -1222,7 +1226,7 @@ class ConvertConvOp : public OpRewritePattern { int64_t pad_high_int64; int64_t input_size = input_ty.getDimSize(dim); if (input_size == ShapedType::kDynamic) return failure(); - tsl::Status status = tensorflow::GetWindowedOutputSizeVerboseV2( + tsl::Status status = tensorflow::GetWindowedOutputSizeVerbose( input_size, filter_ty.getDimSize(i), dilation, stride, padding, &output_size, &pad_low_int64, &pad_high_int64); if (!status.ok()) return failure(); @@ -4886,7 +4890,7 @@ class ConvertConvBackpropInputOp : public OpRewritePattern { padding_after = explicit_paddings[2 * spatial_dim + 1]; } int64_t expected_output_size = 0; - auto status = GetWindowedOutputSizeVerboseV2( + auto status = GetWindowedOutputSizeVerbose( input_size, filter_size, dilation, stride, padding, &expected_output_size, &padding_before, &padding_after); if (!status.ok()) return failure(); diff --git a/tensorflow/compiler/mlir/tosa/g3doc/legalization.md b/tensorflow/compiler/mlir/tosa/g3doc/legalization.md index d09a96f178bbaf..8da389f17e0b84 100644 --- a/tensorflow/compiler/mlir/tosa/g3doc/legalization.md +++ b/tensorflow/compiler/mlir/tosa/g3doc/legalization.md @@ -140,7 +140,7 @@ vector get_padding_values_from_pad_type(tensorflow::Padding padding, tens int64 op_size, pad_before_tf, pad_after_tf; - tensorflow::GetWindowedOutputSizeVerboseV2(input_type.shape[ifm_dim], filter_type.shape[filter_dim], + tensorflow::GetWindowedOutputSizeVerbose(input_type.shape[ifm_dim], filter_type.shape[filter_dim], dim_dilation, dim_stride, padding, // Outputs &op_size, &pad_before_tf, &pad_after_tf); diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.cc b/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.cc index eb4edcd9c47080..39c3ec9b5a5eb5 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.cc +++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.cc @@ -528,7 +528,7 @@ bool getPaddingValuesFromPadType(tensorflow::Padding tf_pad, ip_size = ip_size < 0 ? f_size * dim_dilation : ip_size; int64_t op_size, pad_before_tf, pad_after_tf; // Complains if using int64_T - tensorflow::Status status = tensorflow::GetWindowedOutputSizeVerboseV2( + tensorflow::Status status = tensorflow::GetWindowedOutputSizeVerbose( ip_size, f_size, dim_dilation, dim_stride, tf_pad, &op_size, &pad_before_tf, &pad_after_tf); if (!status.ok()) return false; diff --git a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc index 833efb34649950..242c022c892faf 100644 --- a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc +++ b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc @@ -292,7 +292,7 @@ StatusOr MakeXlaForwardConvOp(StringPiece /*type_string*/, } int64_t unused_output_size; - TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerboseV2( + TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose( input_shape.dimensions(dim), filter_shape.dimensions(i), rhs_dilation[i], window_strides[i], attrs.padding, &unused_output_size, &padding[i].first, &padding[i].second)); diff --git a/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc b/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc index 49e80226786355..01e1d57e732b6a 100644 --- a/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc @@ -149,7 +149,7 @@ class ExtractImagePatchesOp : public XlaOpKernel { int64_t unused_output_size; OP_REQUIRES_OK( - ctx, GetWindowedOutputSizeVerboseV2( + ctx, GetWindowedOutputSizeVerbose( input_shape.dim_size(dim), ksizes_[dim], rhs_dilation[i], window_strides[i], padding_, &unused_output_size, &padding[i].first, &padding[i].second)); diff --git a/tensorflow/compiler/xla/service/dynamic_window_utils.h b/tensorflow/compiler/xla/service/dynamic_window_utils.h index 40e891fad4b59e..11392ea33884e4 100644 --- a/tensorflow/compiler/xla/service/dynamic_window_utils.h +++ b/tensorflow/compiler/xla/service/dynamic_window_utils.h @@ -31,7 +31,7 @@ struct DynamicWindowDims { HloInstruction* output_size; }; -// This mirrors the logic in GetWindowedOutputSizeVerboseV2 but with HLOs as +// This mirrors the logic in GetWindowedOutputSizeVerbose but with HLOs as // inputs and outputs. DynamicWindowDims GetWindowedOutputSize(HloInstruction* input_size, int64_t window_size, diff --git a/tensorflow/core/framework/kernel_shape_util.cc b/tensorflow/core/framework/kernel_shape_util.cc index 2d52a5022bd486..071821ce4a56d6 100644 --- a/tensorflow/core/framework/kernel_shape_util.cc +++ b/tensorflow/core/framework/kernel_shape_util.cc @@ -14,15 +14,17 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/framework/kernel_shape_util.h" +#include +#include + #include "tensorflow/core/lib/core/errors.h" namespace tensorflow { -Status GetWindowedOutputSizeVerboseV2(int64_t input_size, int64_t filter_size, - int64_t dilation_rate, int64_t stride, - Padding padding_type, - int64_t* output_size, - int64_t* padding_before, - int64_t* padding_after) { +Status GetWindowedOutputSizeVerbose(int64_t input_size, int64_t filter_size, + int64_t dilation_rate, int64_t stride, + Padding padding_type, int64_t* output_size, + int64_t* padding_before, + int64_t* padding_after) { if (stride <= 0) { return errors::InvalidArgument("Stride must be > 0, but got ", stride); } @@ -64,17 +66,6 @@ Status GetWindowedOutputSizeVerboseV2(int64_t input_size, int64_t filter_size, return OkStatus(); } -Status GetWindowedOutputSizeVerbose(int64_t input_size, int64_t filter_size, - int64_t stride, Padding padding_type, - int64_t* output_size, - int64_t* padding_before, - int64_t* padding_after) { - return GetWindowedOutputSizeVerboseV2(input_size, filter_size, - /*dilation_rate=*/1, stride, - padding_type, output_size, - padding_before, padding_after); -} - Status GetWindowedOutputSize(int64_t input_size, int64_t filter_size, int dilation_rate, int64_t stride, Padding padding_type, int64_t* output_size, @@ -82,12 +73,12 @@ Status GetWindowedOutputSize(int64_t input_size, int64_t filter_size, if (padding_type == Padding::EXPLICIT) { return errors::Internal( "GetWindowedOutputSize does not handle EXPLICIT padding; call " - "GetWindowedOutputSizeVerboseV2 instead"); + "GetWindowedOutputSizeVerbose instead"); } int64_t padding_after_unused; - return GetWindowedOutputSizeVerboseV2(input_size, filter_size, dilation_rate, - stride, padding_type, output_size, - padding_size, &padding_after_unused); + return GetWindowedOutputSizeVerbose(input_size, filter_size, dilation_rate, + stride, padding_type, output_size, + padding_size, &padding_after_unused); } Status Get3dOutputSizeV2(const std::array& input, diff --git a/tensorflow/core/framework/kernel_shape_util.h b/tensorflow/core/framework/kernel_shape_util.h index 6ffda766ca449c..551a863e3d38e5 100644 --- a/tensorflow/core/framework/kernel_shape_util.h +++ b/tensorflow/core/framework/kernel_shape_util.h @@ -85,14 +85,6 @@ Status GetWindowedOutputSize(int64_t input_size, int64_t filter_size, // excess padding (caused by an odd padding size value) is added to the // 'padding_after' dimension. Status GetWindowedOutputSizeVerbose(int64_t input_size, int64_t filter_size, - int64_t stride, Padding padding_type, - int64_t* output_size, - int64_t* padding_before, - int64_t* padding_after); - -// The V2 version computes the same outputs with arbitrary dilation_rate. For -// detailed equations, refer to the comments for GetWindowedOutputSize(). -Status GetWindowedOutputSizeVerboseV2(int64_t input_size, int64_t filter_size, int64_t dilation_rate, int64_t stride, Padding padding_type, int64_t* output_size, diff --git a/tensorflow/core/kernels/conv_grad_filter_ops.cc b/tensorflow/core/kernels/conv_grad_filter_ops.cc index 7509bed522be28..555d5ae640cddf 100644 --- a/tensorflow/core/kernels/conv_grad_filter_ops.cc +++ b/tensorflow/core/kernels/conv_grad_filter_ops.cc @@ -314,13 +314,13 @@ class Conv2DCustomBackpropFilterOp : public OpKernel { context, GetWindowedOutputSizeVerbose( dims.spatial_dims[0].input_size, dims.spatial_dims[0].filter_size, - dims.spatial_dims[0].stride, padding_, + /*dilation_rate=*/1, dims.spatial_dims[0].stride, padding_, &dims.spatial_dims[0].output_size, &pad_top, &pad_bottom)); OP_REQUIRES_OK( context, GetWindowedOutputSizeVerbose( dims.spatial_dims[1].input_size, dims.spatial_dims[1].filter_size, - dims.spatial_dims[1].stride, padding_, + /*dilation_rate=*/1, dims.spatial_dims[1].stride, padding_, &dims.spatial_dims[1].output_size, &pad_left, &pad_right)); // The total dimension size of each kernel. diff --git a/tensorflow/core/kernels/conv_grad_filter_ops_3d.cc b/tensorflow/core/kernels/conv_grad_filter_ops_3d.cc index 0cb52b8419da8c..69fc08bb3d364f 100644 --- a/tensorflow/core/kernels/conv_grad_filter_ops_3d.cc +++ b/tensorflow/core/kernels/conv_grad_filter_ops_3d.cc @@ -16,7 +16,10 @@ limitations under the License. #define USE_EIGEN_TENSOR #define EIGEN_USE_THREADS +#include +#include #include +#include #include "tensorflow/core/framework/kernel_shape_util.h" #include "tensorflow/core/framework/numeric_op.h" @@ -375,24 +378,27 @@ class Conv3DCustomBackpropFilterOp : public OpKernel { int64_t top_pad_rows, bottom_pad_rows; int64_t left_pad_cols, right_pad_cols; - OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose( - dims.spatial_dims[0].input_size, - dims.spatial_dims[0].filter_size, - dims.spatial_dims[0].stride, padding_, - &dims.spatial_dims[0].output_size, - &top_pad_planes, &bottom_pad_planes)); - OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose( - dims.spatial_dims[1].input_size, - dims.spatial_dims[1].filter_size, - dims.spatial_dims[1].stride, padding_, - &dims.spatial_dims[1].output_size, - &top_pad_rows, &bottom_pad_rows)); - OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose( - dims.spatial_dims[2].input_size, - dims.spatial_dims[2].filter_size, - dims.spatial_dims[2].stride, padding_, - &dims.spatial_dims[2].output_size, - &left_pad_cols, &right_pad_cols)); + OP_REQUIRES_OK( + context, + GetWindowedOutputSizeVerbose( + dims.spatial_dims[0].input_size, dims.spatial_dims[0].filter_size, + /*dilation_rate=*/1, dims.spatial_dims[0].stride, padding_, + &dims.spatial_dims[0].output_size, &top_pad_planes, + &bottom_pad_planes)); + OP_REQUIRES_OK( + context, + GetWindowedOutputSizeVerbose( + dims.spatial_dims[1].input_size, dims.spatial_dims[1].filter_size, + /*dilation_rate=*/1, dims.spatial_dims[1].stride, padding_, + &dims.spatial_dims[1].output_size, &top_pad_rows, + &bottom_pad_rows)); + OP_REQUIRES_OK( + context, + GetWindowedOutputSizeVerbose( + dims.spatial_dims[2].input_size, dims.spatial_dims[2].filter_size, + /*dilation_rate=*/1, dims.spatial_dims[2].stride, padding_, + &dims.spatial_dims[2].output_size, &left_pad_cols, + &right_pad_cols)); // TODO(ezhulenev): Extract work size and shard estimation to shared // functions in conv_grad_ops, and update 2d convolution backprop. diff --git a/tensorflow/core/kernels/conv_grad_filter_ops_launcher.cc b/tensorflow/core/kernels/conv_grad_filter_ops_launcher.cc index 1738c9413c9d84..403a6122d7f273 100644 --- a/tensorflow/core/kernels/conv_grad_filter_ops_launcher.cc +++ b/tensorflow/core/kernels/conv_grad_filter_ops_launcher.cc @@ -100,12 +100,12 @@ struct LaunchConv2DBackpropFilterOp { int64_t expected_out_rows, expected_out_cols; // The function is guaranteed to succeed because we checked the output and // padding was valid earlier. - TF_CHECK_OK(GetWindowedOutputSizeVerboseV2( + TF_CHECK_OK(GetWindowedOutputSizeVerbose( dims.spatial_dims[0].input_size, dims.spatial_dims[0].filter_size, row_dilation, row_stride, padding, &expected_out_rows, &padding_top, &padding_bottom)); DCHECK_EQ(dims.spatial_dims[0].output_size, expected_out_rows); - TF_CHECK_OK(GetWindowedOutputSizeVerboseV2( + TF_CHECK_OK(GetWindowedOutputSizeVerbose( dims.spatial_dims[1].input_size, dims.spatial_dims[1].filter_size, col_dilation, col_stride, padding, &expected_out_cols, &padding_left, &padding_right)); @@ -206,12 +206,12 @@ void LaunchConv2DBackpropFilterOpImpl( int64_t expected_out_rows, expected_out_cols; // The function is guaranteed to succeed because we checked the output and // padding was valid earlier. - TF_CHECK_OK(GetWindowedOutputSizeVerboseV2( + TF_CHECK_OK(GetWindowedOutputSizeVerbose( dims.spatial_dims[0].input_size, dims.spatial_dims[0].filter_size, row_dilation, row_stride, padding, &expected_out_rows, &padding_top, &padding_bottom)); DCHECK_EQ(dims.spatial_dims[0].output_size, expected_out_rows); - TF_CHECK_OK(GetWindowedOutputSizeVerboseV2( + TF_CHECK_OK(GetWindowedOutputSizeVerbose( dims.spatial_dims[1].input_size, dims.spatial_dims[1].filter_size, col_dilation, col_stride, padding, &expected_out_cols, &padding_left, &padding_right)); diff --git a/tensorflow/core/kernels/conv_grad_input_ops.cc b/tensorflow/core/kernels/conv_grad_input_ops.cc index a222fd3e89c167..cbf56bc71c7504 100644 --- a/tensorflow/core/kernels/conv_grad_input_ops.cc +++ b/tensorflow/core/kernels/conv_grad_input_ops.cc @@ -111,12 +111,12 @@ void LaunchConv2DBackpropInputOpGpuImpl( int64_t expected_out_rows, expected_out_cols; // The function is guaranteed to succeed because we checked the output and // padding was valid earlier. - TF_CHECK_OK(GetWindowedOutputSizeVerboseV2( + TF_CHECK_OK(GetWindowedOutputSizeVerbose( dims.spatial_dims[0].input_size, dims.spatial_dims[0].filter_size, row_dilation, row_stride, padding, &expected_out_rows, &padding_top, &padding_bottom)); DCHECK_EQ(dims.spatial_dims[0].output_size, expected_out_rows); - TF_CHECK_OK(GetWindowedOutputSizeVerboseV2( + TF_CHECK_OK(GetWindowedOutputSizeVerbose( dims.spatial_dims[1].input_size, dims.spatial_dims[1].filter_size, col_dilation, col_stride, padding, &expected_out_cols, &padding_left, &padding_right)); diff --git a/tensorflow/core/kernels/conv_grad_input_ops.h b/tensorflow/core/kernels/conv_grad_input_ops.h index f330ee672a66b2..b7e4e9c6837c41 100644 --- a/tensorflow/core/kernels/conv_grad_input_ops.h +++ b/tensorflow/core/kernels/conv_grad_input_ops.h @@ -144,13 +144,13 @@ struct LaunchConv2DBackpropInputOpImpl { int64_t expected_out_rows, expected_out_cols; // The function is guaranteed to succeed because we checked the output and // padding was valid earlier. - TF_CHECK_OK(GetWindowedOutputSizeVerboseV2( + TF_CHECK_OK(GetWindowedOutputSizeVerbose( dims.spatial_dims[0].input_size, dims.spatial_dims[0].filter_size, row_dilation, row_stride, padding, &expected_out_rows, &padding_top, &padding_bottom)); DCHECK_EQ(dims.spatial_dims[0].output_size, expected_out_rows); - TF_CHECK_OK(GetWindowedOutputSizeVerboseV2( + TF_CHECK_OK(GetWindowedOutputSizeVerbose( dims.spatial_dims[1].input_size, dims.spatial_dims[1].filter_size, col_dilation, col_stride, padding, &expected_out_cols, &padding_left, &padding_right)); @@ -525,13 +525,13 @@ class Conv2DCustomBackpropInputOp : public OpKernel { context, GetWindowedOutputSizeVerbose( dims.spatial_dims[0].input_size, dims.spatial_dims[0].filter_size, - dims.spatial_dims[0].stride, padding_, + /*dilation_rate=*/1, dims.spatial_dims[0].stride, padding_, &dims.spatial_dims[0].output_size, &pad_top, &pad_bottom)); OP_REQUIRES_OK( context, GetWindowedOutputSizeVerbose( dims.spatial_dims[1].input_size, dims.spatial_dims[1].filter_size, - dims.spatial_dims[1].stride, padding_, + /*dilation_rate=*/1, dims.spatial_dims[1].stride, padding_, &dims.spatial_dims[1].output_size, &pad_left, &pad_right)); // The total dimension size of each kernel. diff --git a/tensorflow/core/kernels/conv_grad_input_ops_3d.cc b/tensorflow/core/kernels/conv_grad_input_ops_3d.cc index 3e35e82bb9a200..9bb28116202740 100644 --- a/tensorflow/core/kernels/conv_grad_input_ops_3d.cc +++ b/tensorflow/core/kernels/conv_grad_input_ops_3d.cc @@ -16,7 +16,10 @@ limitations under the License. #define USE_EIGEN_TENSOR #define EIGEN_USE_THREADS +#include +#include #include +#include #include "tensorflow/core/framework/kernel_shape_util.h" #include "tensorflow/core/framework/numeric_op.h" @@ -360,24 +363,27 @@ class Conv3DCustomBackpropInputOp : public OpKernel { int64_t top_pad_rows, bottom_pad_rows; int64_t left_pad_cols, right_pad_cols; - OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose( - dims.spatial_dims[0].input_size, - dims.spatial_dims[0].filter_size, - dims.spatial_dims[0].stride, padding_, - &dims.spatial_dims[0].output_size, - &top_pad_planes, &bottom_pad_planes)); - OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose( - dims.spatial_dims[1].input_size, - dims.spatial_dims[1].filter_size, - dims.spatial_dims[1].stride, padding_, - &dims.spatial_dims[1].output_size, - &top_pad_rows, &bottom_pad_rows)); - OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose( - dims.spatial_dims[2].input_size, - dims.spatial_dims[2].filter_size, - dims.spatial_dims[2].stride, padding_, - &dims.spatial_dims[2].output_size, - &left_pad_cols, &right_pad_cols)); + OP_REQUIRES_OK( + context, + GetWindowedOutputSizeVerbose( + dims.spatial_dims[0].input_size, dims.spatial_dims[0].filter_size, + /*dilation_rate=*/1, dims.spatial_dims[0].stride, padding_, + &dims.spatial_dims[0].output_size, &top_pad_planes, + &bottom_pad_planes)); + OP_REQUIRES_OK( + context, + GetWindowedOutputSizeVerbose( + dims.spatial_dims[1].input_size, dims.spatial_dims[1].filter_size, + /*dilation_rate=*/1, dims.spatial_dims[1].stride, padding_, + &dims.spatial_dims[1].output_size, &top_pad_rows, + &bottom_pad_rows)); + OP_REQUIRES_OK( + context, + GetWindowedOutputSizeVerbose( + dims.spatial_dims[2].input_size, dims.spatial_dims[2].filter_size, + /*dilation_rate=*/1, dims.spatial_dims[2].stride, padding_, + &dims.spatial_dims[2].output_size, &left_pad_cols, + &right_pad_cols)); // TODO(ezhulenev): Extract work size and shard estimation to shared // functions in conv_grad_ops, and update 2d convolution backprop. diff --git a/tensorflow/core/kernels/conv_grad_shape_utils.cc b/tensorflow/core/kernels/conv_grad_shape_utils.cc index f2686e1dd6cc60..9560a37fd6eea6 100644 --- a/tensorflow/core/kernels/conv_grad_shape_utils.cc +++ b/tensorflow/core/kernels/conv_grad_shape_utils.cc @@ -63,7 +63,7 @@ Status ConvBackpropExtractAndVerifyDimension( dim->stride = strides[spatial_dim]; dim->dilation = dilations[spatial_dim]; int64_t out_size = 0; - TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerboseV2( + TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose( dim->input_size, dim->filter_size, dim->dilation, dim->stride, padding, &out_size, &padding_before, &padding_after)); if (dim->output_size != out_size) { diff --git a/tensorflow/core/kernels/conv_ops.cc b/tensorflow/core/kernels/conv_ops.cc index 55399555357823..063899c6b4a3b7 100644 --- a/tensorflow/core/kernels/conv_ops.cc +++ b/tensorflow/core/kernels/conv_ops.cc @@ -184,10 +184,10 @@ Status ComputeConv2DDimension(const Conv2DParameters& params, // Compute windowed output sizes for rows and columns. int64_t out_rows = 0, out_cols = 0; - TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerboseV2( + TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose( input_rows, filter_rows, dilation_rows, stride_rows, params.padding, &out_rows, &pad_rows_before, &pad_rows_after)); - TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerboseV2( + TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose( input_cols, filter_cols, dilation_cols, stride_cols, params.padding, &out_cols, &pad_cols_before, &pad_cols_after)); diff --git a/tensorflow/core/kernels/conv_ops_impl.h b/tensorflow/core/kernels/conv_ops_impl.h index be916e42b48d93..22b88454435f8c 100644 --- a/tensorflow/core/kernels/conv_ops_impl.h +++ b/tensorflow/core/kernels/conv_ops_impl.h @@ -479,7 +479,7 @@ class ConvOp : public BinaryOp { // Compute windowed output sizes for spatial dimensions. std::vector out_dims(spatial_dims); for (int i = 0; i < spatial_dims; ++i) { - OP_REQUIRES_OK(context, GetWindowedOutputSizeVerboseV2( + OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose( input_dims[i], filter_dims[i], dilation_dims[i], stride_dims[i], padding_, &out_dims[i], &pad_before[i], &pad_after[i])); @@ -843,16 +843,16 @@ void LaunchConv2DOpImpl(OpKernelContext* ctx, bool use_cudnn, &padding_right); } int64_t out_rows_check, out_cols_check; - Status status = GetWindowedOutputSizeVerboseV2( + Status status = GetWindowedOutputSizeVerbose( in_rows, patch_rows, row_dilation, row_stride, padding, &out_rows_check, &padding_top, &padding_bottom); // The status is guaranteed to be OK because we checked the output and padding // was valid earlier. TF_CHECK_OK(status); DCHECK_EQ(out_rows, out_rows_check); - status = GetWindowedOutputSizeVerboseV2(in_cols, patch_cols, col_dilation, - col_stride, padding, &out_cols_check, - &padding_left, &padding_right); + status = GetWindowedOutputSizeVerbose(in_cols, patch_cols, col_dilation, + col_stride, padding, &out_cols_check, + &padding_left, &padding_right); TF_CHECK_OK(status); DCHECK_EQ(out_cols, out_cols_check); diff --git a/tensorflow/core/kernels/depthwise_conv_grad_op.cc b/tensorflow/core/kernels/depthwise_conv_grad_op.cc index 1e97ef38b5b7a1..b16458aa8052ca 100644 --- a/tensorflow/core/kernels/depthwise_conv_grad_op.cc +++ b/tensorflow/core/kernels/depthwise_conv_grad_op.cc @@ -121,12 +121,14 @@ typedef Eigen::GpuDevice GPUDevice; GetExplicitPaddingForDim(explicit_paddings_, data_format_, 'W', &pad_left, \ &pad_right); \ } \ - OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose( \ - input_rows, filter_rows, stride_, padding_, \ - &out_rows, &pad_top, &pad_bottom)); \ - OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose( \ - input_cols, filter_cols, stride_, padding_, \ - &out_cols, &pad_left, &pad_right)); \ + OP_REQUIRES_OK(context, \ + GetWindowedOutputSizeVerbose( \ + input_rows, filter_rows, /*dilation_rate=*/1, stride_, \ + padding_, &out_rows, &pad_top, &pad_bottom)); \ + OP_REQUIRES_OK(context, \ + GetWindowedOutputSizeVerbose( \ + input_cols, filter_cols, /*dilation_rate=*/1, stride_, \ + padding_, &out_cols, &pad_left, &pad_right)); \ OP_REQUIRES( \ context, output_rows == out_rows, \ errors::InvalidArgument( \ diff --git a/tensorflow/core/kernels/depthwise_conv_op.cc b/tensorflow/core/kernels/depthwise_conv_op.cc index b282855666b4ae..a636759602d92d 100644 --- a/tensorflow/core/kernels/depthwise_conv_op.cc +++ b/tensorflow/core/kernels/depthwise_conv_op.cc @@ -411,12 +411,14 @@ class DepthwiseConv2dNativeOp : public BinaryOp { GetExplicitPaddingForDim(explicit_paddings_, data_format_, 'W', &pad_left, &pad_right); } - OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose( - input_rows, filter_rows, stride_, padding_, - &out_rows, &pad_top, &pad_bottom)); - OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose( - input_cols, filter_cols, stride_, padding_, - &out_cols, &pad_left, &pad_right)); + OP_REQUIRES_OK(context, + GetWindowedOutputSizeVerbose( + input_rows, filter_rows, /*dilation_rate=*/1, stride_, + padding_, &out_rows, &pad_top, &pad_bottom)); + OP_REQUIRES_OK(context, + GetWindowedOutputSizeVerbose( + input_cols, filter_cols, /*dilation_rate=*/1, stride_, + padding_, &out_cols, &pad_left, &pad_right)); TensorShape out_shape; OP_REQUIRES_OK(context, ShapeFromFormatWithStatus(data_format_, batch, out_rows, diff --git a/tensorflow/core/kernels/mkl/mkl_conv_ops.h b/tensorflow/core/kernels/mkl/mkl_conv_ops.h index 2f35decb3548f0..0384df4b309285 100644 --- a/tensorflow/core/kernels/mkl/mkl_conv_ops.h +++ b/tensorflow/core/kernels/mkl/mkl_conv_ops.h @@ -446,11 +446,11 @@ class MklDnnConvUtil { padding_type = padding_; } OP_REQUIRES_OK(context_, - GetWindowedOutputSizeVerboseV2( + GetWindowedOutputSizeVerbose( input_rows, filter_rows, dilation_rows, stride_rows, padding_type, &out_rows, &pad_top, &pad_bottom)); OP_REQUIRES_OK(context_, - GetWindowedOutputSizeVerboseV2( + GetWindowedOutputSizeVerbose( input_cols, filter_cols, dilation_cols, stride_cols, padding_type, &out_cols, &pad_left, &pad_right)); } else { @@ -466,16 +466,16 @@ class MklDnnConvUtil { } else { padding_type = padding_; } - OP_REQUIRES_OK(context_, GetWindowedOutputSizeVerboseV2( + OP_REQUIRES_OK(context_, GetWindowedOutputSizeVerbose( input_planes, filter_planes, dilation_planes, stride_planes, padding_type, &out_planes, &pad_front, &pad_back)); OP_REQUIRES_OK(context_, - GetWindowedOutputSizeVerboseV2( + GetWindowedOutputSizeVerbose( input_rows, filter_rows, dilation_rows, stride_rows, padding_type, &out_rows, &pad_top, &pad_bottom)); OP_REQUIRES_OK(context_, - GetWindowedOutputSizeVerboseV2( + GetWindowedOutputSizeVerbose( input_cols, filter_cols, dilation_cols, stride_cols, padding_type, &out_cols, &pad_left, &pad_right)); } diff --git a/tensorflow/core/kernels/mkl/mkl_pooling_ops_common.cc b/tensorflow/core/kernels/mkl/mkl_pooling_ops_common.cc index aab190f5ac618d..c73233cab8dd26 100644 --- a/tensorflow/core/kernels/mkl/mkl_pooling_ops_common.cc +++ b/tensorflow/core/kernels/mkl/mkl_pooling_ops_common.cc @@ -375,19 +375,21 @@ void MklPoolParameters::Init(OpKernelContext* context, if (depth_window == 1) { // We are pooling in the D (Pool3D only), H and W. if (!is_pool2d) { - OP_REQUIRES_OK( - context, GetWindowedOutputSizeVerbose(tensor_in_planes, window_planes, - planes_stride, padding, - &out_planes, &pad_P1, &pad_P2)); + OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose( + tensor_in_planes, window_planes, + /*dilation_rate=*/1, planes_stride, padding, + &out_planes, &pad_P1, &pad_P2)); } - OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose( - tensor_in_rows, window_rows, row_stride, - padding, &out_height, &pad_top, &pad_bottom)); + OP_REQUIRES_OK( + context, GetWindowedOutputSizeVerbose( + tensor_in_rows, window_rows, /*dilation_rate=*/1, + row_stride, padding, &out_height, &pad_top, &pad_bottom)); - OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose( - tensor_in_cols, window_cols, col_stride, - padding, &out_width, &pad_left, &pad_right)); + OP_REQUIRES_OK(context, + GetWindowedOutputSizeVerbose( + tensor_in_cols, window_cols, /*dilation_rate=*/1, + col_stride, padding, &out_width, &pad_left, &pad_right)); // TF can work with int64, but oneDNN only supports int32. // Fail if the depth, height or width are greater than MAX_INT. diff --git a/tensorflow/core/kernels/ops_util_test.cc b/tensorflow/core/kernels/ops_util_test.cc index e449947e16def3..a3848b4dca6db6 100644 --- a/tensorflow/core/kernels/ops_util_test.cc +++ b/tensorflow/core/kernels/ops_util_test.cc @@ -107,13 +107,13 @@ class OpsUtilTest : public ::testing::Test { int64_t new_height, new_width, pad_top, pad_bottom, pad_left, pad_right; Status status = GetWindowedOutputSizeVerbose( pad_struct.input.in_height, pad_struct.input.filter_height, - pad_struct.input.row_stride, pad_struct.input.padding, &new_height, - &pad_top, &pad_bottom); + /*dilation_rate=*/1, pad_struct.input.row_stride, + pad_struct.input.padding, &new_height, &pad_top, &pad_bottom); EXPECT_EQ(status.code(), code) << status; status = GetWindowedOutputSizeVerbose( pad_struct.input.in_width, pad_struct.input.filter_width, - pad_struct.input.col_stride, pad_struct.input.padding, &new_width, - &pad_left, &pad_right); + /*dilation_rate=*/1, pad_struct.input.col_stride, + pad_struct.input.padding, &new_width, &pad_left, &pad_right); EXPECT_EQ(status.code(), code) << status; EXPECT_EQ(pad_struct.output.new_height, new_height); EXPECT_EQ(pad_struct.output.new_width, new_width); diff --git a/tensorflow/core/kernels/pooling_ops_common.cc b/tensorflow/core/kernels/pooling_ops_common.cc index b48287ae1442a4..407d6991608c7e 100644 --- a/tensorflow/core/kernels/pooling_ops_common.cc +++ b/tensorflow/core/kernels/pooling_ops_common.cc @@ -164,12 +164,14 @@ PoolParameters::PoolParameters(OpKernelContext* context, } if (depth_window == 1) { - OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose( - tensor_in_rows, window_rows, row_stride, - padding, &out_height, &pad_top, &pad_bottom)); - OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose( - tensor_in_cols, window_cols, col_stride, - padding, &out_width, &pad_left, &pad_right)); + OP_REQUIRES_OK( + context, GetWindowedOutputSizeVerbose( + tensor_in_rows, window_rows, /*dilation_rate=*/1, + row_stride, padding, &out_height, &pad_top, &pad_bottom)); + OP_REQUIRES_OK(context, + GetWindowedOutputSizeVerbose( + tensor_in_cols, window_cols, /*dilation_rate=*/1, + col_stride, padding, &out_width, &pad_left, &pad_right)); pad_depth = 0; out_depth = depth; } else { @@ -195,12 +197,14 @@ PoolParameters::PoolParameters(OpKernelContext* context, errors::Unimplemented("Depthwise max pooling is currently " "only implemented for CPU devices.")); - OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose( - tensor_in_rows, window_rows, row_stride, - padding, &out_height, &pad_top, &pad_bottom)); - OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose( - tensor_in_cols, window_cols, col_stride, - padding, &out_width, &pad_left, &pad_right)); + OP_REQUIRES_OK( + context, GetWindowedOutputSizeVerbose( + tensor_in_rows, window_rows, /*dilation_rate=*/1, + row_stride, padding, &out_height, &pad_top, &pad_bottom)); + OP_REQUIRES_OK(context, + GetWindowedOutputSizeVerbose( + tensor_in_cols, window_cols, /*dilation_rate=*/1, + col_stride, padding, &out_width, &pad_left, &pad_right)); pad_depth = 0; out_depth = depth / depth_window; } diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc index 42eff63cab03b1..331cfa84728eea 100644 --- a/tensorflow/core/ops/array_ops.cc +++ b/tensorflow/core/ops/array_ops.cc @@ -2702,11 +2702,11 @@ REGISTER_OP("ExtractImagePatches") int64_t output_rows, output_cols; int64_t padding_before, padding_after; TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose( - in_rows, ksize_rows_eff, stride_rows, padding, &output_rows, - &padding_before, &padding_after)); + in_rows, ksize_rows_eff, /*dilation_rate=*/1, stride_rows, padding, + &output_rows, &padding_before, &padding_after)); TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose( - in_cols, ksize_cols_eff, stride_cols, padding, &output_cols, - &padding_before, &padding_after)); + in_cols, ksize_cols_eff, /*dilation_rate=*/1, stride_cols, padding, + &output_cols, &padding_before, &padding_after)); ShapeHandle output_shape = c->MakeShape( {batch_size_dim, output_rows, output_cols, output_depth_dim}); c->set_output(0, output_shape); @@ -2808,14 +2808,14 @@ REGISTER_OP("ExtractVolumePatches") int64_t output_planes, output_rows, output_cols; int64_t padding_before, padding_after; TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose( - in_planes, ksize_planes, stride_planes, padding, &output_planes, - &padding_before, &padding_after)); + in_planes, ksize_planes, /*dilation_rate=*/1, stride_planes, padding, + &output_planes, &padding_before, &padding_after)); TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose( - in_rows, ksize_rows, stride_rows, padding, &output_rows, - &padding_before, &padding_after)); + in_rows, ksize_rows, /*dilation_rate=*/1, stride_rows, padding, + &output_rows, &padding_before, &padding_after)); TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose( - in_cols, ksize_cols, stride_cols, padding, &output_cols, - &padding_before, &padding_after)); + in_cols, ksize_cols, /*dilation_rate=*/1, stride_cols, padding, + &output_cols, &padding_before, &padding_after)); ShapeHandle output_shape = c->MakeShape({batch_size_dim, output_planes, output_rows, output_cols, output_depth_dim}); diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc index 5fcc0a4738c74d..e7fd05a9c34e89 100644 --- a/tensorflow/core/ops/nn_ops.cc +++ b/tensorflow/core/ops/nn_ops.cc @@ -1147,11 +1147,11 @@ REGISTER_OP("Dilation2D") int64_t output_rows, output_cols; int64_t padding_before, padding_after; TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose( - in_rows, filter_rows_eff, stride_rows, padding, &output_rows, - &padding_before, &padding_after)); + in_rows, filter_rows_eff, /*dilation_rate=*/1, stride_rows, padding, + &output_rows, &padding_before, &padding_after)); TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose( - in_cols, filter_cols_eff, stride_cols, padding, &output_cols, - &padding_before, &padding_after)); + in_cols, filter_cols_eff, /*dilation_rate=*/1, stride_cols, padding, + &output_cols, &padding_before, &padding_after)); ShapeHandle output_shape = c->MakeShape( {batch_size_dim, output_rows, output_cols, output_depth_dim}); From a76f25506ef687cdfda2d68f4430ca69708e4a8a Mon Sep 17 00:00:00 2001 From: Yimei Sun Date: Thu, 13 Jul 2023 23:26:04 -0700 Subject: [PATCH 302/376] Fix typo --- tensorflow/core/kernels/mkl/mkl_fused_instance_norm_op.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/kernels/mkl/mkl_fused_instance_norm_op.cc b/tensorflow/core/kernels/mkl/mkl_fused_instance_norm_op.cc index df5a7e80239734..a08a34584f564b 100644 --- a/tensorflow/core/kernels/mkl/mkl_fused_instance_norm_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_fused_instance_norm_op.cc @@ -273,7 +273,7 @@ class MklFusedInstanceNormOp : public OpKernel { // Helper function to prepare scale and shift data in float type as // required by oneDNN library. Prior to oneDNN 3.x version, the library // requires the final scale and shift data to be passed in the same buffer - // wherase the 3.x version requires separate buffers for scale and shift + // whereas the 3.x version requires separate buffers for scale and shift // data. void SetupScaleShiftBuffer(OpKernelContext* ctx, const Tensor& scale_tensor, const Tensor& shift_tensor, From f63aefed5e6b174936ccbfff23e9e4547ba07ab9 Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Fri, 14 Jul 2023 01:04:09 -0700 Subject: [PATCH 303/376] [XLA:GPU] Do not enable debug info manager for recursive compilation during autotuning. PiperOrigin-RevId: 548050621 --- tensorflow/compiler/xla/service/compiler.h | 2 + tensorflow/compiler/xla/service/gpu/BUILD | 1 + .../xla/service/gpu/autotuner_compile_util.cc | 5 +- .../compiler/xla/service/gpu/gpu_compiler.cc | 2 +- .../xla/service/gpu/gpu_compiler_test.cc | 52 +++++++++++++++++-- .../xla/service/gpu/gpu_executable.cc | 10 ++-- .../compiler/xla/service/gpu/gpu_executable.h | 3 ++ .../xla/service/xla_debug_info_manager.cc | 6 +++ .../xla/service/xla_debug_info_manager.h | 5 ++ 9 files changed, 76 insertions(+), 10 deletions(-) diff --git a/tensorflow/compiler/xla/service/compiler.h b/tensorflow/compiler/xla/service/compiler.h index de568a260c64eb..d75cab95226ce0 100644 --- a/tensorflow/compiler/xla/service/compiler.h +++ b/tensorflow/compiler/xla/service/compiler.h @@ -244,6 +244,8 @@ class Compiler { std::function, Shape>>( const HloModule& module)> layout_canonicalization_callback = {}; + + bool enable_debug_info_manager = true; }; virtual ~Compiler() = default; diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 50e052bb133fb7..bc1af2d1fc567b 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -2514,6 +2514,7 @@ xla_cc_test( "//tensorflow/compiler/xla:autotune_results_proto_cc", "//tensorflow/compiler/xla/hlo/utils:hlo_matchers", "//tensorflow/compiler/xla/service:gpu_plugin", + "//tensorflow/compiler/xla/service:xla_debug_info_manager", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/tsl/lib/core:status_test_util", diff --git a/tensorflow/compiler/xla/service/gpu/autotuner_compile_util.cc b/tensorflow/compiler/xla/service/gpu/autotuner_compile_util.cc index 85859f88d64cb2..87fed6db01161e 100644 --- a/tensorflow/compiler/xla/service/gpu/autotuner_compile_util.cc +++ b/tensorflow/compiler/xla/service/gpu/autotuner_compile_util.cc @@ -190,7 +190,10 @@ StatusOr> AutotunerCompileUtil::CompileNoCache( (*new_hlo_module)->config().set_debug_options(opts_); StatusOr> out = compiler_->RunBackend( - std::move(*new_hlo_module), &stream_executor_, &allocator_); + std::move(*new_hlo_module), &stream_executor_, + Compiler::CompileOptions{&allocator_, /*thread_pool=*/nullptr, + /*layout_canonicalization_callback=*/{}, + /*enable_debug_info_manager=*/false}); if (out.status().code() == absl::StatusCode::kResourceExhausted) { // Being out of shared memory budget is an expected failure. return std::unique_ptr(); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index 9670d2dc954216..3058f2c60f4846 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -1473,7 +1473,7 @@ StatusOr> GpuCompiler::RunBackend( .xla_gpu_enable_persistent_temp_buffers(), std::move(buffer_assignment_proto), [buffer_assignment] { return buffer_assignment->ToVerboseString(); }, - std::move(module)})); + std::move(module), options.enable_debug_info_manager})); if (embed_ir_in_executable) { DCHECK_NE("", ir_module_string_before_opt); gpu_executable->set_ir_module_string(ir_module_string_before_opt); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler_test.cc index 05655d609ba2d4..cadc76d4e9bbe8 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler_test.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include #include +#include #include #include @@ -24,6 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/autotune_results.pb.h" #include "tensorflow/compiler/xla/hlo/utils/hlo_matchers.h" #include "tensorflow/compiler/xla/service/gpu/horizontal_loop_fusion.h" +#include "tensorflow/compiler/xla/service/xla_debug_info_manager.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/tsl/lib/core/status_test_util.h" @@ -33,16 +35,58 @@ namespace { namespace op = xla::testing::opcode_matchers; -using ::absl::LogSeverity; -using ::absl::ScopedMockLog; -using ::testing::EndsWith; using ::testing::IsEmpty; using ::testing::Not; -using ::testing::StartsWith; using ::testing::TempDir; using GpuCompilerTest = HloTestBase; +TEST_F(GpuCompilerTest, DebugInfoManagerEnabled) { + const char* hlo_text = R"( +HloModule test + +ENTRY main { + p = f32[10]{0} parameter(0) + ROOT neg = f32[10]{0} negate(p) +} +)"; + auto module = ParseAndReturnVerifiedModule(hlo_text).value(); + std::unique_ptr executable = + backend() + .compiler() + ->RunBackend(std::move(module), backend().default_stream_executor(), + {/*device_allocator=*/nullptr, + /*thread_pool=*/nullptr, + /*layout_canonicalization_callback=*/{}, + /*enable_debug_info_manager=*/true}) + .value(); + EXPECT_TRUE(XlaDebugInfoManager::Get()->TracksModule( + executable->module().unique_id())); +} + +TEST_F(GpuCompilerTest, DebugInfoManagerDisabled) { + const char* hlo_text = R"( +HloModule test + +ENTRY main { + p = f32[10]{0} parameter(0) + ROOT neg = f32[10]{0} negate(p) +} +)"; + auto module = ParseAndReturnVerifiedModule(hlo_text).value(); + std::unique_ptr executable = + backend() + .compiler() + ->RunBackend(std::move(module), backend().default_stream_executor(), + {/*device_allocator=*/nullptr, + /*thread_pool=*/nullptr, + /*layout_canonicalization_callback=*/{}, + /*enable_debug_info_manager=*/false}) + .value(); + EXPECT_FALSE(XlaDebugInfoManager::Get()->TracksModule( + executable->module().unique_id())); +} + TEST_F(GpuCompilerTest, CopyInsertionFusion) { const char* hlo_text = R"( HloModule cluster diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc index 6e008a43807950..2b69d869020a5e 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc @@ -126,7 +126,8 @@ GpuExecutable::GpuExecutable(GpuExecutable::Params params) verbose_buffer_assignment_string_dumper_( params.verbose_buffer_assignment_string_dumper), constants_(std::move(params.constants)), - output_info_(std::move(params.output_info)) { + output_info_(std::move(params.output_info)), + enable_debug_info_manager_(params.enable_debug_info_manager) { #if TENSORFLOW_USE_ROCM // ROCm uses hsaco hashes to distinguish between modules. // Bad things happen if multiple modules with identical code are loaded. @@ -134,14 +135,14 @@ GpuExecutable::GpuExecutable(GpuExecutable::Params params) *(uint64_t*)(&binary_[binary_.size() - 16]) = tsl::EnvTime::NowNanos(); *(uint64_t*)(&binary_[binary_.size() - 8]) = tsl::random::New64(); #endif - if (has_module()) { + if (has_module() && enable_debug_info_manager_) { XlaDebugInfoManager::Get()->RegisterModule( module().unique_id(), shared_module(), debug_buffer_assignment_); } } GpuExecutable::~GpuExecutable() { - if (has_module()) { + if (has_module() && enable_debug_info_manager_) { XlaDebugInfoManager::Get()->UnregisterModule(module().unique_id()); } @@ -925,7 +926,8 @@ GpuExecutable::GpuExecutable( output_shape_(xla_output_shape), allocations_(std::move(allocations)), constants_(std::move(constants)), - output_info_(std::move(output_info)) { + output_info_(std::move(output_info)), + enable_debug_info_manager_(true) { XlaDebugInfoManager::Get()->RegisterModule( module().unique_id(), shared_module(), debug_buffer_assignment_); } diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.h b/tensorflow/compiler/xla/service/gpu/gpu_executable.h index 51010a8fa3124e..830cc20914e517 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.h @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include #include #include #include @@ -103,6 +104,7 @@ class GpuExecutable : public Executable { }; std::unique_ptr debug_module = nullptr; + bool enable_debug_info_manager = true; }; // Analyze the entry function to construct buffer allocation and other output @@ -314,6 +316,7 @@ class GpuExecutable : public Executable { // Retains shared ownership of on-device constants that are managed by XLA and // potentially shared with other executables. std::vector> shared_constants_; + bool enable_debug_info_manager_; GpuExecutable(const GpuExecutable&) = delete; GpuExecutable& operator=(const GpuExecutable&) = delete; diff --git a/tensorflow/compiler/xla/service/xla_debug_info_manager.cc b/tensorflow/compiler/xla/service/xla_debug_info_manager.cc index 33411199568f60..b0df104a74284b 100644 --- a/tensorflow/compiler/xla/service/xla_debug_info_manager.cc +++ b/tensorflow/compiler/xla/service/xla_debug_info_manager.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include #include +#include #include "tensorflow/compiler/xla/service/hlo_proto_util.h" @@ -94,4 +95,9 @@ void XlaDebugInfoManager::StopTracing( } } +bool XlaDebugInfoManager::TracksModule(ModuleIdentifier module_id) const { + absl::MutexLock lock(&mutex_); + return modules_.find(module_id) != modules_.end(); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/xla_debug_info_manager.h b/tensorflow/compiler/xla/service/xla_debug_info_manager.h index d18a7cf35dabb4..6c16a374b1c962 100644 --- a/tensorflow/compiler/xla/service/xla_debug_info_manager.h +++ b/tensorflow/compiler/xla/service/xla_debug_info_manager.h @@ -16,8 +16,10 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_XLA_DEBUG_INFO_MANAGER_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_XLA_DEBUG_INFO_MANAGER_H_ +#include #include #include +#include #include "absl/container/flat_hash_map.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_module.h" @@ -58,6 +60,9 @@ class XlaDebugInfoManager { void StopTracing( std::vector>* module_debug_info = nullptr); + // Returns whether 'module_id' is tracked by XlaDebugInfoManager. + bool TracksModule(ModuleIdentifier module_id) const; + friend class XlaDebugInfoManagerTestPeer; private: From 85844e73ce5f93634529c659eb2c755871cfa191 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 14 Jul 2023 01:39:09 -0700 Subject: [PATCH 304/376] Update auto assignment list Add : Varsha Anjanappa PiperOrigin-RevId: 548057030 --- .github/bot_config.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/bot_config.yml b/.github/bot_config.yml index b90b4f52c56d0f..9ddb1c272bbf1e 100644 --- a/.github/bot_config.yml +++ b/.github/bot_config.yml @@ -18,6 +18,7 @@ assignees: - sushreebarsa - SuryanarayanaY - tilakrayal + - Varsha-anjanappa # A list of assignees for compiler folder compiler_assignees: - joker-eph From ceb2910f130dfc05692154af669e945429029e63 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 14 Jul 2023 02:02:05 -0700 Subject: [PATCH 305/376] Update GraphDef version to 1557. PiperOrigin-RevId: 548061277 --- tensorflow/core/public/version.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index e6f885c504bda6..2a64ed66331723 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -108,7 +108,7 @@ limitations under the License. #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 -#define TF_GRAPH_DEF_VERSION 1556 // Updated: 2023/7/13 +#define TF_GRAPH_DEF_VERSION 1557 // Updated: 2023/7/14 // Checkpoint compatibility versions (the versions field in SavedSliceMeta). // From 562db363736e4fc50336a401b9f142cdc783504a Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 14 Jul 2023 02:02:13 -0700 Subject: [PATCH 306/376] compat: Update forward compatibility horizon to 2023-07-14 PiperOrigin-RevId: 548061310 --- tensorflow/python/compat/compat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index 1bc69b19356702..764155e5591961 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -29,7 +29,7 @@ # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2023, 7, 13) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2023, 7, 14) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None From 4c6b4dc41e71f042adef1b209f8ea69150db54d4 Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Fri, 14 Jul 2023 02:32:41 -0700 Subject: [PATCH 307/376] [XLA:GPU] Add row length filter for matching normalization diamond in Triton Softmax rewriter. PiperOrigin-RevId: 548066828 --- .../xla/service/gpu/ir_emitter_triton_test.cc | 37 ------------------- .../service/gpu/softmax_rewriter_triton.cc | 6 +++ .../gpu/softmax_rewriter_triton_test.cc | 35 ++++++++++++++---- 3 files changed, 34 insertions(+), 44 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_triton_test.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_triton_test.cc index 05f8619c08fce3..9f97c1c0bc5597 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_triton_test.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_triton_test.cc @@ -2011,43 +2011,6 @@ ENTRY main { EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec(1e-6, 1e-6))); } -TEST_F(TritonSoftmaxTest, CanFuseAndEmitExactSoftmaxF32WithShortRows) { - const std::string hlo_text = R"( -HloModule softmax -max_computation { - arg_0 = f32[] parameter(0) - arg_1 = f32[] parameter(1) - ROOT maximum = f32[] maximum(arg_0, arg_1) -} -add_computation { - arg_0.1 = f32[] parameter(0) - arg_1.1 = f32[] parameter(1) - ROOT add = f32[] add(arg_0.1, arg_1.1) -} -ENTRY main { - param_0 = f32[127,5]{1,0} parameter(0) - constant_neg_inf = f32[] constant(-inf) - reduce = f32[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation - broadcast = f32[127,5]{1,0} broadcast(reduce), dimensions={0} - subtract = f32[127,5]{1,0} subtract(param_0, broadcast) - exponential = f32[127,5]{1,0} exponential(subtract) - constant_zero = f32[] constant(0) - second_reduce = f32[127]{0} reduce(exponential, constant_zero), dimensions={1}, to_apply=add_computation - second_broadcast = f32[127,5]{1,0} broadcast(second_reduce), dimensions={0} - ROOT divide = f32[127,5]{1,0} divide(exponential, second_broadcast) -} -)"; - - MatchOptimizedHlo(hlo_text, R"( -; CHECK: ENTRY -; CHECK: %[[P0:.*]] = f32[127,5]{1,0} parameter(0) -; CHECK: ROOT -; CHECK-SAME: fusion(%[[P0]]) -; CHECK-SAME: kind=kCustom -; CHECK-SAME: __triton_softmax -)"); - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec(1e-6, 1e-6))); -} TEST_F(TritonSoftmaxTest, CanFuseAndEmitFirstSoftmaxDiamondF16) { const std::string hlo_text = R"( diff --git a/tensorflow/compiler/xla/service/gpu/softmax_rewriter_triton.cc b/tensorflow/compiler/xla/service/gpu/softmax_rewriter_triton.cc index a5e59e1be3d16b..efaa9c4a811b30 100644 --- a/tensorflow/compiler/xla/service/gpu/softmax_rewriter_triton.cc +++ b/tensorflow/compiler/xla/service/gpu/softmax_rewriter_triton.cc @@ -216,6 +216,12 @@ std::optional MatchesTritonCompatibleClosedReductionDiamond( return match_failure; } + // TODO(b/291204753): remove this filter. This heuristic enables flipping the + // default flag while filtering out cases that could result in regressions. + if (reduce->operand(0)->shape().dimensions().back() < 64) { + return match_failure; + } + while (IsTriviallyFusible(producer)) { producer = producer->mutable_operand(0); } diff --git a/tensorflow/compiler/xla/service/gpu/softmax_rewriter_triton_test.cc b/tensorflow/compiler/xla/service/gpu/softmax_rewriter_triton_test.cc index 00d77002725f0e..ab4022d0bd038d 100644 --- a/tensorflow/compiler/xla/service/gpu/softmax_rewriter_triton_test.cc +++ b/tensorflow/compiler/xla/service/gpu/softmax_rewriter_triton_test.cc @@ -577,17 +577,17 @@ add_computation { ROOT add = f32[] add(arg_0.1, arg_1.1) } ENTRY main { - param_0 = f32[127,125]{1,0} parameter(0) + param_0 = f32[127,625]{1,0} parameter(0) constant_neg_inf = f32[] constant(-inf) reduce = f32[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation - broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0} - subtract = f32[127,125]{1,0} subtract(param_0, broadcast) - bitcasted_subtract = f32[127,5,25] bitcast(subtract) - exponential = f32[127,5,25] exponential(bitcasted_subtract) + broadcast = f32[127,625]{1,0} broadcast(reduce), dimensions={0} + subtract = f32[127,625]{1,0} subtract(param_0, broadcast) + bitcasted_subtract = f32[127,5,125] bitcast(subtract) + exponential = f32[127,5,125] exponential(bitcasted_subtract) constant_zero = f32[] constant(0) second_reduce = f32[127,5] reduce(exponential, constant_zero), dimensions={2}, to_apply=add_computation - second_broadcast = f32[127,5,25] broadcast(second_reduce), dimensions={0,1} - ROOT divide = f32[127,5,25] divide(exponential, second_broadcast) + second_broadcast = f32[127,5,125] broadcast(second_reduce), dimensions={0,1} + ROOT divide = f32[127,5,125] divide(exponential, second_broadcast) } )"; auto module = ParseAndReturnVerifiedModule(hlo_string).value(); @@ -789,6 +789,27 @@ ENTRY main { EXPECT_FALSE(fusion_rewriter.Run(module.get()).value()); } +TEST_F(SoftmaxRewriterTritonTest, DoNotFuseSoftmaxWithSmallRows) { + const std::string hlo_string = R"( +HloModule softmax +max_computation { + arg_0 = f32[] parameter(0) + arg_1 = f32[] parameter(1) + ROOT maximum = f32[] maximum(arg_0, arg_1) +} +ENTRY main { + param_0 = f32[127,50]{1,0} parameter(0) + constant_neg_inf = f32[] constant(-inf) + reduce = f32[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation + broadcast = f32[127,50]{1,0} broadcast(reduce), dimensions={0} + ROOT subtract = f32[127,50]{1,0} subtract(param_0, broadcast) +} +)"; + auto module = ParseAndReturnVerifiedModule(hlo_string).value(); + SoftmaxRewriterTriton fusion_rewriter(gpu_version_); + EXPECT_FALSE(fusion_rewriter.Run(module.get()).value()); +} + } // anonymous namespace } // namespace gpu } // namespace xla From d02daa397dc60ab5a6e91caaaaa11a8011e72338 Mon Sep 17 00:00:00 2001 From: Alan Kelly Date: Fri, 14 Jul 2023 02:39:31 -0700 Subject: [PATCH 308/376] Flex: Cache ValidateOutputTensorShapeConsistency PiperOrigin-RevId: 548068184 --- tensorflow/lite/delegates/flex/kernel.cc | 26 ++++++++++++++---------- tensorflow/lite/delegates/flex/kernel.h | 4 ++++ 2 files changed, 19 insertions(+), 11 deletions(-) diff --git a/tensorflow/lite/delegates/flex/kernel.cc b/tensorflow/lite/delegates/flex/kernel.cc index de0f698ddedd6a..7b4dc17298f1c7 100644 --- a/tensorflow/lite/delegates/flex/kernel.cc +++ b/tensorflow/lite/delegates/flex/kernel.cc @@ -85,7 +85,7 @@ class OpInputs { } forwardable_.resize(inputs_.size()); } - ~OpInputs() {} + ~OpInputs() = default; int Size() const { return inputs_.size(); } @@ -438,7 +438,7 @@ tensorflow::Status DelegateKernel::ExecuteOpKernelRunner( } DelegateKernel::DelegateKernel() : op_data_(new OpData) {} -DelegateKernel::~DelegateKernel() {} +DelegateKernel::~DelegateKernel() = default; TfLiteStatus DelegateKernel::Init(TfLiteContext* context, const TfLiteDelegateParams* params) { @@ -572,20 +572,24 @@ TfLiteStatus DelegateKernel::Prepare(TfLiteContext* context, TfLiteNode* node) { tensor_ref_count[tensor_index] += 2; } - const bool shapes_are_valid = - (ValidateOutputTensorShapeConsistency(context) == kTfLiteOk); - if (shapes_are_valid) { - TFLITE_LOG(tflite::TFLITE_LOG_INFO, - "FlexDelegate: All tensor shapes are consistent."); - } else { - TFLITE_LOG(tflite::TFLITE_LOG_WARNING, - "FlexDelegate: Some tensor shapes are inconsistent."); + // Output shapes which may have initially been inferable may no longer be + // after ResizeInputTensor has been called, so it must be checked again. + if (shapes_are_valid_) { + shapes_are_valid_ = + (ValidateOutputTensorShapeConsistency(context) == kTfLiteOk); + if (shapes_are_valid_) { + TFLITE_LOG(tflite::TFLITE_LOG_INFO, + "FlexDelegate: All tensor shapes are consistent."); + } else { + TFLITE_LOG(tflite::TFLITE_LOG_WARNING, + "FlexDelegate: Some tensor shapes are inconsistent."); + } } // All output tensors are allocated by TensorFlow, so we mark them as // kTfLiteDynamic. for (auto tensor_index : op_data_->subgraph_outputs) { - if (!shapes_are_valid) { + if (!shapes_are_valid_) { SetTensorToDynamic(&context->tensors[tensor_index]); } ++tensor_ref_count[tensor_index]; diff --git a/tensorflow/lite/delegates/flex/kernel.h b/tensorflow/lite/delegates/flex/kernel.h index fabb8367284306..ee162148af5094 100644 --- a/tensorflow/lite/delegates/flex/kernel.h +++ b/tensorflow/lite/delegates/flex/kernel.h @@ -60,6 +60,10 @@ class DelegateKernel : public SimpleDelegateKernelInterface { const std::map& GetTensorReleaseMap() const; std::unique_ptr op_data_; + + // Indicates that the output shapes may be inferred using the input shapes and + // May be allocated during Prepare. + bool shapes_are_valid_ = true; }; } // namespace flex From cdeabc5b68858f84fac7595c9db262d6d1c73d42 Mon Sep 17 00:00:00 2001 From: Ilia Sergachev Date: Fri, 14 Jul 2023 02:41:59 -0700 Subject: [PATCH 309/376] [XLA:GPU] Minor fix: do not trigger Triton GEMM on narrowing but unsupported output type conversions. PiperOrigin-RevId: 548068646 --- .../xla/service/gpu/gemm_rewriter_triton.cc | 2 ++ .../xla/service/gpu/gemm_rewriter_triton_test.cc | 13 +++++++++++++ 2 files changed, 15 insertions(+) diff --git a/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton.cc b/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton.cc index d133928c93879f..102e17bde7dd1b 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton.cc @@ -1434,6 +1434,8 @@ bool ShouldTritonHandleGEMM(const HloInstruction& dot, // Data-narrowing conversion after the dot is profitable to fuse. if (dot.user_count() == 1 && dot.users()[0]->opcode() == HloOpcode::kConvert && + IsSupportedDataType(dot.users()[0]->shape().element_type(), + gpu_version) && InputMinusOutputBytes(*dot.users()[0]) > -kIoToleranceBytes) { return true; } diff --git a/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton_test.cc b/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton_test.cc index c86e78b14389ec..151199150b3e1f 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton_test.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton_test.cc @@ -112,6 +112,19 @@ ENTRY e { GmockMatch(m::Fusion(m::Parameter(), m::Broadcast()))); } +TEST_F(GemmRewriterTritonTest, DoNotTriggerOnUnsupportedOutputConversions) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +ENTRY e { + p0 = f16[128,256] parameter(0) + p1 = f16[256,512] parameter(1) + r = f16[128,512] dot(p0, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} + ROOT c = u8[128,512] convert(r) +})")); + EXPECT_FALSE(GemmRewriterTriton(gpu_version_).Run(module.get()).value()); +} + using TritonDotAnalysisTest = HloTestBase; TEST_F(TritonDotAnalysisTest, NopBitcasts) { From 70e59f884b41f9463e0e5fc474c619d3fc788c49 Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Fri, 14 Jul 2023 03:13:44 -0700 Subject: [PATCH 310/376] Check whether there is a module before calling RegisterModule() Also clean up the signature for RegisterModule(). It doesn't make sense to pass in both module_id and module, and then CHECK that module->module_id() is equal to module_id. PiperOrigin-RevId: 548074722 --- tensorflow/compiler/xla/service/cpu/cpu_executable.cc | 8 ++++---- tensorflow/compiler/xla/service/gpu/gpu_executable.cc | 10 ++++++---- .../compiler/xla/service/xla_debug_info_manager.cc | 6 +++--- .../compiler/xla/service/xla_debug_info_manager.h | 4 ++-- .../xla/service/xla_debug_info_manager_test.cc | 6 +++--- 5 files changed, 18 insertions(+), 16 deletions(-) diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc index bb74517b91cfb3..a358c25fa471d6 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc @@ -75,8 +75,8 @@ CpuExecutable::CpuExecutable( std::make_shared(assignment_->ToProto()); } if (has_module()) { - XlaDebugInfoManager::Get()->RegisterModule( - module().unique_id(), shared_module(), buffer_assignment_); + XlaDebugInfoManager::Get()->RegisterModule(shared_module(), + buffer_assignment_); } // Resolve symbols in the constructor rather than at execution time to avoid @@ -110,8 +110,8 @@ CpuExecutable::CpuExecutable( std::make_shared(assignment_->ToProto()); } if (has_module()) { - XlaDebugInfoManager::Get()->RegisterModule( - module().unique_id(), shared_module(), buffer_assignment_); + XlaDebugInfoManager::Get()->RegisterModule(shared_module(), + buffer_assignment_); } } diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc index 2b69d869020a5e..a475ec153b6878 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc @@ -136,8 +136,8 @@ GpuExecutable::GpuExecutable(GpuExecutable::Params params) *(uint64_t*)(&binary_[binary_.size() - 8]) = tsl::random::New64(); #endif if (has_module() && enable_debug_info_manager_) { - XlaDebugInfoManager::Get()->RegisterModule( - module().unique_id(), shared_module(), debug_buffer_assignment_); + XlaDebugInfoManager::Get()->RegisterModule(shared_module(), + debug_buffer_assignment_); } } @@ -928,8 +928,10 @@ GpuExecutable::GpuExecutable( constants_(std::move(constants)), output_info_(std::move(output_info)), enable_debug_info_manager_(true) { - XlaDebugInfoManager::Get()->RegisterModule( - module().unique_id(), shared_module(), debug_buffer_assignment_); + if (has_module()) { + XlaDebugInfoManager::Get()->RegisterModule(shared_module(), + debug_buffer_assignment_); + } } // Returns a list of functions exported from the `module` that should be loaded diff --git a/tensorflow/compiler/xla/service/xla_debug_info_manager.cc b/tensorflow/compiler/xla/service/xla_debug_info_manager.cc index b0df104a74284b..467d8364876d18 100644 --- a/tensorflow/compiler/xla/service/xla_debug_info_manager.cc +++ b/tensorflow/compiler/xla/service/xla_debug_info_manager.cc @@ -25,11 +25,11 @@ limitations under the License. namespace xla { void XlaDebugInfoManager::RegisterModule( - ModuleIdentifier module_id, std::shared_ptr hlo_module, + std::shared_ptr hlo_module, std::shared_ptr buffer_assignment) { - CHECK(hlo_module != nullptr && module_id == hlo_module->unique_id()); + CHECK(hlo_module != nullptr); absl::MutexLock lock(&mutex_); - auto result = modules_.try_emplace(module_id); + auto result = modules_.try_emplace(hlo_module->unique_id()); CHECK(result.second); XlaModuleEntry& m = result.first->second; m.hlo_module = std::move(hlo_module); diff --git a/tensorflow/compiler/xla/service/xla_debug_info_manager.h b/tensorflow/compiler/xla/service/xla_debug_info_manager.h index 6c16a374b1c962..08d7e8eb54b552 100644 --- a/tensorflow/compiler/xla/service/xla_debug_info_manager.h +++ b/tensorflow/compiler/xla/service/xla_debug_info_manager.h @@ -42,9 +42,9 @@ class XlaDebugInfoManager { } // Registers an active module to XlaDebugInfoManager. - // The module_id is expected to be unique per process. + // The module_id of the module is expected to be unique per process. void RegisterModule( - ModuleIdentifier module_id, std::shared_ptr hlo_module, + std::shared_ptr hlo_module, std::shared_ptr buffer_assignment); // Unregisters an active module. diff --git a/tensorflow/compiler/xla/service/xla_debug_info_manager_test.cc b/tensorflow/compiler/xla/service/xla_debug_info_manager_test.cc index 1aa459e29cb1ac..2fbb876e242802 100644 --- a/tensorflow/compiler/xla/service/xla_debug_info_manager_test.cc +++ b/tensorflow/compiler/xla/service/xla_debug_info_manager_test.cc @@ -27,9 +27,9 @@ namespace xla { class XlaDebugInfoManagerTestPeer { public: void RegisterModule( - ModuleIdentifier module_id, std::shared_ptr hlo_module, + std::shared_ptr hlo_module, std::shared_ptr buffer_assignment) { - return xla_debug_info_manager_.RegisterModule(module_id, hlo_module, + return xla_debug_info_manager_.RegisterModule(hlo_module, buffer_assignment); } @@ -85,7 +85,7 @@ class XlaDebugInfoManagerTest : public HloTestBase { debug_info.buffer_assignment = nullptr; ModuleIdentifier unique_id = debug_info.module->unique_id(); debug_info.unique_id = unique_id; - xla_debug_info_manager_.RegisterModule(unique_id, debug_info.module, + xla_debug_info_manager_.RegisterModule(debug_info.module, debug_info.buffer_assignment); external_references_.push_back(std::move(debug_info)); return unique_id; From 458c94a1907a3b480cdee86812be55a0085965ed Mon Sep 17 00:00:00 2001 From: Johannes Reifferscheid Date: Fri, 14 Jul 2023 03:14:41 -0700 Subject: [PATCH 311/376] HloFusionAnalysis: Don't return statuses where it's not necessary. Also don't repeatedly find tiled transposes quite as often. The interface is still not optimal, but at least we can stop virally returning statuses everywhere. PiperOrigin-RevId: 548074953 --- .../xla/service/gpu/gpu_performance_model.cc | 13 +- .../xla/service/gpu/hlo_fusion_analysis.cc | 119 ++++++++---------- .../xla/service/gpu/hlo_fusion_analysis.h | 74 ++++++++--- .../xla/service/gpu/ir_emitter_unnested.cc | 18 ++- 4 files changed, 126 insertions(+), 98 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/gpu_performance_model.cc b/tensorflow/compiler/xla/service/gpu/gpu_performance_model.cc index 308230501664b5..c0a20c96e5d9d7 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_performance_model.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_performance_model.cc @@ -138,11 +138,14 @@ std::optional EstimateThreadCount( bool use_experimental_block_size) { auto fusion = DynCast(instr); if (fusion != nullptr && cc.has_value()) { - HloFusionAnalysis fusion_analysis(fusion, &gpu_device_info, cc.value()); - auto launch_dimensions = - fusion_analysis.GetLaunchDimensions(use_experimental_block_size); - if (launch_dimensions.ok()) { - return launch_dimensions->launch_bound(); + auto analysis = + HloFusionAnalysis::Create(fusion, &gpu_device_info, cc.value()); + if (analysis.ok()) { + auto launch_dimensions = + analysis->GetLaunchDimensions(use_experimental_block_size); + if (launch_dimensions.ok()) { + return launch_dimensions->launch_bound(); + } } } return std::nullopt; diff --git a/tensorflow/compiler/xla/service/gpu/hlo_fusion_analysis.cc b/tensorflow/compiler/xla/service/gpu/hlo_fusion_analysis.cc index 5c8aacd93b257b..89483a71a85e15 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_fusion_analysis.cc +++ b/tensorflow/compiler/xla/service/gpu/hlo_fusion_analysis.cc @@ -48,41 +48,6 @@ const auto kDimX = TilingScheme::DimX; const auto kLinearIndexingX = TilingScheme::LinearIndexingX; const auto kStridedIndexingX = TilingScheme::StridedIndexingX; -// Returns true if the fusion has consistent transpose heros. -bool HasConsistentTransposeHeros(HloComputation* fusion) { - std::vector hlo_roots = GetFusionRoots(fusion); - if (!HasAnyTiledTransposeRoot(fusion)) { - return false; - } - const HloInstruction* first_transpose = &FindNonTrivialHero(**absl::c_find_if( - hlo_roots, - [](HloInstruction* instr) { return FindAnyTiledTranspose(*instr); })); - const Shape& transpose_in_shape = first_transpose->operand(0)->shape(); - std::optional first_tiled_transpose = - FindAnyTiledTranspose(*first_transpose); - - // We need the following invariant: - // For every tuple element: - // -> EITHER it's a kCopy: S{L} -> S{L'} - // -> OR it's an elementwise op of shape S{L} - for (HloInstruction* root : hlo_roots) { - std::optional tiled_transpose = - FindAnyTiledTranspose(*root); - if (tiled_transpose) { - if (*tiled_transpose != *first_tiled_transpose) { - return false; - } - } else { - if (!ShapeUtil::IsReshapeOrTransposeBitcast( - root->shape(), transpose_in_shape, - /*ignore_element_type=*/true)) { - return false; - } - } - } - return true; -} - // Returns true if the fusion output contains non-strided slices only. bool IsInputFusibleNonStridedSlices(const HloInstruction* root) { if (root->opcode() == HloOpcode::kTuple) { @@ -258,13 +223,47 @@ int64_t NearestPowerOfTwo(int64_t v) { } // namespace -StatusOr -HloFusionAnalysis::GetEmitterFusionKind() const { +// Returns true if the fusion has consistent transpose heros. +bool HloFusionAnalysis::HasConsistentTransposeHeros() const { + if (!tiled_transpose_) { + return false; + } + + auto* fusion = fusion_->fused_instructions_computation(); + std::vector hlo_roots = GetFusionRoots(fusion); + const HloInstruction* first_transpose = + &FindNonTrivialHero(*root_with_tiled_transpose_); + const Shape& transpose_in_shape = first_transpose->operand(0)->shape(); + std::optional first_tiled_transpose = + FindAnyTiledTranspose(*first_transpose); + + // We need the following invariant: + // For every tuple element: + // -> EITHER it's a kCopy: S{L} -> S{L'} + // -> OR it's an elementwise op of shape S{L} + for (HloInstruction* root : hlo_roots) { + std::optional tiled_transpose = + FindAnyTiledTranspose(*root); + if (tiled_transpose) { + if (*tiled_transpose != *first_tiled_transpose) { + return false; + } + } else { + if (!ShapeUtil::IsReshapeOrTransposeBitcast( + root->shape(), transpose_in_shape, + /*ignore_element_type=*/true)) { + return false; + } + } + } + return true; +} + +HloFusionAnalysis::EmitterFusionKind HloFusionAnalysis::GetEmitterFusionKind() + const { #if GOOGLE_CUDA - TF_ASSIGN_OR_RETURN(auto backend_config, - fusion_->backend_config()); - if (backend_config.kind() == kTritonGemmFusionKind || - backend_config.kind() == kTritonSoftmaxFusionKind) { + if (fusion_backend_config_.kind() == kTritonGemmFusionKind || + fusion_backend_config_.kind() == kTritonSoftmaxFusionKind) { return EmitterFusionKind::kTriton; } #endif @@ -273,7 +272,8 @@ HloFusionAnalysis::GetEmitterFusionKind() const { if (HasAnyUnnestedReductionRoot(fused_computation)) { return EmitterFusionKind::kReduction; } - if (HasConsistentTransposeHeros(fused_computation)) { + // We expect that the last dimension is swapped with a different dimension. + if (HasConsistentTransposeHeros() && tiled_transpose_->permutation[2] != 2) { return EmitterFusionKind::kTranspose; } @@ -297,7 +297,7 @@ HloFusionAnalysis::GetEmitterFusionKind() const { StatusOr HloFusionAnalysis::GetLaunchDimensions( bool use_experimental_block_size) { - TF_ASSIGN_OR_RETURN(auto emitter_fusion_kind, GetEmitterFusionKind()); + auto emitter_fusion_kind = GetEmitterFusionKind(); switch (emitter_fusion_kind) { case EmitterFusionKind::kLoop: { // Disable experimental block size if few_waves or row_vectorized enabled. @@ -309,8 +309,7 @@ StatusOr HloFusionAnalysis::GetLaunchDimensions( *loop_fusion_config); } case EmitterFusionKind::kReduction: { - TF_ASSIGN_OR_RETURN(auto reduction_codegen_info, - GetReductionCodegenInfo()); + auto* reduction_codegen_info = GetReductionCodegenInfo(); const TilingScheme& tiling_scheme = reduction_codegen_info->GetTilingScheme(); size_t blocks_y = reduction_codegen_info->GetIndexGroups().size(); @@ -321,7 +320,7 @@ StatusOr HloFusionAnalysis::GetLaunchDimensions( /*y=*/1, /*z=*/1}); } case EmitterFusionKind::kTranspose: { - TF_ASSIGN_OR_RETURN(auto tiling_scheme, GetTransposeTilingScheme()); + auto* tiling_scheme = GetTransposeTilingScheme(); return LaunchDimensions(tiling_scheme->GetNumberOfBlocksPhysical(), tiling_scheme->GetNumThreadsPerBlockPhysical()); } @@ -330,8 +329,7 @@ StatusOr HloFusionAnalysis::GetLaunchDimensions( } } -StatusOr -HloFusionAnalysis::GetReductionCodegenInfo() { +const ReductionCodegenInfo* HloFusionAnalysis::GetReductionCodegenInfo() { if (reduction_codegen_info_.has_value()) { return &reduction_codegen_info_.value(); } @@ -344,34 +342,27 @@ HloFusionAnalysis::GetReductionCodegenInfo() { // We always use the first reduce as representative to construct // ReductionCodegenInfo, since all the reductions are required to have the // same shape and layout as verified by `IsFusedReductionOutputConsistent()`. - TF_ASSIGN_OR_RETURN(auto reduction_codegen_info, - ComputeReductionCodegenInfo(first_reduce)); + auto reduction_codegen_info = ComputeReductionCodegenInfo(first_reduce); reduction_codegen_info_.emplace(std::move(reduction_codegen_info)); return &reduction_codegen_info_.value(); } -StatusOr HloFusionAnalysis::GetTransposeTilingScheme() { +const TilingScheme* HloFusionAnalysis::GetTransposeTilingScheme() { if (transpose_tiling_scheme_.has_value()) { return &transpose_tiling_scheme_.value(); } - std::optional dims_and_order = FindAnyTiledTranspose( - **absl::c_find_if(fusion_roots_, [](HloInstruction* instr) { - return FindAnyTiledTranspose(*instr); - })); - - // TODO(cheshire): have a more robust way of checking this. - TF_RET_CHECK(dims_and_order.has_value()); + if (!tiled_transpose_) { + return nullptr; + } constexpr int kNumRows = 4; - TF_RET_CHECK(WarpSize() % kNumRows == 0); + static_assert(WarpSize() % kNumRows == 0); // 3D view over the input shape. - Vector3 dims = dims_and_order->dimensions; - Vector3 order = dims_and_order->permutation; + Vector3 dims = tiled_transpose_->dimensions; + Vector3 order = tiled_transpose_->permutation; - // We expect that the last dimension is swapped with a different dimension. - TF_RET_CHECK(order[2] != 2); Vector3 permuted_dims = {dims[order[0]], dims[order[1]], dims[order[2]]}; Vector3 tile_sizes{1, 1, 1}; tile_sizes[order[2]] = WarpSize() / kNumRows; @@ -670,7 +661,7 @@ int HloFusionAnalysis::CalculateVirtualThreadScalingFactorForReduction( return 1; } -StatusOr HloFusionAnalysis::ComputeReductionCodegenInfo( +ReductionCodegenInfo HloFusionAnalysis::ComputeReductionCodegenInfo( HloInstruction* first_reduce) const { Shape input_shape = first_reduce->operand(0)->shape(); ReductionDimensions reduction_dimensions = diff --git a/tensorflow/compiler/xla/service/gpu/hlo_fusion_analysis.h b/tensorflow/compiler/xla/service/gpu/hlo_fusion_analysis.h index e647357eb078cb..225523e6901a02 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_fusion_analysis.h +++ b/tensorflow/compiler/xla/service/gpu/hlo_fusion_analysis.h @@ -17,10 +17,13 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HLO_FUSION_ANALYSIS_H_ #include +#include #include #include "tensorflow/compiler/xla/hlo/ir/hlo_computation.h" +#include "tensorflow/compiler/xla/hlo/ir/hlo_instruction.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_instructions.h" +#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" #include "tensorflow/compiler/xla/service/gpu/gpu_device_info.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/gpu/kernel_mapping_scheme.h" @@ -43,38 +46,69 @@ class HloFusionAnalysis { kScatter, }; - HloFusionAnalysis(const HloFusionInstruction* fusion, - const GpuDeviceInfo* device_info, - se::CudaComputeCapability compute_capability) - : fusion_(fusion), - fused_computation_(fusion->fused_instructions_computation()), - fusion_roots_(GetFusionRoots(fusion->fused_instructions_computation())), - device_info_(device_info), - compute_capability_(compute_capability) {} + static StatusOr Create( + const HloFusionInstruction* fusion, const GpuDeviceInfo* device_info, + se::CudaComputeCapability compute_capability) { + TF_ASSIGN_OR_RETURN(auto backend_config, + fusion->backend_config()); + + auto hlo_roots = GetFusionRoots(fusion->fused_instructions_computation()); + HloInstruction* root_with_tiled_transpose; + std::optional tiled_transpose; + + for (auto* root : hlo_roots) { + if ((tiled_transpose = FindAnyTiledTranspose(*root))) { + root_with_tiled_transpose = root; + break; + } + } + + return HloFusionAnalysis(fusion, std::move(backend_config), device_info, + compute_capability, root_with_tiled_transpose, + tiled_transpose); + } - // Simple getters. const HloComputation* fused_computation() const { return fused_computation_; } absl::Span fusion_roots() const { return absl::MakeSpan(fusion_roots_); } - // Determine the fusion type for the emitter. - StatusOr GetEmitterFusionKind() const; + // Determines the fusion type for the emitter. + EmitterFusionKind GetEmitterFusionKind() const; - // Determine the launch dimensions for the fusion. + // Determines the launch dimensions for the fusion. The fusion kind must be + // one of `kLoop`, `kReduction` or `kTranspose`. StatusOr GetLaunchDimensions( bool use_experimental_block_size = false); - // Calculate reduction information (kind: kReduction). - StatusOr GetReductionCodegenInfo(); + // Calculates the reduction information. Returns `nullptr` if the fusion is + // not a reduction. + const ReductionCodegenInfo* GetReductionCodegenInfo(); - // Calculate transpose tiling information (kind: kTranspose). - StatusOr GetTransposeTilingScheme(); + // Calculates the transpose tiling information. Returns `nullptr` if the + // fusion is not a transpose. + const TilingScheme* GetTransposeTilingScheme(); - // Calculate loop fusion config (kind: kLoop). + // Calculates the loop fusion config. Returns `nullptr` if the fusion is not a + // loop. const LaunchDimensionsConfig* GetLoopFusionConfig(); private: + HloFusionAnalysis(const HloFusionInstruction* fusion, + FusionBackendConfig fusion_backend_config, + const GpuDeviceInfo* device_info, + se::CudaComputeCapability compute_capability, + HloInstruction* root_with_tiled_transpose, + std::optional tiled_transpose) + : fusion_(fusion), + fusion_backend_config_(std::move(fusion_backend_config)), + fused_computation_(fusion->fused_instructions_computation()), + fusion_roots_(GetFusionRoots(fusion->fused_instructions_computation())), + device_info_(device_info), + compute_capability_(compute_capability), + root_with_tiled_transpose_(root_with_tiled_transpose), + tiled_transpose_(tiled_transpose) {} + const Shape& GetElementShape() const; int SmallestInputDtypeBits() const; int64_t MaxBeneficialColumnReductionUnrollBasedOnBlockSize() const; @@ -88,14 +122,18 @@ class HloFusionAnalysis { bool reduction_is_race_free) const; int CalculateVirtualThreadScalingFactorForReduction( const ReductionDimensions& reduction_dimensions) const; - StatusOr ComputeReductionCodegenInfo( + ReductionCodegenInfo ComputeReductionCodegenInfo( HloInstruction* first_reduce) const; + bool HasConsistentTransposeHeros() const; const HloFusionInstruction* fusion_; + FusionBackendConfig fusion_backend_config_; const HloComputation* fused_computation_; std::vector fusion_roots_; const GpuDeviceInfo* device_info_; se::CudaComputeCapability compute_capability_; + HloInstruction* root_with_tiled_transpose_; + std::optional tiled_transpose_; std::optional reduction_codegen_info_; std::optional transpose_tiling_scheme_; diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 5c00bb6ab9aebd..aa8c17bc532970 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -75,7 +75,6 @@ limitations under the License. #include "tensorflow/compiler/xla/hlo/ir/hlo_instruction.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_instructions.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_opcode.h" -#include "tensorflow/compiler/xla/hlo/utils/hlo_query.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/mlir_hlo/lhlo/IR/lhlo_ops.h" #include "tensorflow/compiler/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.h" @@ -100,7 +99,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/gpu_device_info.h" #include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" #include "tensorflow/compiler/xla/service/gpu/gpu_fused_mha_runner.h" -#include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h" #include "tensorflow/compiler/xla/service/gpu/hlo_fusion_analysis.h" #include "tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h" #include "tensorflow/compiler/xla/service/gpu/infeed_thunk.h" @@ -139,7 +137,6 @@ limitations under the License. #include "tensorflow/compiler/xla/translate/mhlo_to_hlo/location_exporter.h" #include "tensorflow/compiler/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h" #include "tensorflow/compiler/xla/translate/mhlo_to_lhlo_with_xla/mhlo_to_lhlo_with_xla.h" -#include "tensorflow/compiler/xla/union_find.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/tsl/platform/errors.h" @@ -1989,8 +1986,7 @@ Status IrEmitterUnnested::EmitLoopFusion(mlir::lmhlo::FusionOp fusion, Status IrEmitterUnnested::EmitUnnestedTranspose( mlir::lmhlo::FusionOp fusion, HloFusionAnalysis& fusion_analysis) { - TF_ASSIGN_OR_RETURN(auto tiling_scheme, - fusion_analysis.GetTransposeTilingScheme()); + auto* tiling_scheme = fusion_analysis.GetTransposeTilingScheme(); TF_ASSIGN_OR_RETURN(auto launch_dimensions, fusion_analysis.GetLaunchDimensions()); @@ -2042,12 +2038,13 @@ Status IrEmitterUnnested::EmitFusion(mlir::Operation* op) { // Create HloFusionAnalysis instance. GpuDeviceInfo device_info = ir_emitter_context_->gpu_device_info(); - HloFusionAnalysis fusion_analysis( - &fusion, &device_info, ir_emitter_context_->cuda_compute_capability()); + TF_ASSIGN_OR_RETURN(auto fusion_analysis, + HloFusionAnalysis::Create( + &fusion, &device_info, + ir_emitter_context_->cuda_compute_capability())); // Dispatch to the fusion specific emitter. - TF_ASSIGN_OR_RETURN(auto emitter_fusion_kind, - fusion_analysis.GetEmitterFusionKind()); + auto emitter_fusion_kind = fusion_analysis.GetEmitterFusionKind(); switch (emitter_fusion_kind) { case HloFusionAnalysis::EmitterFusionKind::kTriton: { #if GOOGLE_CUDA @@ -4875,8 +4872,7 @@ Status IrEmitterUnnested::EmitIRForReduction( Status IrEmitterUnnested::EmitUnnestedReduction( mlir::lmhlo::FusionOp fusion, HloFusionAnalysis& fusion_analysis) { - TF_ASSIGN_OR_RETURN(auto reduction_codegen_info, - fusion_analysis.GetReductionCodegenInfo()); + auto* reduction_codegen_info = fusion_analysis.GetReductionCodegenInfo(); TF_ASSIGN_OR_RETURN(auto launch_dimensions, fusion_analysis.GetLaunchDimensions()); From 6e153325b66330dafea4e4e8b67b5d56b1a37852 Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Fri, 14 Jul 2023 04:42:31 -0700 Subject: [PATCH 312/376] [XLA:GPU] Handle edge case in Triton Softmax rewriter where bitcast produces a scalar. This avoids crashing within last_dimension when attempting to match. PiperOrigin-RevId: 548090995 --- .../compiler/xla/service/gpu/softmax_rewriter_triton.cc | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tensorflow/compiler/xla/service/gpu/softmax_rewriter_triton.cc b/tensorflow/compiler/xla/service/gpu/softmax_rewriter_triton.cc index efaa9c4a811b30..0c6bc5c8997182 100644 --- a/tensorflow/compiler/xla/service/gpu/softmax_rewriter_triton.cc +++ b/tensorflow/compiler/xla/service/gpu/softmax_rewriter_triton.cc @@ -69,6 +69,10 @@ bool TrivialEdge(HloInstruction** producer, HloInstruction* consumer, bool BitcastIsTilingNoop(HloInstruction* bitcast) { CHECK_EQ(bitcast->opcode(), HloOpcode::kBitcast); + if (bitcast->shape().rank() == 0) { + return true; + } + // In the Softmax rewriter for now, tiling is derived from a hero reduction // operation, which should be reducing its input on the last axis. Therefore, // a bitcast is always a no-op with regards to a tile if From 391b868f9012794f45e69400c6d7ad75c0f9eb29 Mon Sep 17 00:00:00 2001 From: Aliia Khasanova Date: Fri, 14 Jul 2023 05:14:17 -0700 Subject: [PATCH 313/376] Forward declare AsyncBundleTypeStorage for mhlo/IR/hlo_ops_typedefs.h.inc. PiperOrigin-RevId: 548096999 --- .../compiler/xla/mlir_hlo/mhlo/IR/hlo_ops.cc | 15 +++++---------- .../compiler/xla/mlir_hlo/mhlo/IR/hlo_ops.h | 15 +++++++++------ 2 files changed, 14 insertions(+), 16 deletions(-) diff --git a/tensorflow/compiler/xla/mlir_hlo/mhlo/IR/hlo_ops.cc b/tensorflow/compiler/xla/mlir_hlo/mhlo/IR/hlo_ops.cc index 645366eecc9158..b45ee1443836e9 100644 --- a/tensorflow/compiler/xla/mlir_hlo/mhlo/IR/hlo_ops.cc +++ b/tensorflow/compiler/xla/mlir_hlo/mhlo/IR/hlo_ops.cc @@ -100,9 +100,7 @@ using mlir::hlo::printDimSizes; #define GET_TYPEDEF_CLASSES #include "mhlo/IR/hlo_ops_typedefs.cc.inc" -namespace mlir { -namespace mhlo { - +namespace mlir::mhlo { namespace detail { /// A type representing a collection of other types. struct AsyncBundleTypeStorage final @@ -6279,8 +6277,7 @@ LogicalResult UniformDequantizeOp::inferReturnTypeComponents( using mlir::hlo::parseWindowAttributes; using mlir::hlo::printWindowAttributes; -} // namespace mhlo -} // namespace mlir +} // namespace mlir::mhlo using mlir::hlo::parseComplexOpType; using mlir::hlo::parseCustomCallTarget; @@ -6302,8 +6299,7 @@ using mlir::hlo::printVariadicSameOperandsAndResultType; #define GET_OP_CLASSES #include "mhlo/IR/hlo_ops.cc.inc" -namespace mlir { -namespace mhlo { +namespace mlir::mhlo { //===----------------------------------------------------------------------===// // mhlo Dialect Interfaces @@ -6344,7 +6340,7 @@ struct MhloHloDialectInterface : public hlo::HloDialectInterface { return TypeExtensionsAttr::get(getDialect()->getContext(), bounds); } }; -} // end anonymous namespace +} // namespace //===----------------------------------------------------------------------===// // mhlo Dialect Constructor @@ -7365,5 +7361,4 @@ LogicalResult MhloDialect::verifyOperationAttribute(Operation* op, return success(); } -} // namespace mhlo -} // namespace mlir +} // namespace mlir::mhlo diff --git a/tensorflow/compiler/xla/mlir_hlo/mhlo/IR/hlo_ops.h b/tensorflow/compiler/xla/mlir_hlo/mhlo/IR/hlo_ops.h index 9b54a8494a8309..4a5483a91c6d87 100644 --- a/tensorflow/compiler/xla/mlir_hlo/mhlo/IR/hlo_ops.h +++ b/tensorflow/compiler/xla/mlir_hlo/mhlo/IR/hlo_ops.h @@ -36,6 +36,11 @@ limitations under the License. #include "mlir/Interfaces/SideEffectInterfaces.h" #include "stablehlo/dialect/Base.h" +// Forward declaration for hlo_ops_typedefs.h.inc. +namespace mlir::mhlo::detail { +struct AsyncBundleTypeStorage; +} // namespace mlir::mhlo::detail + // Include order below matters. #include "mhlo/IR/hlo_ops_enums.h.inc" #define GET_ATTRDEF_CLASSES @@ -92,21 +97,19 @@ void printConvolutionDimensions(AsmPrinter &p, Operation *, ParseResult parseConvolutionDimensions(AsmParser &parser, ConvDimensionNumbersAttr &dnums); -} // end namespace mhlo -} // end namespace mlir +} // namespace mhlo +} // namespace mlir #define GET_OP_CLASSES #include "mhlo/IR/hlo_ops.h.inc" -namespace mlir { -namespace mhlo { +namespace mlir::mhlo { SortOp createSortOp(PatternRewriter *rewriter, const Location &loc, const llvm::ArrayRef &operands, const llvm::ArrayRef &elementTypes, int64_t dimension, bool isStable, ComparisonDirection direction); -} // end namespace mhlo -} // end namespace mlir +} // namespace mlir::mhlo #endif // MLIR_HLO_MHLO_IR_HLO_OPS_H From 1b780c50100e7082d650178f79f0a16deded4162 Mon Sep 17 00:00:00 2001 From: Juanli Shen Date: Fri, 14 Jul 2023 08:15:07 -0700 Subject: [PATCH 314/376] Some minor cleanup PiperOrigin-RevId: 548130812 --- tensorflow/core/tfrt/fallback/BUILD | 3 +-- tensorflow/core/tfrt/fallback/cost_recorder.h | 8 ++------ tensorflow/core/tfrt/fallback/cost_recorder_test.cc | 3 +-- 3 files changed, 4 insertions(+), 10 deletions(-) diff --git a/tensorflow/core/tfrt/fallback/BUILD b/tensorflow/core/tfrt/fallback/BUILD index 765008be386279..84566c792ef35d 100644 --- a/tensorflow/core/tfrt/fallback/BUILD +++ b/tensorflow/core/tfrt/fallback/BUILD @@ -126,9 +126,8 @@ tf_cc_test( srcs = ["cost_recorder_test.cc"], deps = [ ":cost_recorder", + ":op_cost_map_proto_cc", "//tensorflow/core:lib", - "//tensorflow/core/platform:status", - "@com_google_absl//absl/strings", "@com_google_googletest//:gtest_main", ], ) diff --git a/tensorflow/core/tfrt/fallback/cost_recorder.h b/tensorflow/core/tfrt/fallback/cost_recorder.h index 9929ca76e1028c..b275e0a5d35791 100644 --- a/tensorflow/core/tfrt/fallback/cost_recorder.h +++ b/tensorflow/core/tfrt/fallback/cost_recorder.h @@ -19,14 +19,12 @@ limitations under the License. #define TENSORFLOW_CORE_TFRT_FALLBACK_COST_RECORDER_H_ #include -#include #include #include "absl/container/flat_hash_map.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/thread_annotations.h" -#include "tensorflow/core/tfrt/fallback/op_cost_map.pb.h" namespace tensorflow { namespace tfrt_stub { @@ -45,8 +43,7 @@ class CostRecorder { // Returns the normalized average execution duration of the op keyed by // `op_key`. If there is no record for `op_key`, returns the uint32_t::max to // avoid stream merging. Note that we don't use uint64_t::max because - // otherwise adding op costs would cause overflow. (See details in - // go/tfrt-stream-analysis-doc.) + // otherwise adding op costs would cause overflow. uint64_t GetCost(int64_t op_key) const; // Writes the op cost map (in format of `OpCostMapProto`) to a file specified @@ -65,8 +62,7 @@ class CostRecorder { uint64_t normalize_ratio_; mutable tensorflow::mutex op_cost_map_mutex_; - // Map op key to {sum of op execution duration in nanoseconds, #occurences of - // the op}. + // Map op key to {sum of op execution duration, #occurences of the op}. absl::flat_hash_map> op_cost_map_ TF_GUARDED_BY(op_cost_map_mutex_); }; diff --git a/tensorflow/core/tfrt/fallback/cost_recorder_test.cc b/tensorflow/core/tfrt/fallback/cost_recorder_test.cc index 6f5d6486eb1c82..827259c0990fb6 100644 --- a/tensorflow/core/tfrt/fallback/cost_recorder_test.cc +++ b/tensorflow/core/tfrt/fallback/cost_recorder_test.cc @@ -19,9 +19,8 @@ limitations under the License. #include #include -#include "absl/strings/string_view.h" #include "tensorflow/core/platform/env.h" -#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/tfrt/fallback/op_cost_map.pb.h" namespace tensorflow { namespace tfrt_stub { From 78f425b7f30a62810f3b7e210e5c7b6c76e34665 Mon Sep 17 00:00:00 2001 From: George Necula Date: Fri, 14 Jul 2023 08:40:02 -0700 Subject: [PATCH 315/376] Separates ValidateStaticShapes from RefineDynamicShapes. In a recent change we have merged ValidateStaticShapes into RefineDynamicShapes. This has the disadvantage that we cannot perform partial shape refinement. In this change we separate ValidatStaticShapes. PiperOrigin-RevId: 548135749 --- .../compiler/tf2xla/kernels/xla_call_module_loader.cc | 4 ++++ .../compiler/tf2xla/kernels/xla_call_module_loader.h | 5 +++++ .../compiler/tf2xla/kernels/xla_call_module_op.cc | 1 + tensorflow/compiler/xla/python/mlir.cc | 11 +++++++---- .../compiler/xla/python/refine_polymorphic_shapes.cc | 7 ++++--- .../compiler/xla/python/refine_polymorphic_shapes.h | 5 ++++- tensorflow/compiler/xla/python/xla_client.py | 2 +- tensorflow/compiler/xla/python/xla_extension/mlir.pyi | 3 ++- 8 files changed, 28 insertions(+), 10 deletions(-) diff --git a/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.cc b/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.cc index 98808caa60d551..e7def0a7f8bab4 100644 --- a/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.cc +++ b/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.cc @@ -554,6 +554,10 @@ tsl::Status XlaCallModuleLoader::ValidateDialect() { return tsl::OkStatus(); } +absl::Status XlaCallModuleLoader::ValidateStaticShapes() { + return xla::ValidateStaticShapes(*module_); +} + absl::Status XlaCallModuleLoader::LowerModuleToMhlo() { mlir::StatusScopedDiagnosticHandler diag_handler(module_->getContext()); diff --git a/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.h b/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.h index 2eab7d25faa2ee..8d8c30f96fbca8 100644 --- a/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.h +++ b/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.h @@ -58,6 +58,11 @@ class XlaCallModuleLoader { // Validates that the module only contains ops from valid dialects. tsl::Status ValidateDialect(); + // Validates that the module represents a statically-shaped StableHLO program, + // otherwise all sorts of weirdness might happen in the HLO exporter which is + // much easier to detect here. + absl::Status ValidateStaticShapes(); + // Lowers the StableHLO module to MHLO in place. absl::Status LowerModuleToMhlo(); diff --git a/tensorflow/compiler/tf2xla/kernels/xla_call_module_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_call_module_op.cc index cd1a09de334223..5445c93c7be650 100644 --- a/tensorflow/compiler/tf2xla/kernels/xla_call_module_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/xla_call_module_op.cc @@ -217,6 +217,7 @@ class XlaCallModuleOp : public XlaOpKernel { input_shapes.push_back(*std::move(shape)); } OP_REQUIRES_OK(ctx, loader_->RefineDynamicShapes(input_shapes)); + OP_REQUIRES_OK(ctx, loader_->ValidateStaticShapes()); OP_REQUIRES_OK(ctx, loader_->LowerModuleToMhlo()); if (!function_list_.empty()) { OP_REQUIRES_OK(ctx, LowerTfFunctionCalls(ctx)); diff --git a/tensorflow/compiler/xla/python/mlir.cc b/tensorflow/compiler/xla/python/mlir.cc index 4c6526aa51cb60..12e2f13b7a234d 100644 --- a/tensorflow/compiler/xla/python/mlir.cc +++ b/tensorflow/compiler/xla/python/mlir.cc @@ -233,18 +233,21 @@ void BuildMlirSubmodule(py::module& m) { py::arg("mlir_module")); mlir_module.def( "refine_polymorphic_shapes", - [](std::string mlir_module, bool enable_shape_assertions) -> py::bytes { + [](std::string mlir_module, bool enable_shape_assertions, + bool validate_static_shapes) -> py::bytes { std::string buffer; llvm::raw_string_ostream os(buffer); - xla::ThrowIfError( - RefinePolymorphicShapes(mlir_module, os, enable_shape_assertions)); + xla::ThrowIfError(RefinePolymorphicShapes( + mlir_module, os, enable_shape_assertions, validate_static_shapes)); return py::bytes(buffer); }, py::arg("mlir_module"), py::arg("enable_shape_assertions") = true, + py::arg("validate_static_shapes") = true, R"(Refines the dynamic shapes for a module. The "main" function must have static shapes and all the intermediate dynamic shapes depend only on the input static - shapes. + shapes. Optionally, also validates that the resulting module has + only static shapes. )"); } diff --git a/tensorflow/compiler/xla/python/refine_polymorphic_shapes.cc b/tensorflow/compiler/xla/python/refine_polymorphic_shapes.cc index 5381caa8d5dffe..01a09c1c244909 100644 --- a/tensorflow/compiler/xla/python/refine_polymorphic_shapes.cc +++ b/tensorflow/compiler/xla/python/refine_polymorphic_shapes.cc @@ -271,12 +271,13 @@ absl::Status RefinePolymorphicShapes(mlir::ModuleOp module, absl::StrCat("Module shape refinement failed: ", diag_handler.ConsumeStatus().ToString())); } - return ValidateStaticShapes(module); + return absl::OkStatus(); } absl::Status RefinePolymorphicShapes(llvm::StringRef module_str, llvm::raw_ostream &os, - bool enable_shape_assertions) { + bool enable_shape_assertions, + bool validate_static_shapes) { mlir::MLIRContext context; if (VLOG_IS_ON(3)) context.disableMultithreading(); context.loadDialect(); @@ -294,10 +295,10 @@ absl::Status RefinePolymorphicShapes(llvm::StringRef module_str, return absl::InvalidArgumentError("Cannot parse module."); } TF_RETURN_IF_ERROR(RefinePolymorphicShapes(*module, enable_shape_assertions)); + if (validate_static_shapes) TF_RETURN_IF_ERROR(ValidateStaticShapes(*module)); if (mlir::failed(mlir::writeBytecodeToFile(*module, os))) { return absl::InternalError("Cannot serialize module."); } - return absl::OkStatus(); } diff --git a/tensorflow/compiler/xla/python/refine_polymorphic_shapes.h b/tensorflow/compiler/xla/python/refine_polymorphic_shapes.h index 726ce3aec07249..75237aeff01e89 100644 --- a/tensorflow/compiler/xla/python/refine_polymorphic_shapes.h +++ b/tensorflow/compiler/xla/python/refine_polymorphic_shapes.h @@ -33,9 +33,12 @@ absl::Status RefinePolymorphicShapes(mlir::ModuleOp module, bool enable_shape_assertions); // Like the above but with serialized input and output modules. +// If `validate_static_shapes` is true, then checks that only static shapes +// are left after refinement. absl::Status RefinePolymorphicShapes(llvm::StringRef module_str, llvm::raw_ostream &os, - bool enable_shape_assertions); + bool enable_shape_assertions, + bool validate_static_shapes); // Validates that the module has only static shapes. absl::Status ValidateStaticShapes(mlir::ModuleOp module); diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index c52fa3c3f9d74f..ca088373b92bef 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -47,7 +47,7 @@ _version = 167 # Version number for MLIR:Python components. -mlir_api_version = 52 +mlir_api_version = 53 xla_platform_names = { 'cpu': 'Host', diff --git a/tensorflow/compiler/xla/python/xla_extension/mlir.pyi b/tensorflow/compiler/xla/python/xla_extension/mlir.pyi index e3f073b1defbc9..f62c6565dc0a41 100644 --- a/tensorflow/compiler/xla/python/xla_extension/mlir.pyi +++ b/tensorflow/compiler/xla/python/xla_extension/mlir.pyi @@ -25,4 +25,5 @@ def stablehlo_to_mhlo(mlir_module: Union[bytes, str]) -> str: ... def serialize_portable_artifact(mlir_module: str, target:str) -> bytes: ... def deserialize_portable_artifact(mlir_module: bytes) -> str: ... def refine_polymorphic_shapes(mlir_module: Union[bytes, str], - enable_shape_assertions: bool = ...) -> bytes: ... + enable_shape_assertions: bool = ..., + validate_static_shapes: bool = ...) -> bytes: ... From 9743226b198993255ffddb367af71c6a2f51b656 Mon Sep 17 00:00:00 2001 From: Ilia Sergachev Date: Fri, 14 Jul 2023 08:42:57 -0700 Subject: [PATCH 316/376] [XLA:GPU] Unify two fusion traversals in the Triton GEMM rewriter. Two separate traversals from dot() with approximately same logic were used to first identify which dots should use Triton GEMM and then to do the actual fusion. As the logic of traversal becomes more complicated it's better to make them have the same implementation. PiperOrigin-RevId: 548136358 --- .../xla/service/gpu/gemm_rewriter_triton.cc | 302 +++++++++--------- .../xla/service/gpu/gemm_rewriter_triton.h | 2 +- 2 files changed, 145 insertions(+), 159 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton.cc b/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton.cc index 102e17bde7dd1b..d1acb13ce8d759 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton.cc @@ -694,7 +694,7 @@ FusionDecision CanFuse(const HloInstruction& hlo, bool as_input, void Fuse(HloInstruction& hlo, absl::flat_hash_map& old_to_new_mapping, - std::vector& call_operands, + std::vector& fusion_inputs, HloComputation::Builder& builder) { if (old_to_new_mapping.contains(&hlo)) { return; @@ -705,12 +705,12 @@ void Fuse(HloInstruction& hlo, it != old_to_new_mapping.end()) { return it->second; } - call_operands.push_back(&instr); + fusion_inputs.push_back(&instr); return old_to_new_mapping .insert({&instr, builder.AddInstruction(HloInstruction::CreateParameter( - call_operands.size() - 1, instr.shape(), - absl::StrCat("parameter_", call_operands.size() - 1)))}) + fusion_inputs.size() - 1, instr.shape(), + absl::StrCat("parameter_", fusion_inputs.size() - 1)))}) .first->second; }; if (hlo.opcode() == HloOpcode::kParameter || @@ -748,7 +748,7 @@ void FuseWithInputsRecursively( const GpuVersion gpu_version, absl::flat_hash_map& old_to_new_mapping, - std::vector& call_operands, + std::vector& fusion_inputs, HloComputation::Builder& builder) { absl::flat_hash_set visited; std::stack to_fuse; @@ -805,12 +805,131 @@ void FuseWithInputsRecursively( } } if (top_is_ready_to_fuse) { - Fuse(*hlo, old_to_new_mapping, call_operands, builder); + Fuse(*hlo, old_to_new_mapping, fusion_inputs, builder); to_fuse.pop(); } } } +// Fuses dot and the compatible and profitable to fuse operations around it +// into a new fusion computation constructed using the builder. fusion_inputs +// get populated with the non-fused instructions that become operands of the +// call to this fusion. fusion_output_ptr (if not nullptr) gets assigned the +// original instruction that has to be replaced by the call to the fusion. +StatusOr FuseDot(HloInstruction& dot, + const GpuVersion gpu_version, + HloComputation::Builder& builder, + std::vector& fusion_inputs, + HloInstruction** fusion_output_ptr) { + VLOG(5) << dot.ToString(); + if (FusionDecision can_handle = CanTritonHandleGEMM(dot, gpu_version); + !can_handle) { + VLOG(3) << can_handle.Explain(); + return can_handle; + } + + // Original instruction -> fused one. + absl::flat_hash_map + old_to_new_mapping; + + // Separate traversal from LHS and RHS inputs of the dot: they use + // differently shaped tiles but may go through same HLO graph nodes. + // Direct dot inputs have well defined dimension orders. + + auto fuse_inputs = [&](int operand_number) + -> StatusOr> { + absl::flat_hash_map dim_orders; + int operand_count_before = fusion_inputs.size(); + // Direct dot inputs have well defined dimension orders. + FuseWithInputsRecursively( + dot.mutable_operand(operand_number), + DimensionOrder::FromDotOperand(dot, operand_number), dim_orders, + gpu_version, old_to_new_mapping, fusion_inputs, builder); + TF_RET_CHECK(fusion_inputs.size() - operand_count_before <= + DotFusionAnalysis::kMaxParameterPerScope); + return dim_orders; + }; + // Check if non-contracting dimension originating from LHS operand in the + // output can be split. This currently requires this dimension being split + // in the operand the same way. + int64_t lhs_nc_split_major_part = -1; + { + TF_ASSIGN_OR_RETURN(const auto lhs_dim_orders, fuse_inputs(0)); + // Looking at first LHS parameter to find split non-contracting dimension + // is sufficient because currently all parameters of one scope have to use + // the same tiling. + auto first_lhs_parameter_it = lhs_dim_orders.cbegin(); + while (first_lhs_parameter_it != lhs_dim_orders.cend()) { + if (first_lhs_parameter_it->first->opcode() == HloOpcode::kParameter) { + break; + } + ++first_lhs_parameter_it; + } + if (first_lhs_parameter_it != lhs_dim_orders.cend()) { + const auto lhs_nc_iter_spec = DimensionOrderToTensorIterationSpec( + first_lhs_parameter_it->second)[NonContractingDimensionIndex(dot, 0)]; + if (lhs_nc_iter_spec.size() > 1) { + lhs_nc_split_major_part = lhs_nc_iter_spec.at(1).count; + } + } + } + TF_RET_CHECK(fuse_inputs(1).ok()); + + Fuse(dot, old_to_new_mapping, fusion_inputs, builder); + + // Fusion at dot's output. + + // These describe _outputs_ of corresponding HLOs. + absl::flat_hash_map out_dim_orders; + out_dim_orders.insert( + {&dot, DimensionOrder::FromDotOutput(dot, /*split_k=*/1, + lhs_nc_split_major_part)}); + HloInstruction* fusion_output = ˙ + bool output_changed = true; + while (output_changed) { + output_changed = false; + if (fusion_output->user_count() != 1) { + break; + } + HloInstruction* user = fusion_output->users()[0]; + if (!IsDistributiveOverAddition(*user)) { + break; + } + // Describes the output of `current_output` = input of `user`. + DimensionOrder dim_order(out_dim_orders.at(fusion_output)); + if (CanFuse(*user, /*as_input=*/false, dim_order, old_to_new_mapping, + gpu_version)) { + // Now it describes the output of the user. + CHECK(out_dim_orders.insert({user, dim_order}).second); + for (HloInstruction* operand : user->operands()) { + if (!old_to_new_mapping.contains(operand)) { + // Here we need again a dim order describing inputs of the user. + FuseWithInputsRecursively( + operand, DimensionOrder(out_dim_orders.at(fusion_output)), + out_dim_orders, gpu_version, old_to_new_mapping, fusion_inputs, + builder); + } + } + Fuse(*user, old_to_new_mapping, fusion_inputs, builder); + fusion_output = user; + output_changed = true; + } + } + if (fusion_output_ptr != nullptr) { + *fusion_output_ptr = fusion_output; + } + if (dot.GetModule()->config().debug_options().xla_gpu_triton_gemm_any()) { + return FusionDecision{}; + } + for (const auto& iter : old_to_new_mapping) { + if (iter.second->opcode() == HloOpcode::kConvert || + iter.second->opcode() == HloOpcode::kTranspose) { + return FusionDecision{}; + } + } + return "No profitable operations to fuse."; +} + // Extracts into fused computations parts of HLO graph including dot() // operations that can target the triton GEMM emitter. class GemmRewriterTritonVisitor : public DfsHloRewriteVisitor { @@ -821,126 +940,34 @@ class GemmRewriterTritonVisitor : public DfsHloRewriteVisitor { // if so - fuses all its compatible inputs and outputs as a new computation // and replaces the original dot() with a call to the computation. Status HandleDot(HloInstruction* dot) override { - VLOG(5) << dot->ToString(); - FusionDecision can_handle = CanTritonHandleGEMM(*dot, gpu_version_); - if (!can_handle) { - VLOG(3) << can_handle.Explain(); + std::string fusion_name = absl::StrCat("triton_gemm_", dot->name()); + HloComputation::Builder builder(absl::StrCat(fusion_name, "_computation")); + std::vector fusion_inputs; + HloInstruction* fusion_output = nullptr; + TF_ASSIGN_OR_RETURN( + const FusionDecision should_fuse, + FuseDot(*dot, gpu_version_, builder, fusion_inputs, &fusion_output)); + if (builder.last_added_instruction() == nullptr) { return OkStatus(); } - // If a GEMM requiring padding for cuBLAS is encountered here this // happened because earlier ShouldTritonHandleGEMM() accepted it and padding - // was skipped. Do not check ShouldTritonHandleGEMM() again then. + // was skipped. Accept it ignoring profitability checks. if (!CublasRequiresPadding( - *xla::Cast(dot), + *Cast(dot), std::get(gpu_version_)) && - !ShouldTritonHandleGEMM(*dot, gpu_version_)) { + !should_fuse) { return OkStatus(); } - std::string suggested_name = absl::StrCat("triton_gemm_", dot->name()); - HloComputation::Builder builder( - absl::StrCat(suggested_name, "_computation")); - std::vector call_operands; - // Original instruction -> fused one. - absl::flat_hash_map - old_to_new_mapping; - - // Separate traversal from LHS and RHS inputs of the dot: they use - // differently shaped tiles but may go through same HLO graph nodes. - // Direct dot inputs have well defined dimension orders. - - auto fuse_inputs = [&](int operand_number) - -> StatusOr< - absl::flat_hash_map> { - absl::flat_hash_map dim_orders; - int operand_count_before = call_operands.size(); - // Direct dot inputs have well defined dimension orders. - FuseWithInputsRecursively( - dot->mutable_operand(operand_number), - DimensionOrder::FromDotOperand(*dot, operand_number), dim_orders, - gpu_version_, old_to_new_mapping, call_operands, builder); - TF_RET_CHECK(call_operands.size() - operand_count_before <= - DotFusionAnalysis::kMaxParameterPerScope); - return dim_orders; - }; - // Check if non-contracting dimension originating from LHS operand in the - // output can be split. This currently requires this dimension being split - // in the operand the same way. - int64_t lhs_nc_split_major_part = -1; - { - TF_ASSIGN_OR_RETURN(const auto lhs_dim_orders, fuse_inputs(0)); - // Looking at first LHS parameter to find split non-contracting dimension - // is sufficient because currently all parameters of one scope have to use - // the same tiling. - auto first_lhs_parameter_it = lhs_dim_orders.cbegin(); - while (first_lhs_parameter_it != lhs_dim_orders.cend()) { - if (first_lhs_parameter_it->first->opcode() == HloOpcode::kParameter) { - break; - } - ++first_lhs_parameter_it; - } - if (first_lhs_parameter_it != lhs_dim_orders.cend()) { - const auto lhs_nc_iter_spec = DimensionOrderToTensorIterationSpec( - first_lhs_parameter_it - ->second)[NonContractingDimensionIndex(*dot, 0)]; - if (lhs_nc_iter_spec.size() > 1) { - lhs_nc_split_major_part = lhs_nc_iter_spec.at(1).count; - } - } - } - TF_RET_CHECK(fuse_inputs(1).ok()); - - Fuse(*dot, old_to_new_mapping, call_operands, builder); - - // Fusion at dot's output. - - // These describe _outputs_ of corresponding HLOs. - absl::flat_hash_map out_dim_orders; - out_dim_orders.insert( - {dot, DimensionOrder::FromDotOutput(*dot, /*split_k=*/1, - lhs_nc_split_major_part)}); - HloInstruction* fusion_output = dot; - bool output_changed = true; - while (output_changed) { - output_changed = false; - if (fusion_output->user_count() != 1) { - break; - } - HloInstruction* user = fusion_output->users()[0]; - if (!IsDistributiveOverAddition(*user)) { - break; - } - // Describes the output of `current_output` = input of `user`. - DimensionOrder dim_order(out_dim_orders.at(fusion_output)); - if (CanFuse(*user, /*as_input=*/false, dim_order, old_to_new_mapping, - gpu_version_)) { - // Now it describes the output of the user. - CHECK(out_dim_orders.insert({user, dim_order}).second); - for (HloInstruction* operand : user->operands()) { - if (!old_to_new_mapping.contains(operand)) { - // Here we need again a dim order describing inputs of the user. - FuseWithInputsRecursively( - operand, DimensionOrder(out_dim_orders.at(fusion_output)), - out_dim_orders, gpu_version_, old_to_new_mapping, call_operands, - builder); - } - } - Fuse(*user, old_to_new_mapping, call_operands, builder); - fusion_output = user; - output_changed = true; - } - } - HloComputation* computation = dot->GetModule()->AddComputationAndUnifyNamesAndIds(builder.Build(), /*is_entry=*/false); HloInstruction* dot_fusion = dot->parent()->AddInstruction(HloInstruction::CreateFusion( computation->root_instruction()->shape(), - HloInstruction::FusionKind::kCustom, call_operands, computation)); - dot_fusion->GetModule()->SetAndUniquifyInstrName(dot_fusion, - suggested_name); + HloInstruction::FusionKind::kCustom, fusion_inputs, computation)); + dot_fusion->GetModule()->SetAndUniquifyInstrName(dot_fusion, fusion_name); TF_ASSIGN_OR_RETURN(auto backend_config, dot_fusion->backend_config()); @@ -1425,53 +1452,12 @@ FusionDecision CanTritonHandleGEMM(const HloInstruction& dot, return FusionDecision{}; } -bool ShouldTritonHandleGEMM(const HloInstruction& dot, - const GpuVersion gpu_version) { - if (dot.GetModule()->config().debug_options().xla_gpu_triton_gemm_any()) { - return true; - } - - // Data-narrowing conversion after the dot is profitable to fuse. - if (dot.user_count() == 1 && - dot.users()[0]->opcode() == HloOpcode::kConvert && - IsSupportedDataType(dot.users()[0]->shape().element_type(), - gpu_version) && - InputMinusOutputBytes(*dot.users()[0]) > -kIoToleranceBytes) { - return true; - } - - // Traverse HLO graph part checking that it both can be fused - // and is worth fusing. - auto has_triton_fusible_inputs = [&gpu_version](const HloInstruction& dot, - const int operand_number) { - absl::flat_hash_map - old_to_new_mapping; - DimensionOrder dim_order = - DimensionOrder::FromDotOperand(dot, operand_number); - std::queue queue; - queue.push(dot.operand(operand_number)); - while (!queue.empty()) { - const HloInstruction* current = queue.front(); - queue.pop(); - if (!CanFuse(*current, /*as_input=*/true, dim_order, old_to_new_mapping, - gpu_version)) { - continue; - } - // The values in the map are not used by CanFuse(). - old_to_new_mapping.insert({current, nullptr}); - // Stop as soon as a profitable operation is fused. - if (current->opcode() == HloOpcode::kConvert || - current->opcode() == HloOpcode::kTranspose) { - return true; - } - for (const HloInstruction* operand : current->operands()) { - queue.push(operand); - } - } - return false; - }; - - return has_triton_fusible_inputs(dot, 0) || has_triton_fusible_inputs(dot, 1); +bool ShouldTritonHandleGEMM(HloInstruction& dot, const GpuVersion gpu_version) { + std::vector fusion_inputs; + HloComputation::Builder builder("disposable"); + return FuseDot(dot, gpu_version, builder, fusion_inputs, + /*fusion_output_ptr=*/nullptr) + ->CanFuse(); } StatusOr GemmRewriterTriton::Run( diff --git a/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton.h b/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton.h index 0afc939b43ede2..6619d16196d1b3 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton.h +++ b/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton.h @@ -57,7 +57,7 @@ FusionDecision CanTritonHandleGEMM(const HloInstruction&, GpuVersion gpu_version); // Filters GEMMs which are better to handle using Triton. -bool ShouldTritonHandleGEMM(const HloInstruction&, GpuVersion gpu_version); +bool ShouldTritonHandleGEMM(HloInstruction&, GpuVersion gpu_version); class TensorIterationSpec { public: From 257a5fc02c8e76108d60dc288499b0f70299a0bb Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 14 Jul 2023 09:00:15 -0700 Subject: [PATCH 317/376] [PJRT C API] Move PlatformName test to the test factory (pjrt_c_api_test.cc). PiperOrigin-RevId: 548140094 --- .../xla/pjrt/c/pjrt_c_api_cpu_test.cc | 16 ++------- .../xla/pjrt/c/pjrt_c_api_gpu_test.cc | 16 ++------- .../compiler/xla/pjrt/c/pjrt_c_api_test.cc | 36 +++++++++++++++++-- .../compiler/xla/pjrt/c/pjrt_c_api_test.h | 4 ++- 4 files changed, 42 insertions(+), 30 deletions(-) diff --git a/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_cpu_test.cc b/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_cpu_test.cc index 6bb8f9fd37cf11..84d3629031f85a 100644 --- a/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_cpu_test.cc +++ b/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_cpu_test.cc @@ -24,8 +24,9 @@ namespace xla { namespace pjrt { namespace { -const bool kUnused = - (RegisterPjRtCApiTestFactory([]() { return GetPjrtApi(); }), true); +const bool kUnused = (RegisterPjRtCApiTestFactory([]() { return GetPjrtApi(); }, + /*platform_name=*/"cpu"), + true); class PjrtCApiCpuTest : public ::testing::Test { protected: @@ -66,17 +67,6 @@ class PjrtCApiCpuTest : public ::testing::Test { } }; -TEST_F(PjrtCApiCpuTest, PlatformName) { - PJRT_Client_PlatformName_Args args; - args.client = client_; - args.struct_size = PJRT_Client_PlatformName_Args_STRUCT_SIZE; - args.priv = nullptr; - PJRT_Error* error = api_->PJRT_Client_PlatformName(&args); - ASSERT_EQ(error, nullptr); - absl::string_view platform_name(args.platform_name, args.platform_name_size); - ASSERT_EQ("cpu", platform_name); -} - } // namespace } // namespace pjrt } // namespace xla diff --git a/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_gpu_test.cc b/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_gpu_test.cc index 285869bede2d28..d1e7ef9ab96b31 100644 --- a/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_gpu_test.cc +++ b/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_gpu_test.cc @@ -32,8 +32,9 @@ namespace xla { namespace pjrt { namespace { -const bool kUnused = - (RegisterPjRtCApiTestFactory([]() { return GetPjrtApi(); }), true); +const bool kUnused = (RegisterPjRtCApiTestFactory([]() { return GetPjrtApi(); }, + /*platform_name=*/"gpu"), + true); class PjrtCApiGpuTest : public ::testing::Test { protected: @@ -74,17 +75,6 @@ class PjrtCApiGpuTest : public ::testing::Test { } }; -TEST_F(PjrtCApiGpuTest, PlatformName) { - PJRT_Client_PlatformName_Args args; - args.client = client_; - args.struct_size = PJRT_Client_PlatformName_Args_STRUCT_SIZE; - args.priv = nullptr; - PJRT_Error* error = api_->PJRT_Client_PlatformName(&args); - ASSERT_EQ(error, nullptr); - absl::string_view platform_name(args.platform_name, args.platform_name_size); - ASSERT_EQ("gpu", platform_name); -} - std::unique_ptr<::pjrt::PJRT_KeyValueCallbackData> CreateTestCKVCallback( absl::flat_hash_map* kv_store, absl::Mutex& mu) { PjRtClient::KeyValueGetCallback kv_get = diff --git a/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_test.cc b/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_test.cc index 4bde253fee9a68..4daf5669954bfe 100644 --- a/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_test.cc +++ b/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_test.cc @@ -74,10 +74,14 @@ ENTRY %TupleCreate.v4 (v1: f32[], v2: f32[3], v3: f32[2,3]) -> (f32[], f32[3], f class TestCApiFactory { public: - void Register(std::function factory) { + void Register(std::function factory, + absl::string_view platform_name) { absl::MutexLock lock(&mu_); CHECK(!factory_); factory_ = std::move(factory); + CHECK(platform_name_.empty()) << "Platform name already provided"; + CHECK(!platform_name.empty()) << "Provided platform name is empty"; + platform_name_ = platform_name; } std::function Get() const { @@ -86,9 +90,17 @@ class TestCApiFactory { return factory_; } + std::string GetPlatformName() const { + absl::MutexLock lock(&mu_); + CHECK(!platform_name_.empty()) + << "Test didn't call RegisterPjRtCApiTestFactory()"; + return platform_name_; + } + private: mutable absl::Mutex mu_; std::function factory_ ABSL_GUARDED_BY(mu_); + std::string platform_name_; }; TestCApiFactory& GetGlobalTestCApiFactory() { @@ -98,10 +110,15 @@ TestCApiFactory& GetGlobalTestCApiFactory() { const PJRT_Api* GetCApi() { return GetGlobalTestCApiFactory().Get()(); } +std::string GetPlatformName() { + return GetGlobalTestCApiFactory().GetPlatformName(); +} + } // namespace -void RegisterPjRtCApiTestFactory(std::function factory) { - GetGlobalTestCApiFactory().Register(std::move(factory)); +void RegisterPjRtCApiTestFactory(std::function factory, + absl::string_view platform_name) { + GetGlobalTestCApiFactory().Register(std::move(factory), platform_name); } namespace { @@ -109,6 +126,7 @@ class PjrtCApiTest : public ::testing::Test { protected: const PJRT_Api* api_; PJRT_Client* client_; + std::string platform_name_; // We directly access the internal C++ client to test if the C API has the // same behavior as the C++ API. xla::PjRtClient* cc_client_; @@ -117,6 +135,7 @@ class PjrtCApiTest : public ::testing::Test { void SetUp() override { api_ = GetCApi(); client_ = make_client(); + platform_name_ = GetPlatformName(); } void TearDown() override { destroy_client(client_); } @@ -384,6 +403,17 @@ TEST_F(PjrtCApiTest, ApiVersion) { // ---------------------------------- Client ----------------------------------- +TEST_F(PjrtCApiTest, PlatformName) { + PJRT_Client_PlatformName_Args args; + args.client = client_; + args.struct_size = PJRT_Client_PlatformName_Args_STRUCT_SIZE; + args.priv = nullptr; + PJRT_Error* error = api_->PJRT_Client_PlatformName(&args); + ASSERT_EQ(error, nullptr); + absl::string_view platform_name(args.platform_name, args.platform_name_size); + ASSERT_EQ(platform_name_, platform_name); +} + TEST_F(PjrtCApiTest, ClientProcessIndex) { PJRT_Client_ProcessIndex_Args process_index_args = PJRT_Client_ProcessIndex_Args{ diff --git a/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_test.h b/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_test.h index a2f4b7c1334652..742bad437d7b4d 100644 --- a/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_test.h +++ b/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_test.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/pjrt/c/pjrt_c_api.h" namespace xla { @@ -28,7 +29,8 @@ namespace pjrt { // all the tests in this test factory with the PJRT_Api generated by the input // to RegisterPjRtCApiTestFactory. See // tensorflow/compiler/xla/pjrt/c/pjrt_c_api_cpu_test.cc for an example usage -void RegisterPjRtCApiTestFactory(std::function factory); +void RegisterPjRtCApiTestFactory(std::function factory, + absl::string_view platform_name); } // namespace pjrt } // namespace xla From a003528c19a5c2d5e9ec4bd9ac4ddccfa58d8d28 Mon Sep 17 00:00:00 2001 From: Juan Martinez Castellanos Date: Fri, 14 Jul 2023 10:07:39 -0700 Subject: [PATCH 318/376] Make all Python targets under tensorflow/compiler/mlir/tfrt/jit/python_binding/**/ have strict dependencies. PiperOrigin-RevId: 548155762 --- tensorflow/compiler/mlir/stablehlo/BUILD | 7 ++++--- tensorflow/compiler/mlir/tensorflow/BUILD | 2 ++ .../compiler/mlir/tfrt/jit/python_binding/BUILD | 15 ++++++--------- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/tensorflow/compiler/mlir/stablehlo/BUILD b/tensorflow/compiler/mlir/stablehlo/BUILD index 162572e5d7c298..ddea76059bfdcd 100644 --- a/tensorflow/compiler/mlir/stablehlo/BUILD +++ b/tensorflow/compiler/mlir/stablehlo/BUILD @@ -1,4 +1,5 @@ -load("//tensorflow:pytype.default.bzl", "pytype_library") +load("//tensorflow:strict.default.bzl", "py_strict_test") +load("//tensorflow:pytype.default.bzl", "pytype_strict_library") load("//tensorflow/tsl:tsl.default.bzl", "tsl_pybind_extension") package( @@ -36,7 +37,7 @@ tsl_pybind_extension( ], ) -pytype_library( +pytype_strict_library( name = "stablehlo", srcs = ["stablehlo.py"], srcs_version = "PY3", @@ -46,7 +47,7 @@ pytype_library( ], ) -py_test( +py_strict_test( name = "stablehlo_test", srcs = ["stablehlo_test.py"], python_version = "PY3", diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index 207ec7970aa1c1..474ab34155e8eb 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -1,3 +1,4 @@ +load("//tensorflow:strict.default.bzl", "py_strict_library") load("//tensorflow:tensorflow.default.bzl", "filegroup", "get_compatible_with_portable") load("@bazel_skylib//rules:build_test.bzl", "build_test") load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") @@ -2329,6 +2330,7 @@ tf_gen_op_wrapper_py( name = "gen_mlir_passthrough_op_py", out = "gen_mlir_passthrough_op.py", compatible_with = [], + py_lib_rule = py_strict_library, deps = [":mlir_passthrough_op"], ) diff --git a/tensorflow/compiler/mlir/tfrt/jit/python_binding/BUILD b/tensorflow/compiler/mlir/tfrt/jit/python_binding/BUILD index 6314fb77c9052a..868f9814127bd2 100644 --- a/tensorflow/compiler/mlir/tfrt/jit/python_binding/BUILD +++ b/tensorflow/compiler/mlir/tfrt/jit/python_binding/BUILD @@ -1,22 +1,19 @@ load("//tensorflow:tensorflow.default.bzl", "pybind_extension", "pybind_library") -load("//tensorflow:strict.default.bzl", "py_strict_test") - -licenses(["notice"]) +load("//tensorflow:strict.default.bzl", "py_strict_library", "py_strict_test") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = [":__subpackages__"], ) -py_library( +licenses(["notice"]) + +py_strict_library( name = "tf_jitrt", testonly = 1, srcs = ["tf_jitrt.py"], visibility = ["//tensorflow/compiler/mlir/tfrt:__subpackages__"], - deps = [ - ":_tf_jitrt_executor", - "//third_party/py/numpy", - ], + deps = [":_tf_jitrt_executor"], ) py_strict_test( @@ -59,7 +56,7 @@ pybind_extension( ], ) -py_library( +py_strict_library( name = "tfrt_fallback", testonly = True, srcs = ["tfrt_fallback.py"], From d254eab530fd0edb5dbd40bf89a24451fb9361e2 Mon Sep 17 00:00:00 2001 From: Son Tuan Vu Date: Fri, 14 Jul 2023 10:18:09 -0700 Subject: [PATCH 319/376] [XLA:GPU][NFC] Remove unused includes PiperOrigin-RevId: 548158482 --- tensorflow/compiler/xla/hlo/ir/hlo_sharding.cc | 2 +- tensorflow/compiler/xla/pjrt/BUILD | 1 + .../xla/pjrt/tracked_tfrt_cpu_device_buffer.h | 3 ++- .../xla/pjrt/tracked_tfrt_cpu_device_buffer_test.cc | 1 - tensorflow/compiler/xla/service/BUILD | 1 + tensorflow/compiler/xla/service/gpu/BUILD | 7 ------- .../xla/service/gpu/autotuner_compile_util.cc | 3 --- .../compiler/xla/service/gpu/autotuner_compile_util.h | 11 ----------- .../xla/service/gpu/cudnn_fused_mha_rewriter.cc | 1 + tensorflow/compiler/xla/service/gpu/gpu_compiler.cc | 9 +-------- tensorflow/compiler/xla/service/gpu/gpu_compiler.h | 4 ---- .../compiler/xla/service/gpu/ir_emission_utils.cc | 2 -- .../compiler/xla/service/gpu/ir_emission_utils.h | 2 -- tensorflow/compiler/xla/service/gpu/matmul_utils.cc | 2 +- tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc | 2 -- tensorflow/compiler/xla/service/gpu/nvptx_compiler.h | 2 -- .../compiler/xla/service/gpu/triton_autotuner.cc | 2 -- tensorflow/compiler/xla/service/hlo_module_config.h | 1 + tensorflow/compiler/xla/shape_util.cc | 2 +- tensorflow/compiler/xla/shape_util.h | 6 ------ tensorflow/compiler/xla/tests/BUILD | 1 + tensorflow/compiler/xla/tests/literal_test_util.cc | 1 + 22 files changed, 12 insertions(+), 54 deletions(-) diff --git a/tensorflow/compiler/xla/hlo/ir/hlo_sharding.cc b/tensorflow/compiler/xla/hlo/ir/hlo_sharding.cc index cc8771ce430d80..196fb2ded19cbf 100644 --- a/tensorflow/compiler/xla/hlo/ir/hlo_sharding.cc +++ b/tensorflow/compiler/xla/hlo/ir/hlo_sharding.cc @@ -18,7 +18,6 @@ limitations under the License. #include #include #include -#include #include #include #include @@ -40,6 +39,7 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/tsl/platform/errors.h" +#include "tensorflow/tsl/platform/protobuf.h" namespace xla { namespace { diff --git a/tensorflow/compiler/xla/pjrt/BUILD b/tensorflow/compiler/xla/pjrt/BUILD index 46fbf312d3aa1d..600dce75cbad40 100644 --- a/tensorflow/compiler/xla/pjrt/BUILD +++ b/tensorflow/compiler/xla/pjrt/BUILD @@ -495,6 +495,7 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/runtime:cpu_event", + "//tensorflow/tsl/platform:env", "//tensorflow/tsl/platform:platform_port", "@com_google_absl//absl/base", "@com_google_absl//absl/container:inlined_vector", diff --git a/tensorflow/compiler/xla/pjrt/tracked_tfrt_cpu_device_buffer.h b/tensorflow/compiler/xla/pjrt/tracked_tfrt_cpu_device_buffer.h index b8c6bdee610fe6..25cf4d8ddff523 100644 --- a/tensorflow/compiler/xla/pjrt/tracked_tfrt_cpu_device_buffer.h +++ b/tensorflow/compiler/xla/pjrt/tracked_tfrt_cpu_device_buffer.h @@ -21,13 +21,14 @@ limitations under the License. #include #include "absl/container/inlined_vector.h" -#include "absl/synchronization/mutex.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/cpu_function_runtime.h" #include "tensorflow/compiler/xla/runtime/cpu_event.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/util.h" +#include "tensorflow/tsl/platform/env.h" #include "tensorflow/tsl/platform/mem.h" +#include "tensorflow/tsl/platform/threadpool.h" #include "tfrt/host_context/async_value_ref.h" // from @tf_runtime namespace xla { diff --git a/tensorflow/compiler/xla/pjrt/tracked_tfrt_cpu_device_buffer_test.cc b/tensorflow/compiler/xla/pjrt/tracked_tfrt_cpu_device_buffer_test.cc index 0bc6162b819b8d..05d47449c6cd3d 100644 --- a/tensorflow/compiler/xla/pjrt/tracked_tfrt_cpu_device_buffer_test.cc +++ b/tensorflow/compiler/xla/pjrt/tracked_tfrt_cpu_device_buffer_test.cc @@ -18,7 +18,6 @@ limitations under the License. #include #include -#include #include namespace xla { diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 36924224482207..07a4a6032d8e37 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -5102,6 +5102,7 @@ cc_library( "//tensorflow/compiler/xla:shape_layout", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla:xla_proto_cc", + "//tensorflow/tsl/platform:protobuf", "//tensorflow/tsl/platform:statusor", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:inlined_vector", diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index bc1af2d1fc567b..7528ce37be3599 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -566,7 +566,6 @@ cc_library( ":buffer_comparator", ":gemm_rewriter", ":gemm_rewriter_triton", - ":gpu_asm_opts_util", ":gpu_device_info", ":gpu_float_support", ":gpu_fusible", @@ -1328,15 +1327,12 @@ cc_library( "@com_google_absl//absl/types:span", "//tensorflow/compiler/xla:autotune_results_proto_cc", "//tensorflow/compiler/xla:autotuning_proto_cc", - "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_proto_cc", "//tensorflow/compiler/xla/hlo/ir:hlo", "//tensorflow/compiler/xla/service:compiler", "//tensorflow/compiler/xla/service:executable", "//tensorflow/compiler/xla/service:hlo_module_config", - "//tensorflow/compiler/xla/service:platform_util", - "//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/compiler/xla/stream_executor", "//tensorflow/compiler/xla/stream_executor/gpu:gpu_stream_header", "//tensorflow/compiler/xla/stream_executor/gpu:gpu_timer_header", @@ -2383,7 +2379,6 @@ cc_library( "@llvm-project//llvm:Core", "@llvm-project//llvm:Support", "@llvm-project//llvm:TransformUtils", - "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", "@llvm-project//mlir:Support", @@ -2398,7 +2393,6 @@ cc_library( "//tensorflow/compiler/xla/hlo/transforms:hlo_constant_splitter", "//tensorflow/compiler/xla/mlir/backends/gpu/transforms:passes", "//tensorflow/compiler/xla/mlir/runtime/transforms:compilation_pipeline_gpu", - "//tensorflow/compiler/xla/mlir_hlo:transforms_gpu_passes", "//tensorflow/compiler/xla/runtime:jit_executable", "//tensorflow/compiler/xla/service:algebraic_simplifier", "//tensorflow/compiler/xla/service:all_gather_broadcast_reorder", @@ -2488,7 +2482,6 @@ cc_library( "//tensorflow/compiler/xla/stream_executor:device_description_proto_cc_impl", "//tensorflow/compiler/xla/stream_executor:stream_executor_headers", "//tensorflow/compiler/xla/stream_executor/cuda:cuda_platform_id", - "//tensorflow/compiler/xla/stream_executor/rocm:rocm_platform_id", "//tensorflow/compiler/xla/translate/hlo_to_mhlo:hlo_utils", "//tensorflow/compiler/xla/translate/mhlo_to_hlo:location_exporter", "//tensorflow/compiler/xla/translate/mhlo_to_lhlo_with_xla", diff --git a/tensorflow/compiler/xla/service/gpu/autotuner_compile_util.cc b/tensorflow/compiler/xla/service/gpu/autotuner_compile_util.cc index 87fed6db01161e..fd4440bcaeefaa 100644 --- a/tensorflow/compiler/xla/service/gpu/autotuner_compile_util.cc +++ b/tensorflow/compiler/xla/service/gpu/autotuner_compile_util.cc @@ -41,9 +41,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h" -#include "tensorflow/compiler/xla/service/platform_util.h" -#include "tensorflow/compiler/xla/service/shaped_buffer.h" -#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/stream_executor/device_memory.h" #include "tensorflow/compiler/xla/stream_executor/gpu/gpu_stream.h" #include "tensorflow/compiler/xla/stream_executor/gpu/gpu_timer.h" diff --git a/tensorflow/compiler/xla/service/gpu/autotuner_compile_util.h b/tensorflow/compiler/xla/service/gpu/autotuner_compile_util.h index 41330dcc1ba4a2..1ad8ddb5ba9acc 100644 --- a/tensorflow/compiler/xla/service/gpu/autotuner_compile_util.h +++ b/tensorflow/compiler/xla/service/gpu/autotuner_compile_util.h @@ -16,19 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_AUTOTUNER_COMPILE_UTIL_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_AUTOTUNER_COMPILE_UTIL_H_ -#include - -#include -#include -#include -#include -#include #include #include -#include -#include -#include -#include #include #include "tensorflow/compiler/xla/autotune_results.pb.h" diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_fused_mha_rewriter.cc b/tensorflow/compiler/xla/service/gpu/cudnn_fused_mha_rewriter.cc index f2b7d0a8ac85ad..49bbbf562e91a2 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_fused_mha_rewriter.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_fused_mha_rewriter.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/cublas_cudnn.h" #include "tensorflow/compiler/xla/service/gpu/matmul_utils.h" #include "tensorflow/compiler/xla/service/pattern_matcher.h" +#include "tensorflow/compiler/xla/stream_executor/stream_executor.h" #include "tensorflow/tsl/platform/errors.h" #include "tensorflow/tsl/platform/statusor.h" diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index 3058f2c60f4846..bd1adab6c50ce6 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -19,12 +19,10 @@ limitations under the License. #include #include #include -#include #include #include #include #include -#include // NOLINT #include #include #include @@ -40,9 +38,8 @@ limitations under the License. #include "llvm/IR/Verifier.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Utils/SplitModule.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/Diagnostics.h" // from @llvm-project +#include "tensorflow/compiler/xla/hlo/ir/hlo_casting_utils.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_instruction.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_instructions.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_module.h" @@ -50,7 +47,6 @@ limitations under the License. #include "tensorflow/compiler/xla/hlo/transforms/hlo_constant_splitter.h" #include "tensorflow/compiler/xla/mlir/backends/gpu/transforms/passes.h" #include "tensorflow/compiler/xla/mlir/runtime/transforms/compilation_pipeline_gpu.h" -#include "tensorflow/compiler/xla/mlir_hlo/transforms/gpu_passes.h" #include "tensorflow/compiler/xla/runtime/jit_executable.h" #include "tensorflow/compiler/xla/service/algebraic_simplifier.h" #include "tensorflow/compiler/xla/service/all_gather_broadcast_reorder.h" @@ -115,8 +111,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/horizontal_loop_fusion.h" #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" -#include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h" -#include "tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h" #include "tensorflow/compiler/xla/service/gpu/matmul_utils.h" #include "tensorflow/compiler/xla/service/gpu/metrics.h" #include "tensorflow/compiler/xla/service/gpu/move_copy_to_users.h" @@ -180,7 +174,6 @@ limitations under the License. #include "tensorflow/compiler/xla/stream_executor/device_description.h" #include "tensorflow/compiler/xla/stream_executor/device_description.pb.h" #include "tensorflow/compiler/xla/stream_executor/dnn.h" -#include "tensorflow/compiler/xla/stream_executor/rocm/rocm_platform_id.h" #include "tensorflow/compiler/xla/stream_executor/stream_executor.h" #include "tensorflow/compiler/xla/stream_executor/stream_executor_pimpl.h" #include "tensorflow/compiler/xla/util.h" diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.h b/tensorflow/compiler/xla/service/gpu/gpu_compiler.h index bea4b498b0f1b0..7ec383489e0cfb 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.h @@ -23,24 +23,20 @@ limitations under the License. #include #include -#include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "tensorflow/compiler/xla/autotune_results.pb.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_module.h" #include "tensorflow/compiler/xla/service/executable.h" #include "tensorflow/compiler/xla/service/gpu/executable.pb.h" #include "tensorflow/compiler/xla/service/gpu/gpu_device_info.h" #include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" -#include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h" #include "tensorflow/compiler/xla/service/llvm_compiler.h" #include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/compiler/xla/stream_executor/device_description.h" #include "tensorflow/compiler/xla/stream_executor/device_description.pb.h" #include "tensorflow/compiler/xla/stream_executor/stream_executor.h" #include "tensorflow/compiler/xla/stream_executor/stream_executor_pimpl.h" -#include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" namespace xla { diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc index 1c63929cf17674..7a8db3574b7a16 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc @@ -17,7 +17,6 @@ limitations under the License. #include #include -#include #include #include #include @@ -36,7 +35,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_type_conversion_util.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" -#include "tensorflow/compiler/xla/stream_executor/device_description.h" #include "tensorflow/compiler/xla/translate/mhlo_to_hlo/type_to_shape.h" #include "tensorflow/compiler/xla/xla_data.pb.h" diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h index ba031cb1504aaf..4df85639667806 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h @@ -18,7 +18,6 @@ limitations under the License. #include #include -#include #include #include "llvm/IR/IRBuilder.h" @@ -26,7 +25,6 @@ limitations under the License. #include "tensorflow/compiler/xla/hlo/ir/hlo_instruction.h" #include "tensorflow/compiler/xla/mlir_hlo/lhlo/IR/lhlo_ops.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" -#include "tensorflow/compiler/xla/stream_executor/stream_executor.h" namespace xla { namespace gpu { diff --git a/tensorflow/compiler/xla/service/gpu/matmul_utils.cc b/tensorflow/compiler/xla/service/gpu/matmul_utils.cc index 99d31cf1f5e6c5..d1615663ae9f5f 100644 --- a/tensorflow/compiler/xla/service/gpu/matmul_utils.cc +++ b/tensorflow/compiler/xla/service/gpu/matmul_utils.cc @@ -26,7 +26,6 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_instruction.h" -#include "tensorflow/compiler/xla/hlo/ir/hlo_module.h" #include "tensorflow/compiler/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/shape.h" @@ -36,6 +35,7 @@ limitations under the License. #include "tensorflow/compiler/xla/stream_executor/blas.h" #include "tensorflow/compiler/xla/stream_executor/device_memory.h" #include "tensorflow/compiler/xla/stream_executor/numeric_options.h" +#include "tensorflow/compiler/xla/stream_executor/stream_executor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc index 458a46e5313b1f..fda453bbe47e89 100644 --- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc @@ -15,8 +15,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/nvptx_compiler.h" -#include - #include #include #include diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h index d4bde5bca73d31..3b8906bbe46d41 100644 --- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h +++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h @@ -16,8 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NVPTX_COMPILER_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NVPTX_COMPILER_H_ -#include -#include #include #include #include diff --git a/tensorflow/compiler/xla/service/gpu/triton_autotuner.cc b/tensorflow/compiler/xla/service/gpu/triton_autotuner.cc index 467bbf41718dd7..934d550238e1c1 100644 --- a/tensorflow/compiler/xla/service/gpu/triton_autotuner.cc +++ b/tensorflow/compiler/xla/service/gpu/triton_autotuner.cc @@ -19,7 +19,6 @@ limitations under the License. #include #include #include -#include #include #include #include @@ -46,7 +45,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/buffer_comparator.h" #include "tensorflow/compiler/xla/service/gpu/gemm_rewriter.h" #include "tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton.h" -#include "tensorflow/compiler/xla/service/gpu/gpu_asm_opts_util.h" #include "tensorflow/compiler/xla/service/gpu/gpu_device_info.h" #include "tensorflow/compiler/xla/service/gpu/gpu_float_support.h" #include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h" diff --git a/tensorflow/compiler/xla/service/hlo_module_config.h b/tensorflow/compiler/xla/service/hlo_module_config.h index 845368579ee490..e1293baba97728 100644 --- a/tensorflow/compiler/xla/service/hlo_module_config.h +++ b/tensorflow/compiler/xla/service/hlo_module_config.h @@ -31,6 +31,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/xla.pb.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/tsl/platform/protobuf.h" namespace xla { diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index 66b3b497139905..853e1c7451be3d 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -22,7 +22,6 @@ limitations under the License. #include #include #include -#include #include #include @@ -38,6 +37,7 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" +#include "tensorflow/tsl/platform/cpu_info.h" #include "tensorflow/tsl/platform/threadpool.h" namespace xla { diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index 82345566e79f8e..a499299b31c751 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -19,19 +19,16 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SHAPE_UTIL_H_ #define TENSORFLOW_COMPILER_XLA_SHAPE_UTIL_H_ -#include #include #include #include #include #include #include -#include #include #include #include -#include "absl/base/macros.h" #include "absl/container/inlined_vector.h" #include "absl/functional/function_ref.h" #include "absl/types/span.h" @@ -40,9 +37,6 @@ limitations under the License. #include "tensorflow/compiler/xla/printer.h" #include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/tsl/platform/cpu_info.h" -#include "tensorflow/tsl/platform/env.h" -#include "tensorflow/tsl/platform/threadpool.h" namespace xla { diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index 11474f18b09552..e3d9decb530f13 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -114,6 +114,7 @@ cc_library( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto_cc", + "//tensorflow/tsl/platform:env", "//tensorflow/tsl/platform:errors", "//tensorflow/tsl/platform:path", "//tensorflow/tsl/platform:test", diff --git a/tensorflow/compiler/xla/tests/literal_test_util.cc b/tensorflow/compiler/xla/tests/literal_test_util.cc index ea81983429479a..be96a4b2982aad 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util.cc +++ b/tensorflow/compiler/xla/tests/literal_test_util.cc @@ -17,6 +17,7 @@ limitations under the License. #include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/literal_comparison.h" +#include "tensorflow/tsl/platform/env.h" #include "tensorflow/tsl/platform/path.h" #include "tensorflow/tsl/platform/test.h" From aca9fe60048f2577769ae71dfd415f4ce2fe42c5 Mon Sep 17 00:00:00 2001 From: Reed Wanderman-Milne Date: Fri, 14 Jul 2023 10:23:28 -0700 Subject: [PATCH 320/376] Fix some tests that were broken with XLA:GPU. PiperOrigin-RevId: 548159914 --- tensorflow/python/kernel_tests/control_flow/scan_ops_test.py | 5 ++++- .../python/kernel_tests/math_ops/aggregate_ops_test.py | 2 ++ tensorflow/python/kernel_tests/nn_ops/pooling_ops_test.py | 1 + tensorflow/python/kernel_tests/nn_ops/softplus_op_test.py | 3 ++- 4 files changed, 9 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/kernel_tests/control_flow/scan_ops_test.py b/tensorflow/python/kernel_tests/control_flow/scan_ops_test.py index 7747af49cdcbb9..b6ff099c646c79 100644 --- a/tensorflow/python/kernel_tests/control_flow/scan_ops_test.py +++ b/tensorflow/python/kernel_tests/control_flow/scan_ops_test.py @@ -228,7 +228,10 @@ def _compare(self, x, axis, exclusive, reverse): with self.cached_session(): tf_out = math_ops.cumprod(x, axis, exclusive, reverse).eval() - self.assertAllClose(np_out, tf_out) + atol = rtol = 1e-6 + if x.dtype == dtypes.bfloat16.as_numpy_dtype: + atol = rtol = 1e-2 + self.assertAllClose(np_out, tf_out, atol=atol, rtol=rtol) def _compareAll(self, x, axis): for exclusive in [True, False]: diff --git a/tensorflow/python/kernel_tests/math_ops/aggregate_ops_test.py b/tensorflow/python/kernel_tests/math_ops/aggregate_ops_test.py index 42ff056175d03a..4edfdb2e2678f0 100644 --- a/tensorflow/python/kernel_tests/math_ops/aggregate_ops_test.py +++ b/tensorflow/python/kernel_tests/math_ops/aggregate_ops_test.py @@ -73,6 +73,8 @@ def testAddN(self): self.assertAllCloseAccordingToType( expected, actual, + float_rtol=5e-6, + float_atol=5e-6, half_rtol=5e-3, half_atol=5e-3, ) diff --git a/tensorflow/python/kernel_tests/nn_ops/pooling_ops_test.py b/tensorflow/python/kernel_tests/nn_ops/pooling_ops_test.py index 550e384338615f..cc1800755ed2fa 100644 --- a/tensorflow/python/kernel_tests/nn_ops/pooling_ops_test.py +++ b/tensorflow/python/kernel_tests/nn_ops/pooling_ops_test.py @@ -2333,6 +2333,7 @@ def _testAvgPoolGradSamePadding3_1(self, data_format, use_gpu): data_format=data_format, use_gpu=use_gpu) + @test_util.disable_xla("Xla does not raise error on out of bounds access") def testAvgPoolGradOutputMemoryOutOfBounds(self): with self.assertRaisesRegex( errors_impl.InvalidArgumentError, diff --git a/tensorflow/python/kernel_tests/nn_ops/softplus_op_test.py b/tensorflow/python/kernel_tests/nn_ops/softplus_op_test.py index 97e7fcbb0418ab..40a745691acefc 100644 --- a/tensorflow/python/kernel_tests/nn_ops/softplus_op_test.py +++ b/tensorflow/python/kernel_tests/nn_ops/softplus_op_test.py @@ -39,7 +39,8 @@ def _testSoftplus(self, np_features, use_gpu=False): softplus = nn_ops.softplus(np_features) tf_softplus = self.evaluate(softplus) self.assertAllCloseAccordingToType( - np_softplus, tf_softplus, bfloat16_rtol=5e-2, bfloat16_atol=5e-2 + np_softplus, tf_softplus, half_rtol=5e-3, half_atol=5e-3, + bfloat16_rtol=5e-2, bfloat16_atol=5e-2 ) self.assertTrue(np.all(tf_softplus > 0)) self.assertShapeEqual(np_softplus, softplus) From ea054a460ee3c0d3f695ae1448dbfad360baa3cb Mon Sep 17 00:00:00 2001 From: Justin Lebar Date: Fri, 14 Jul 2023 10:37:30 -0700 Subject: [PATCH 321/376] [XLA] Remove To/FromAbslStatus functions from XLA. These functions are nop's -- TF status is now absl::Status. No functional change. PiperOrigin-RevId: 548163566 --- tensorflow/compiler/xla/service/gpu/runtime/collectives.cc | 4 ++-- tensorflow/compiler/xla/status.h | 2 -- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/runtime/collectives.cc b/tensorflow/compiler/xla/service/gpu/runtime/collectives.cc index e1ea335ba179fd..47780a2613e578 100644 --- a/tensorflow/compiler/xla/service/gpu/runtime/collectives.cc +++ b/tensorflow/compiler/xla/service/gpu/runtime/collectives.cc @@ -193,10 +193,10 @@ absl::Status P2PImplCommon(const ServiceExecutableRunOptions* run_options, NcclCollectiveThunk::GetDeviceString(params); auto comm = GetNcclComm(params, group_mode, op_id, replica_group_offsets, replica_group_values, is_async); - if (!comm.ok()) return ToAbslStatus(comm.status()); + if (!comm.ok()) return comm.status(); auto device_buffers = device_buffers_getter(args); - if (!device_buffers.ok()) return ToAbslStatus(device_buffers.status()); + if (!device_buffers.ok()) return device_buffers.status(); if (device_buffers->size() != 1) { return absl::InternalError(absl::StrFormat( "Expected device buffer size: 1, got %d", device_buffers->size())); diff --git a/tensorflow/compiler/xla/status.h b/tensorflow/compiler/xla/status.h index fe7fbef419c54e..ec1d45034a5ccc 100644 --- a/tensorflow/compiler/xla/status.h +++ b/tensorflow/compiler/xla/status.h @@ -20,10 +20,8 @@ limitations under the License. namespace xla { // NOLINTBEGIN(misc-unused-using-decls) -using tsl::FromAbslStatus; using tsl::OkStatus; using tsl::Status; // TENSORFLOW_STATUS_OK -using tsl::ToAbslStatus; // NOLINTEND(misc-unused-using-decls) } // namespace xla From 9b31c22e4275cd0ed9ea872cd4db83b15cd49ea7 Mon Sep 17 00:00:00 2001 From: Frederik Gossen Date: Fri, 14 Jul 2023 10:46:28 -0700 Subject: [PATCH 322/376] [KernelGen] JIT-compile most the MLIR-generated GPU kernels JIT-compile all MLIR-generated kernels for which the build rules can be reconfigured easily. For now, this excludes i64-indexed kernels and kernels with different input and output types. PiperOrigin-RevId: 548166059 --- tensorflow/core/kernels/mlir_generated/BUILD | 329 ++++++++++--------- tensorflow/python/kernel_tests/linalg/BUILD | 2 +- tensorflow/python/ops/BUILD | 4 +- 3 files changed, 182 insertions(+), 153 deletions(-) diff --git a/tensorflow/core/kernels/mlir_generated/BUILD b/tensorflow/core/kernels/mlir_generated/BUILD index c4c5b6f5d99435..b94f1ffd77aa32 100644 --- a/tensorflow/core/kernels/mlir_generated/BUILD +++ b/tensorflow/core/kernels/mlir_generated/BUILD @@ -639,13 +639,14 @@ gpu_kernel_library( gpu_kernel_library( name = "gpu_atan2_kernels", - op = "atan2", - tile_size = "256", - types = [ + jit_types = [ "f16", "f32", "f64", ], + op = "atan2", + tile_size = "256", + types = [], unroll_factors = "4", ) @@ -748,25 +749,27 @@ gpu_kernel_library( gpu_kernel_library( name = "gpu_ceil_kernels", - op = "ceil", - tile_size = "256", - types = [ + jit_types = [ "f16", "f32", "f64", ], + op = "ceil", + tile_size = "256", + types = [], unroll_factors = "4", ) gpu_kernel_library( name = "gpu_floor_kernels", - op = "floor", - tile_size = "256", - types = [ + jit_types = [ "f16", "f32", "f64", ], + op = "floor", + tile_size = "256", + types = [], unroll_factors = "4", ) @@ -792,26 +795,28 @@ gpu_kernel_library( gpu_kernel_library( name = "gpu_rint_kernels", - jit_types = ["f16"], - op = "rint", - tile_size = "1024", - types = [ + jit_types = [ + "f16", "f32", "f64", ], + op = "rint", + tile_size = "1024", + types = [], ) gpu_kernel_library( name = "gpu_round_kernels", - op = "round", - tile_size = "1024", - types = [ + jit_types = [ "f16", "f32", "f64", "i32", "i64", ], + op = "round", + tile_size = "1024", + types = [], ) # Predicate kernels @@ -1029,12 +1034,13 @@ gpu_kernel_library( gpu_kernel_library( name = "gpu_conj_kernels", - op = "conj", - tile_size = "256", - types = [ + jit_types = [ "c64", "c128", ], + op = "conj", + tile_size = "256", + types = [], unroll_factors = "2", ) @@ -1171,10 +1177,6 @@ gpu_kernel_library( "ui16", "ui32", "ui64", - ], - op = "maximum", - tile_size = "1024", - types = [ "f16", "f32", "f64", @@ -1182,6 +1184,9 @@ gpu_kernel_library( "i64", "ui8", ], + op = "maximum", + tile_size = "1024", + types = [], unroll_factors = "4", ) @@ -1192,10 +1197,6 @@ gpu_kernel_library( "ui16", "ui32", "ui64", - ], - op = "minimum", - tile_size = "1024", - types = [ "f16", "f32", "f64", @@ -1203,6 +1204,9 @@ gpu_kernel_library( "i64", "ui8", ], + op = "minimum", + tile_size = "1024", + types = [], unroll_factors = "4", ) @@ -1254,9 +1258,7 @@ gpu_kernel_library( gpu_kernel_library( name = "gpu_neg_kernels", - op = "neg", - tile_size = "256", - types = [ + jit_types = [ "f16", "f32", "f64", @@ -1267,6 +1269,9 @@ gpu_kernel_library( "c64", "c128", ], + op = "neg", + tile_size = "256", + types = [], unroll_factors = "4", ) @@ -1275,22 +1280,19 @@ gpu_kernel_library( jit_types = [ "i8", "i16", - ], - op = "pow", - tile_size = "1024", - types = [ "f16", "f32", "f64", "i64", ], + op = "pow", + tile_size = "1024", + types = [], ) gpu_kernel_library( name = "gpu_reciprocal_kernels", - op = "reciprocal", - tile_size = "256", - types = [ + jit_types = [ "c64", "c128", "f16", @@ -1298,6 +1300,9 @@ gpu_kernel_library( "f64", "i64", ], + op = "reciprocal", + tile_size = "256", + types = [], unroll_factors = "4", ) @@ -1306,10 +1311,6 @@ gpu_kernel_library( jit_types = [ "i8", "i16", - ], - op = "sign", - tile_size = "256", - types = [ "f16", "f32", "f64", @@ -1318,6 +1319,9 @@ gpu_kernel_library( "c64", "c128", ], + op = "sign", + tile_size = "256", + types = [], unroll_factors = "4", ) @@ -1361,80 +1365,86 @@ gpu_kernel_library( gpu_kernel_library( name = "gpu_xdivy_kernels", - op = "xdivy", - tile_size = "1024", - types = [ + jit_types = [ "f16", "f32", "f64", "c64", "c128", ], + op = "xdivy", + tile_size = "1024", + types = [], unroll_factors = "4", ) # Logarithmic and exponential kernels gpu_kernel_library( name = "gpu_exp_kernels", - op = "exp", - tile_size = "256", - types = [ + jit_types = [ "f16", "f32", "f64", "c64", "c128", ], + op = "exp", + tile_size = "256", + types = [], unroll_factors = "4", ) gpu_kernel_library( name = "gpu_expm1_kernels", - op = "expm1", - tile_size = "256", - types = [ + jit_types = [ "f16", "f32", "f64", ], + op = "expm1", + tile_size = "256", + types = [], unroll_factors = "4", ) gpu_kernel_library( name = "gpu_log_kernels", - op = "log", - tile_size = "256", - types = [ + jit_types = [ "f16", "f32", "f64", ], + op = "log", + tile_size = "256", + types = [], unroll_factors = "4", ) gpu_kernel_library( name = "gpu_log1p_kernels", - op = "log1p", - tile_size = "256", - types = [ + jit_types = [ "f16", "f32", "f64", ], + op = "log1p", + tile_size = "256", + types = [], unroll_factors = "4", ) gpu_kernel_library( name = "gpu_xlogy_kernels", - op = "xlogy", - tile_size = "1024", - types = [ + jit_types = [ "f16", "f32", "f64", "c64", "c128", ], + op = "xlogy", + tile_size = "1024", + types = [], unroll_factors = "4", # For complex XlogyOp kernels, we don't use unrolling, it would only cause # slowdowns. @@ -1446,15 +1456,16 @@ gpu_kernel_library( gpu_kernel_library( name = "gpu_xlog1py_kernels", - op = "xlog1py", - tile_size = "1024", - types = [ + jit_types = [ "f16", "f32", "f64", "c64", "c128", ], + op = "xlog1py", + tile_size = "1024", + types = [], unroll_factors = "4", # For complex Xlog1pyOp kernels, we don't use unrolling, it would only cause # slowdowns. @@ -1468,25 +1479,27 @@ gpu_kernel_library( gpu_kernel_library( name = "gpu_sqrt_kernels", - op = "sqrt", - tile_size = "256", - types = [ + jit_types = [ "f16", "f32", "f64", ], + op = "sqrt", + tile_size = "256", + types = [], unroll_factors = "4", ) gpu_kernel_library( name = "gpu_rsqrt_kernels", - op = "rsqrt", - tile_size = "256", - types = [ + jit_types = [ "f16", "f32", "f64", ], + op = "rsqrt", + tile_size = "256", + types = [], unroll_factors = "4", ) @@ -1499,28 +1512,28 @@ gpu_kernel_library( "ui16", "ui32", "ui64", - ], - op = "square", - tile_size = "1024", - types = [ "f16", "f32", "f64", "i64", ], + op = "square", + tile_size = "1024", + types = [], unroll_factors = "4", ) gpu_kernel_library( name = "gpu_squared_difference_kernels", - op = "squared_difference", - tile_size = "1024", - types = [ + jit_types = [ "f16", "f32", "f64", "i64", ], + op = "squared_difference", + tile_size = "1024", + types = [], unroll_factors = "4", ) @@ -1528,74 +1541,77 @@ gpu_kernel_library( gpu_kernel_library( name = "gpu_bitwise_and_kernels", - op = "bitwise_and", - tile_size = "1024", - types = [ + jit_types = [ "i8", "i16", "i32", "i64", ], + op = "bitwise_and", + tile_size = "1024", + types = [], unroll_factors = "4", ) gpu_kernel_library( name = "gpu_bitwise_or_kernels", - op = "bitwise_or", - tile_size = "1024", - types = [ + jit_types = [ "i8", "i16", "i32", "i64", ], + op = "bitwise_or", + tile_size = "1024", + types = [], unroll_factors = "4", ) gpu_kernel_library( name = "gpu_bitwise_xor_kernels", - op = "bitwise_xor", - tile_size = "1024", - types = [ + jit_types = [ "i8", "i16", "i32", "i64", ], + op = "bitwise_xor", + tile_size = "1024", + types = [], unroll_factors = "4", ) gpu_kernel_library( name = "gpu_invert_kernels", - op = "invert", - tile_size = "1024", - types = [ + jit_types = [ "i8", "i16", "i32", "i64", ], + op = "invert", + tile_size = "1024", + types = [], unroll_factors = "4", ) gpu_kernel_library( name = "gpu_left_shift_kernels", - op = "left_shift", - tile_size = "1024", - types = [ + jit_types = [ "i8", "i16", "i32", "i64", ], + op = "left_shift", + tile_size = "1024", + types = [], unroll_factors = "4", ) gpu_kernel_library( name = "gpu_right_shift_kernels", - op = "right_shift", - tile_size = "1024", - types = [ + jit_types = [ "i8", "i16", "i32", @@ -1605,6 +1621,9 @@ gpu_kernel_library( "ui32", "ui64", ], + op = "right_shift", + tile_size = "1024", + types = [], unroll_factors = "4", ) @@ -1612,52 +1631,57 @@ gpu_kernel_library( gpu_kernel_library( name = "gpu_logical_not_kernels", + jit_types = ["i1"], op = "logical_not", tile_size = "256", - types = ["i1"], + types = [], ) gpu_kernel_library( name = "gpu_logical_and_kernels", - op = "logical_and", - tile_size = "1024", - types = [ + jit_types = [ "i1", ], + op = "logical_and", + tile_size = "1024", + types = [], ) gpu_kernel_library( name = "gpu_logical_or_kernels", - op = "logical_or", - tile_size = "1024", - types = [ + jit_types = [ "i1", ], + op = "logical_or", + tile_size = "1024", + types = [], ) # Erf kernels gpu_kernel_library( name = "gpu_erf_kernels", - op = "erf", - tile_size = "256", - types = [ + jit_types = [ "f16", "f32", "f64", ], + op = "erf", + tile_size = "256", + types = [], unroll_factors = "4", ) gpu_kernel_library( name = "gpu_erfc_kernels", - op = "erfc", - tile_size = "256", - types = [ + jit_types = [ "f16", "f32", "f64", ], + op = "erfc", + tile_size = "256", + types = [], unroll_factors = "4", ) @@ -1665,45 +1689,49 @@ gpu_kernel_library( gpu_kernel_library( name = "gpu_polygamma_kernels", - op = "polygamma", - tile_size = "256", - types = [ + jit_types = [ "f32", "f64", ], + op = "polygamma", + tile_size = "256", + types = [], ) gpu_kernel_library( name = "gpu_digamma_kernels", - op = "digamma", - tile_size = "256", - types = [ + jit_types = [ "f16", "f32", "f64", ], + op = "digamma", + tile_size = "256", + types = [], ) gpu_kernel_library( name = "gpu_lgamma_kernels", - op = "lgamma", - tile_size = "256", - types = [ + jit_types = [ "f16", "f32", "f64", ], + op = "lgamma", + tile_size = "256", + types = [], ) gpu_kernel_library( # The zeta kernels needs many registers so tile at 256. name = "gpu_zeta_kernels", - op = "zeta", - tile_size = "256", - types = [ + jit_types = [ "f32", "f64", ], + op = "zeta", + tile_size = "256", + types = [], # TODO(b/178388085): Enable unrolling after vectorization is fixed. # unroll_factors = "4", ) @@ -1730,61 +1758,64 @@ gpu_kernel_library( "ui16", "ui32", "ui64", - ], - op = "relu", - tile_size = "256", - types = [ "f16", "f32", "f64", ], + op = "relu", + tile_size = "256", + types = [], unroll_factors = "16B", ) gpu_kernel_library( name = "gpu_elu_kernels", - op = "elu", - tile_size = "256", - types = [ + jit_types = [ "f16", "f32", "f64", ], + op = "elu", + tile_size = "256", + types = [], ) gpu_kernel_library( name = "gpu_selu_kernels", - op = "selu", - tile_size = "256", - types = [ + jit_types = [ "f16", "f32", "f64", ], + op = "selu", + tile_size = "256", + types = [], ) gpu_kernel_library( name = "gpu_sigmoid_kernels", - op = "sigmoid", - tile_size = "256", - types = [ + jit_types = [ "f16", "f32", "f64", ], + op = "sigmoid", + tile_size = "256", + types = [], ) # Kernels that support all floating-point types. [ gpu_kernel_library( name = "gpu_" + op + "_kernels", - op = op, - tile_size = "256", - types = [ + jit_types = [ "f16", "f32", "f64", ], + op = op, + tile_size = "256", + types = [], unroll_factors = "4", ) for op in [ @@ -1836,11 +1867,6 @@ gpu_kernel_library( "ui16", "ui32", "ui64", - ], - max_supported_rank = 8, - op = "select_v2", - tile_size = "256", - types = [ "i1", "i32", "i64", @@ -1850,6 +1876,10 @@ gpu_kernel_library( "c64", "c128", ], + max_supported_rank = 8, + op = "select_v2", + tile_size = "256", + types = [], ) gpu_kernel_library( @@ -1861,10 +1891,6 @@ gpu_kernel_library( "ui16", "ui32", "ui64", - ], - op = "zeros_like", - tile_size = "1024", - types = [ "i1", "i64", "f16", @@ -1873,6 +1899,9 @@ gpu_kernel_library( "c64", "c128", ], + op = "zeros_like", + tile_size = "1024", + types = [], ) gpu_kernel_library( @@ -1884,10 +1913,6 @@ gpu_kernel_library( "ui16", "ui32", "ui64", - ], - op = "ones_like", - tile_size = "1024", - types = [ "i1", "i64", "f16", @@ -1896,14 +1921,18 @@ gpu_kernel_library( "c64", "c128", ], + op = "ones_like", + tile_size = "1024", + types = [], ) gpu_kernel_library( name = "gpu_next_after_kernels", - op = "next_after", - tile_size = "1024", - types = [ + jit_types = [ "f32", "f64", ], + op = "next_after", + tile_size = "1024", + types = [], ) diff --git a/tensorflow/python/kernel_tests/linalg/BUILD b/tensorflow/python/kernel_tests/linalg/BUILD index ada2c64403c01b..42fd131137743f 100644 --- a/tensorflow/python/kernel_tests/linalg/BUILD +++ b/tensorflow/python/kernel_tests/linalg/BUILD @@ -271,7 +271,7 @@ cuda_py_strict_test( name = "linear_operator_circulant_test", size = "medium", srcs = ["linear_operator_circulant_test.py"], - shard_count = 15, + shard_count = 32, tags = [ "no_cuda11", # TODO(b/197522782): reenable test after fixing. "optonly", # times out, b/79171797 diff --git a/tensorflow/python/ops/BUILD b/tensorflow/python/ops/BUILD index a81d14afbbe481..4943ebe4bfa040 100644 --- a/tensorflow/python/ops/BUILD +++ b/tensorflow/python/ops/BUILD @@ -3159,7 +3159,7 @@ py_strict_library( cuda_py_strict_test( name = "bitwise_ops_test", - size = "small", + size = "medium", srcs = ["bitwise_ops_test.py"], main = "bitwise_ops_test.py", python_version = "PY3", @@ -3504,7 +3504,7 @@ cuda_py_strict_test( cuda_py_strict_test( name = "math_grad_test", - size = "small", + size = "medium", srcs = ["math_grad_test.py"], main = "math_grad_test.py", python_version = "PY3", From 659612616e3ce5e4a63444b314549aa54f346da7 Mon Sep 17 00:00:00 2001 From: Marat Dukhan Date: Fri, 14 Jul 2023 10:59:51 -0700 Subject: [PATCH 323/376] Update cpuinfo dependency PiperOrigin-RevId: 548169546 --- tensorflow/workspace2.bzl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tensorflow/workspace2.bzl b/tensorflow/workspace2.bzl index 4023312c9f6b7b..30c0daa23dc4b3 100644 --- a/tensorflow/workspace2.bzl +++ b/tensorflow/workspace2.bzl @@ -169,9 +169,9 @@ def _tf_repositories(): tf_http_archive( name = "cpuinfo", - strip_prefix = "cpuinfo-3dc310302210c1891ffcfb12ae67b11a3ad3a150", - sha256 = "ba668f9f8ea5b4890309b7db1ed2e152aaaf98af6f9a8a63dbe1b75c04e52cb9", - urls = tf_mirror_urls("https://github.com/pytorch/cpuinfo/archive/3dc310302210c1891ffcfb12ae67b11a3ad3a150.zip"), + strip_prefix = "cpuinfo-87d8234510367db49a65535021af5e1838a65ac2", + sha256 = "609fc42c47482c1fc125dccac65e843f640e792540162581c4b7eb6ff81c826a", + urls = tf_mirror_urls("https://github.com/pytorch/cpuinfo/archive/87d8234510367db49a65535021af5e1838a65ac2.zip"), ) tf_http_archive( From 446d38dca20c1b8f89c91bd536f2d0daea45843c Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 14 Jul 2023 11:03:40 -0700 Subject: [PATCH 324/376] Explicitly include stdlib.h in lite/util.h This header calls free(...), so we should explicitly include the corresponding header file. If stdlib.h is not included, this file doesn't compile if Clang modules are enabled. PiperOrigin-RevId: 548170716 --- tensorflow/lite/util.h | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorflow/lite/util.h b/tensorflow/lite/util.h index 2ba25f84588dfc..6e8264974501ce 100644 --- a/tensorflow/lite/util.h +++ b/tensorflow/lite/util.h @@ -22,6 +22,7 @@ limitations under the License. #define TENSORFLOW_LITE_UTIL_H_ #include +#include #include #include From 873e387e4dc9aa94997356894a1b3056820dfec4 Mon Sep 17 00:00:00 2001 From: Sizhi Tan Date: Fri, 14 Jul 2023 11:03:42 -0700 Subject: [PATCH 325/376] Explicitly disable use of tfrt for failing test. PiperOrigin-RevId: 548170728 --- tensorflow/python/eager/polymorphic_function/BUILD | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tensorflow/python/eager/polymorphic_function/BUILD b/tensorflow/python/eager/polymorphic_function/BUILD index 0f8b19c819d187..6a093f105f51ca 100644 --- a/tensorflow/python/eager/polymorphic_function/BUILD +++ b/tensorflow/python/eager/polymorphic_function/BUILD @@ -331,6 +331,10 @@ tf_py_strict_test( tf_xla_py_strict_test( name = "polymorphic_function_xla_jit_test", srcs = ["polymorphic_function_xla_jit_test.py"], + # copybara:uncomment_begin + # #TODO(b/185944215) # Remove once the bug is fixed. + # disable_tpu_tfrt = True, + # copybara:uncomment_end disabled_backends = [ "cpu_ondemand", ], From 02f24749e521682de1b83bcd51a720d00e0c2eac Mon Sep 17 00:00:00 2001 From: Sizhi Tan Date: Fri, 14 Jul 2023 11:06:01 -0700 Subject: [PATCH 326/376] Explicitly disable use of tfrt for failing test. PiperOrigin-RevId: 548171370 --- tensorflow/python/distribute/BUILD | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD index 05b88a5e37cb7b..e64f90b1ca9be0 100644 --- a/tensorflow/python/distribute/BUILD +++ b/tensorflow/python/distribute/BUILD @@ -2166,6 +2166,11 @@ cuda_py_strict_test( tpu_py_strict_test( name = "collective_all_reduce_strategy_test_tpu", srcs = ["collective_all_reduce_strategy_test.py"], + # copybara:uncomment_begin + # args = [ + # "--tpu_use_tfrt=false", #TODO(b/227404010): Remove once the bug is fixed. + # ], + # copybara:uncomment_end # FIXME(b/227404010): On TFRT TPU, eager CollectiveReduceV2 is broken. disable_tfrt = True, main = "collective_all_reduce_strategy_test.py", From 9647183dd42868e33d35aa80abc28d7c1a77af62 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 14 Jul 2023 11:28:08 -0700 Subject: [PATCH 327/376] Add logging information to Ph1 call sites PiperOrigin-RevId: 548177410 --- tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc | 7 ------- 1 file changed, 7 deletions(-) diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc b/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc index 6eaa7a4b04980f..06877ad8e42a2b 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc @@ -31,7 +31,6 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" #include "tensorflow/core/framework/metrics.h" #include "tensorflow/core/platform/error_payloads.h" -#include "tensorflow/core/platform/stacktrace.h" #include "tensorflow/core/protobuf/core_platform_payloads.pb.h" #include "tensorflow/core/util/debug_data_dumper.h" @@ -291,8 +290,6 @@ void CreateTPUBridgePipelineV1(OpPassManager &pm) { tensorflow::Status TPUBridge(ModuleOp module, bool fallback_enabled, llvm::StringRef module_name) { - VLOG(1) << "TPU Bridge called stack trace is :" - << tensorflow::CurrentStackTrace(); Status status = RunTFXLABridge( module, [module_name](OpPassManager &pm) { @@ -316,8 +313,6 @@ tensorflow::Status TPUBridge(ModuleOp module, bool fallback_enabled, return status; } tensorflow::Status TPUBridgeV1Compat(ModuleOp module, bool fallback_enabled) { - VLOG(1) << "TPU V1 Compat Bridge called stack trace is :" - << tensorflow::CurrentStackTrace(); Status status = RunTFXLABridge(module, [](OpPassManager &pm) { CreateTPUBridgePipelineV1(pm); // Add set of passes to lower back to graph (from tf_executor). @@ -492,8 +487,6 @@ void CreateTFXLABridgePipeline(OpPassManager &pm) { tensorflow::Status RunTFXLABridge(ModuleOp module, llvm::StringRef module_name) { - VLOG(1) << "CPU/GPU Bridge called stack trace is :" - << tensorflow::CurrentStackTrace(); Status status = mlir::TFTPU::RunTFXLABridge( module, [](OpPassManager &pm) { From 2a69423d15752efa007aaaa0171719f6c67c7752 Mon Sep 17 00:00:00 2001 From: Reed Wanderman-Milne Date: Fri, 14 Jul 2023 11:48:46 -0700 Subject: [PATCH 328/376] Fix multi-GPU FP8 crash. The issue was that each device shared a cublasLtMatmulDesc_t. Normally this is fine since cublasLtMatmulDesc_t mostly just holds general information about the matmul like the compute type but for FP8, it also holds pointers in device memory to the scales. The pointers are set on the cublasLtMatmulDesc_t every time the matmul is run, in cuda_blas_lt.c. Since the pointers are different for each device, a race could occur where a GPU would try to access scale pointers in another GPU's memory. To fix, now each stream has a different cublasLtMatmulDesc_t. Fixes https://github.com/openxla/xla/issues/4001 PiperOrigin-RevId: 548182888 --- tensorflow/compiler/xla/service/gpu/BUILD | 1 + .../xla/service/gpu/cublas_lt_matmul_thunk.cc | 28 ++++++- .../xla/service/gpu/cublas_lt_matmul_thunk.h | 16 +++- .../xla/service/gpu/ir_emitter_unnested.cc | 19 ++--- .../compiler/xla/service/gpu/matmul_utils.h | 78 +++++++++---------- 5 files changed, 85 insertions(+), 57 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 7528ce37be3599..7d491f8063173b 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -1247,6 +1247,7 @@ cc_library( "//tensorflow/compiler/xla/stream_executor:stream_executor_headers", "//tensorflow/compiler/xla/stream_executor/cuda:cublas_lt_header", "//tensorflow/compiler/xla/stream_executor/cuda:cublas_plugin", + "//tensorflow/tsl/platform:statusor", ]) + ["//tensorflow/tsl/platform:logging"], ) diff --git a/tensorflow/compiler/xla/service/gpu/cublas_lt_matmul_thunk.cc b/tensorflow/compiler/xla/service/gpu/cublas_lt_matmul_thunk.cc index 881f40c52232d8..cd5551b467e870 100644 --- a/tensorflow/compiler/xla/service/gpu/cublas_lt_matmul_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/cublas_lt_matmul_thunk.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/cublas_lt_matmul_thunk.h" +#include #include #include "tensorflow/compiler/xla/service/gpu/matmul_utils.h" @@ -28,7 +29,8 @@ namespace xla { namespace gpu { CublasLtMatmulThunk::CublasLtMatmulThunk( - ThunkInfo thunk_info, cublas_lt::MatmulPlan plan, int64_t algorithm_idx, + ThunkInfo thunk_info, GemmConfig gemm_config, + se::cuda::BlasLt::Epilogue epilogue, int64_t algorithm_idx, BufferAllocation::Slice a_buffer, BufferAllocation::Slice b_buffer, BufferAllocation::Slice c_buffer, BufferAllocation::Slice d_buffer, BufferAllocation::Slice bias_buffer, BufferAllocation::Slice aux_buffer, @@ -36,7 +38,8 @@ CublasLtMatmulThunk::CublasLtMatmulThunk( BufferAllocation::Slice c_scale, BufferAllocation::Slice d_scale, BufferAllocation::Slice d_amax) : Thunk(Kind::kCublasLtMatmul, thunk_info), - plan_(std::move(plan)), + gemm_config_(std::move(gemm_config)), + epilogue_(epilogue), algorithm_idx_(algorithm_idx), a_buffer_(a_buffer), b_buffer_(b_buffer), @@ -51,10 +54,12 @@ CublasLtMatmulThunk::CublasLtMatmulThunk( d_amax_buffer_(d_amax) {} Status CublasLtMatmulThunk::ExecuteOnStream(const ExecuteParams& params) { + TF_ASSIGN_OR_RETURN(cublas_lt::MatmulPlan * plan, + GetMatmulPlan(params.stream)); if (!algorithm_) { TF_ASSIGN_OR_RETURN( std::vector algorithms, - plan_.GetAlgorithms(params.stream)); + plan->GetAlgorithms(params.stream)); TF_RET_CHECK(algorithm_idx_ >= 0 && algorithm_idx_ < algorithms.size()); algorithm_ = algorithms[algorithm_idx_]; } @@ -89,12 +94,27 @@ Status CublasLtMatmulThunk::ExecuteOnStream(const ExecuteParams& params) { se::OwningScratchAllocator<> scratch_allocator(allocs.device_ordinal(), allocs.memory_allocator()); - return plan_.ExecuteOnStream( + return plan->ExecuteOnStream( params.stream, allocs.GetDeviceAddress(a_buffer_), allocs.GetDeviceAddress(b_buffer_), allocs.GetDeviceAddress(c_buffer_), allocs.GetDeviceAddress(d_buffer_), bias, aux, a_scale, b_scale, c_scale, d_scale, d_amax, *algorithm_, scratch_allocator); } +StatusOr CublasLtMatmulThunk::GetMatmulPlan( + const stream_executor::Stream* stream) { + absl::MutexLock lock(&matmul_plans_cache_mutex_); + auto it = matmul_plans_cache_.find(stream); + if (it == matmul_plans_cache_.end()) { + TF_ASSIGN_OR_RETURN(cublas_lt::MatmulPlan plan, + cublas_lt::MatmulPlan::From(gemm_config_, epilogue_)); + it = matmul_plans_cache_ + .insert({stream, + std::make_unique(std::move(plan))}) + .first; + } + return it->second.get(); +} + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/cublas_lt_matmul_thunk.h b/tensorflow/compiler/xla/service/gpu/cublas_lt_matmul_thunk.h index da156e322ee395..98295e5767d8b6 100644 --- a/tensorflow/compiler/xla/service/gpu/cublas_lt_matmul_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/cublas_lt_matmul_thunk.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUBLAS_LT_MATMUL_THUNK_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUBLAS_LT_MATMUL_THUNK_H_ +#include #include #include @@ -24,13 +25,15 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/thunk.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/stream_executor/cuda/cuda_blas_lt.h" +#include "tensorflow/tsl/platform/statusor.h" namespace xla { namespace gpu { class CublasLtMatmulThunk : public Thunk { public: - CublasLtMatmulThunk(ThunkInfo thunk_info, cublas_lt::MatmulPlan plan, + CublasLtMatmulThunk(ThunkInfo thunk_info, GemmConfig gemm_config, + se::cuda::BlasLt::Epilogue epilogue, int64_t algorithm_idx, BufferAllocation::Slice a_buffer, BufferAllocation::Slice b_buffer, BufferAllocation::Slice c_buffer, @@ -46,7 +49,16 @@ class CublasLtMatmulThunk : public Thunk { Status ExecuteOnStream(const ExecuteParams& params) override; private: - cublas_lt::MatmulPlan plan_; + StatusOr GetMatmulPlan( + const stream_executor::Stream* stream); + + absl::Mutex matmul_plans_cache_mutex_; + absl::flat_hash_map> + matmul_plans_cache_ ABSL_GUARDED_BY(matmul_plans_cache_mutex_); + + GemmConfig gemm_config_; + se::cuda::BlasLt::Epilogue epilogue_; int64_t algorithm_idx_; BufferAllocation::Slice a_buffer_; BufferAllocation::Slice b_buffer_; diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index aa8c17bc532970..299dd9fdb10785 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -1140,11 +1140,12 @@ Status IrEmitterUnnested::EmitCublasLtMatmulThunk(mlir::Operation* op) { TF_ASSIGN_OR_RETURN(aux, GetAllocationSlice(matmul.getAux())); } - TF_ASSIGN_OR_RETURN(cublas_lt::MatmulPlan plan, - cublas_lt::MatmulPlan::For(matmul)); + TF_ASSIGN_OR_RETURN(GemmConfig gemm_config, GemmConfig::For(matmul)); + TF_ASSIGN_OR_RETURN(se::cuda::BlasLt::Epilogue epilogue, + cublas_lt::AsBlasLtEpilogue(matmul.getEpilogue())); auto thunk = std::make_unique( - GetThunkInfo(op), std::move(plan), matmul.getAlgorithm(), a, b, c, d, - bias, aux, a_scale, b_scale, c_scale, d_scale, d_amax); + GetThunkInfo(op), std::move(gemm_config), epilogue, matmul.getAlgorithm(), + a, b, c, d, bias, aux, a_scale, b_scale, c_scale, d_scale, d_amax); AddThunkToThunkSequence(std::move(thunk)); return OkStatus(); @@ -1180,12 +1181,12 @@ Status IrEmitterUnnested::EmitCublasLtMatmulThunkF8(mlir::Operation* op) { BufferAllocation::Slice aux; // Not used. - TF_ASSIGN_OR_RETURN(cublas_lt::MatmulPlan plan, - cublas_lt::MatmulPlan::For(matmul)); - + TF_ASSIGN_OR_RETURN(GemmConfig gemm_config, GemmConfig::For(matmul)); + TF_ASSIGN_OR_RETURN(se::cuda::BlasLt::Epilogue epilogue, + cublas_lt::AsBlasLtEpilogue(matmul.getEpilogue())); auto thunk = std::make_unique( - GetThunkInfo(op), std::move(plan), matmul.getAlgorithm(), a, b, c, d, - bias, aux, a_scale, b_scale, c_scale, d_scale, d_amax); + GetThunkInfo(op), std::move(gemm_config), epilogue, matmul.getAlgorithm(), + a, b, c, d, bias, aux, a_scale, b_scale, c_scale, d_scale, d_amax); AddThunkToThunkSequence(std::move(thunk)); return OkStatus(); diff --git a/tensorflow/compiler/xla/service/gpu/matmul_utils.h b/tensorflow/compiler/xla/service/gpu/matmul_utils.h index 3afe513fc74e12..b82f54f24fe5f6 100644 --- a/tensorflow/compiler/xla/service/gpu/matmul_utils.h +++ b/tensorflow/compiler/xla/service/gpu/matmul_utils.h @@ -111,6 +111,42 @@ struct GemmConfig { double alpha_imag, double beta, std::optional algorithm, int64_t compute_precision); + template ::value || + std::is_same::value>> + static StatusOr For(CublasLtMatmulMaybeF8Op op) { + mlir::mhlo::DotDimensionNumbersAttr dot_dims = op.getDotDimensionNumbers(); + + int64_t compute_precision = 0; // Default + if (op.getPrecisionConfig().has_value()) { + auto precision_config = op.getPrecisionConfig(); + for (auto attr : precision_config.value()) { + int64_t value = static_cast( + attr.template cast().getValue()); + if (value > compute_precision) { + compute_precision = value; + } + } + } + + Shape bias_shape; + if (op.getBias() != nullptr) { + bias_shape = GetShape(op.getBias()); + } + return GemmConfig::For( + GetShape(op.getA()), dot_dims.getLhsBatchingDimensions(), + dot_dims.getLhsContractingDimensions(), GetShape(op.getB()), + dot_dims.getRhsBatchingDimensions(), + dot_dims.getRhsContractingDimensions(), GetShape(op.getC()), + op.getBias() == nullptr ? nullptr : &bias_shape, GetShape(op.getD()), + op.getAlphaReal().convertToDouble(), + op.getAlphaImag().convertToDouble(), op.getBeta().convertToDouble(), + op.getAlgorithm(), compute_precision); + } + MatrixLayout lhs_layout; MatrixLayout rhs_layout; MatrixLayout c_layout; @@ -162,48 +198,6 @@ StatusOr AsBlasLtEpilogue( class MatmulPlan { public: - template ::value || - std::is_same::value>> - static StatusOr For(CublasLtMatmulMaybeF8Op op) { - mlir::mhlo::DotDimensionNumbersAttr dot_dims = op.getDotDimensionNumbers(); - - int64_t compute_precision = 0; // Default - if (op.getPrecisionConfig().has_value()) { - auto precision_config = op.getPrecisionConfig(); - for (auto attr : precision_config.value()) { - int64_t value = static_cast( - attr.template cast().getValue()); - if (value > compute_precision) { - compute_precision = value; - } - } - } - - Shape bias_shape; - if (op.getBias() != nullptr) { - bias_shape = GetShape(op.getBias()); - } - TF_ASSIGN_OR_RETURN( - GemmConfig config, - GemmConfig::For( - GetShape(op.getA()), dot_dims.getLhsBatchingDimensions(), - dot_dims.getLhsContractingDimensions(), GetShape(op.getB()), - dot_dims.getRhsBatchingDimensions(), - dot_dims.getRhsContractingDimensions(), GetShape(op.getC()), - op.getBias() == nullptr ? nullptr : &bias_shape, - GetShape(op.getD()), op.getAlphaReal().convertToDouble(), - op.getAlphaImag().convertToDouble(), op.getBeta().convertToDouble(), - op.getAlgorithm(), compute_precision)); - - TF_ASSIGN_OR_RETURN(se::cuda::BlasLt::Epilogue epilogue, - AsBlasLtEpilogue(op.getEpilogue())); - return From(config, epilogue); - } - static StatusOr From(const GemmConfig& config, se::cuda::BlasLt::Epilogue epilogue); From 4bc6cdedea7fc1350d0dafc431b851e38362b6b2 Mon Sep 17 00:00:00 2001 From: Anlun Xu Date: Fri, 14 Jul 2023 12:02:10 -0700 Subject: [PATCH 329/376] [xla:gpu] Compute the transitive reduction of the dependency graph PiperOrigin-RevId: 548186731 --- .../gpu/transforms/dataflow_analysis.cc | 46 +++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/tensorflow/compiler/xla/mlir/backends/gpu/transforms/dataflow_analysis.cc b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/dataflow_analysis.cc index 02873b88258588..f4b23864a1fe08 100644 --- a/tensorflow/compiler/xla/mlir/backends/gpu/transforms/dataflow_analysis.cc +++ b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/dataflow_analysis.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/mlir/backends/gpu/transforms/dataflow_analysis.h" #include +#include #include #include @@ -167,6 +168,50 @@ bool HasDependency(llvm::ArrayRef buffer_uses_a, return false; } +bool Reachable(const DataflowAnalysis::DataflowGraph& graph, size_t from_index, + size_t to_index) { + std::queue bfs_queue; + bfs_queue.push(from_index); + + while (!bfs_queue.empty()) { + size_t index = bfs_queue.front(); + bfs_queue.pop(); + if (index == to_index) return true; + + const DataflowAnalysis::Node& node = graph[index]; + for (size_t child_index : node.children) { + bfs_queue.push(child_index); + } + } + + return false; +} + +// Remove edges that are redundant for determining the execution order of +// kernels. We use the following algorithm to compute the transitive reduction: +// +// for edge (u,v) do +// if there is a path from u to v in that does not use edge (u,v) then +// remove edge (u,v) +// +// TODO(b/288594057): Use a more efficient algorithm. +void TransitiveReduction(DataflowAnalysis::DataflowGraph& graph) { + for (DataflowAnalysis::Node& node : graph) { + auto is_reducible = [&](size_t to_index) -> bool { + for (size_t child_index : node.children) { + if (child_index != to_index) { + if (Reachable(graph, child_index, to_index)) return true; + } + } + return false; + }; + + node.children.erase(std::remove_if(node.children.begin(), + node.children.end(), is_reducible), + node.children.end()); + } +} + } // namespace DataflowAnalysis::DataflowGraph DataflowAnalysis::GetDataflowGraph( @@ -194,6 +239,7 @@ DataflowAnalysis::DataflowGraph DataflowAnalysis::GetDataflowGraph( } } + TransitiveReduction(graph); return graph; } From 335bc7eb1f4db6c228a94a253e4208ab0a0fc5de Mon Sep 17 00:00:00 2001 From: Son Tuan Vu Date: Fri, 14 Jul 2023 12:22:53 -0700 Subject: [PATCH 330/376] [XLA:GPU][NFC] Use more concise TF_CHECK_OK macro PiperOrigin-RevId: 548192429 --- tensorflow/compiler/xla/shape_util.cc | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index 853e1c7451be3d..d85a38bec85935 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -1731,9 +1731,8 @@ ShapeUtil::DecomposeBitcastToTrt(const Shape& input_shape, absl::Span count, absl::Span incr, const ForEachParallelVisitorFunction& visitor_function) { // The parallel version of ForEachIndexInternal can never fail. - CHECK( - ForEachIndexParallelWithStatus(shape, base, count, incr, visitor_function) - .ok()); + TF_CHECK_OK(ForEachIndexParallelWithStatus(shape, base, count, incr, + visitor_function)); } /* static */ Status ShapeUtil::ForEachIndexParallelWithStatus( @@ -1748,7 +1747,7 @@ ShapeUtil::DecomposeBitcastToTrt(const Shape& input_shape, /* static */ void ShapeUtil::ForEachIndexParallel( const Shape& shape, const ForEachParallelVisitorFunction& visitor_function) { - CHECK(ForEachIndexParallelWithStatus(shape, visitor_function).ok()); + TF_CHECK_OK(ForEachIndexParallelWithStatus(shape, visitor_function)); } /* static */ Status ShapeUtil::ForEachIndexParallelWithStatus( From 28372efec3cc07a8e1db81f6b55b87ca4d87219b Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 14 Jul 2023 12:37:43 -0700 Subject: [PATCH 331/376] Check `indices` are in range in `tf.TensorArray.gather` method Fix #60148, Segmentation fault. PiperOrigin-RevId: 548196238 --- tensorflow/core/kernels/list_kernels.h | 6 +++++ .../data_structures/list_ops_test.py | 22 +++++++++++++++++++ 2 files changed, 28 insertions(+) diff --git a/tensorflow/core/kernels/list_kernels.h b/tensorflow/core/kernels/list_kernels.h index 369a81bb1d5105..09814bcbb6026a 100644 --- a/tensorflow/core/kernels/list_kernels.h +++ b/tensorflow/core/kernels/list_kernels.h @@ -691,6 +691,12 @@ class TensorListGather : public OpKernel { if (!tensor_list->element_shape.IsFullyDefined()) { for (int index = 0; index < indices.NumElements(); ++index) { const int i = indices.flat()(index); + + OP_REQUIRES(c, 0 <= i && i < tensor_list->tensors().size(), + absl::InvalidArgumentError(absl::StrCat( + "Trying to gather element ", i, " in a list with ", + tensor_list->tensors().size(), " elements."))); + const Tensor& t = tensor_list->tensors()[i]; if (t.dtype() != DT_INVALID) { PartialTensorShape tmp = partial_element_shape; diff --git a/tensorflow/python/kernel_tests/data_structures/list_ops_test.py b/tensorflow/python/kernel_tests/data_structures/list_ops_test.py index 3d64a891251f4f..1d3eb1ab96c884 100644 --- a/tensorflow/python/kernel_tests/data_structures/list_ops_test.py +++ b/tensorflow/python/kernel_tests/data_structures/list_ops_test.py @@ -479,6 +479,28 @@ def testGatherUsingSpecifiedElementShape(self): self.assertEqual(t.shape.as_list(), [3]) self.assertAllEqual(self.evaluate(t), np.zeros((3,))) + def testGatherWithInvalidIndicesFails(self): + l = list_ops.tensor_list_reserve( + element_dtype=dtypes.float32, element_shape=None, num_elements=3 + ) + + # Should raise an error when the input index is negative. + with self.assertRaisesRegex( + errors.InvalidArgumentError, + "Trying to gather element -1 in a list with 3 elements.", + ): + t = list_ops.tensor_list_gather(l, [-1], element_dtype=dtypes.float32) + self.evaluate(t) + + # Should raise an error when the input index is larger than the number of + # elements in the list. + with self.assertRaisesRegex( + errors.InvalidArgumentError, + "Trying to gather element 3 in a list with 3 elements.", + ): + t = list_ops.tensor_list_gather(l, [3], element_dtype=dtypes.float32) + self.evaluate(t) + def testScatterOutputListSize(self): c0 = constant_op.constant([1.0, 2.0]) l = list_ops.tensor_list_scatter(c0, [1, 3], []) From 3803e2371c037d761b852e4d661db78da56cf396 Mon Sep 17 00:00:00 2001 From: Victor Stone Date: Fri, 14 Jul 2023 12:55:07 -0700 Subject: [PATCH 332/376] Remove the distinction between rematerialization which happens before and after multi-output fusion since the distinction is not used. PiperOrigin-RevId: 548200573 --- .../xla/service/gpu/compile_module_to_llvm_ir.cc | 1 - .../compiler/xla/service/hlo_rematerialization.h | 15 +-------------- .../xla/service/hlo_rematerialization_test.cc | 2 -- 3 files changed, 1 insertion(+), 17 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/compile_module_to_llvm_ir.cc b/tensorflow/compiler/xla/service/gpu/compile_module_to_llvm_ir.cc index 1a23214065617d..e8771d4c4b2bb8 100644 --- a/tensorflow/compiler/xla/service/gpu/compile_module_to_llvm_ir.cc +++ b/tensorflow/compiler/xla/service/gpu/compile_module_to_llvm_ir.cc @@ -327,7 +327,6 @@ Status CompileModuleToLlvmIrImpl( }, // Assume 75% of the total device memory is available for XLA. /*memory_limit_bytes=*/gpu_device_info.device_memory_size * 0.75, - HloRematerialization::RematerializationPass::kPostFusion, /*block_size_limit=*/1, /*block_rematerialization_factor=*/1, /*compact_shape_function=*/nullptr, HloRematerialization::RematerializationMode::kRecomputeAndCompress); diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.h b/tensorflow/compiler/xla/service/hlo_rematerialization.h index 611740c56e83b8..618027423de37e 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.h +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.h @@ -59,19 +59,11 @@ class HloRematerialization : public HloModulePass { kRecomputeAndCompress // Consider both kRecompute and kRemat. }; - // Enum to specify whether this rematerialization pass occurs before or after - // multi-output fusion. - enum class RematerializationPass { - kPreFusion, // Rematerialization pass before multi-output fusion. - kPostFusion // Rematerialization pass after multi-output fusion. - }; - static Shape DefaultCompactShapeFunction(const Shape& shape) { return shape; } struct Options { explicit Options(const ShapeSizeFunction& size_function, - int64_t memory_limit_bytes, - RematerializationPass pass_location, int block_size_limit, + int64_t memory_limit_bytes, int block_size_limit, int block_rematerialization_factor, CompactShapeFunction compact_shape_function = nullptr, RematerializationMode mode = @@ -79,7 +71,6 @@ class HloRematerialization : public HloModulePass { int64_t min_remat_size = 0) : size_function(size_function), memory_limit_bytes(memory_limit_bytes), - pass_location(pass_location), block_size_limit(block_size_limit), block_rematerialization_factor(block_rematerialization_factor), compact_shape_function(compact_shape_function == nullptr @@ -96,10 +87,6 @@ class HloRematerialization : public HloModulePass { // from this. int64_t memory_limit_bytes; - // Specifies whether this rematerialization pass occurs before or after - // multi-output fusion. - RematerializationPass pass_location; - // Maximum number of consecutive instructions to consider for // rematerialization. int block_size_limit; diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc index 8c61c6cdd4a55f..e487d840259e45 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc @@ -52,7 +52,6 @@ class HloRematerializationTest : public RematerializationTestBase { HloRematerialization::Options options( ByteSizeOf, memory_limit_bytes, - HloRematerialization::RematerializationPass::kPreFusion, /*block_size_limit=*/1, /*block_rematerialization_factor=*/1, nullptr, HloRematerialization::RematerializationMode::kRecomputeAndCompress, min_remat_size); @@ -609,7 +608,6 @@ class CompressingRematerializationTest : public RematerializationTestBase { TF_EXPECT_OK(verifier().Run(module).status()); HloRematerialization::Options options( ShapeSizePadMinorTo64, memory_limit_bytes, - HloRematerialization::RematerializationPass::kPreFusion, /*block_size_limit=*/1, /*block_rematerialization_factor=*/1, ChooseCompactLayoutForShape, HloRematerialization::RematerializationMode::kCompressOnly, From d9de2bc16c99e0829170ff326a3f910635e9557c Mon Sep 17 00:00:00 2001 From: Armando Ugalde Velasco Date: Fri, 14 Jul 2023 13:00:44 -0700 Subject: [PATCH 333/376] Ensure there is a Model instance in standalone::Iterator's context Currently, the model instance in the IteratorContext of a standalone::Iterator is null. To solve this, we need to create one manually when initializing the IteratorContext. Then, since RootDataset creates its own Model instance, we will replace it with the one passed from our IteratorContext. This is to make sure this instance is the one used down the chain and that we can also access it in our IteratorContext. By passing this model instance, we also ensure that when executing IteratorBase::InitializeBase we are creating a RootDataset node without a parent in the model. PiperOrigin-RevId: 548201925 --- tensorflow/core/data/root_dataset.cc | 21 +++++++++++---------- tensorflow/core/data/standalone.cc | 1 + tensorflow/core/framework/dataset.cc | 3 ++- 3 files changed, 14 insertions(+), 11 deletions(-) diff --git a/tensorflow/core/data/root_dataset.cc b/tensorflow/core/data/root_dataset.cc index 9b03486e80552d..f010348410effb 100644 --- a/tensorflow/core/data/root_dataset.cc +++ b/tensorflow/core/data/root_dataset.cc @@ -156,16 +156,6 @@ class RootDataset::Iterator : public DatasetIterator { public: explicit Iterator(const Params& params) : DatasetIterator(params) { - if (dataset()->params_.autotune) { - model_ = std::make_shared(); - auto experiments = GetExperiments(); - if (experiments.contains("stage_based_autotune_v2")) { - model_->AddExperiment("stage_based_autotune_v2"); - } - if (experiments.contains("autotune_buffer_optimization")) { - model_->AddExperiment("autotune_buffer_optimization"); - } - } if (dataset()->params_.max_intra_op_parallelism >= 0) { max_intra_op_parallelism_ = value_or_default(dataset()->params_.max_intra_op_parallelism, 0, @@ -187,6 +177,17 @@ class RootDataset::Iterator : public DatasetIterator { bool SymbolicCheckpointCompatible() const override { return true; } Status Initialize(IteratorContext* ctx) override { + if (dataset()->params_.autotune) { + model_ = ctx->model() != nullptr ? ctx->model() + : std::make_shared(); + absl::flat_hash_set experiments = GetExperiments(); + if (experiments.contains("stage_based_autotune_v2")) { + model_->AddExperiment("stage_based_autotune_v2"); + } + if (experiments.contains("autotune_buffer_optimization")) { + model_->AddExperiment("autotune_buffer_optimization"); + } + } IteratorContext iter_ctx(CreateParams(ctx)); TF_RETURN_IF_ERROR(dataset()->input_->MakeIterator(&iter_ctx, this, prefix(), &input_impl_)); diff --git a/tensorflow/core/data/standalone.cc b/tensorflow/core/data/standalone.cc index a1c702b0c77b66..ab3eb810af6b98 100644 --- a/tensorflow/core/data/standalone.cc +++ b/tensorflow/core/data/standalone.cc @@ -195,6 +195,7 @@ Status Dataset::MakeIterator( std::back_inserter(params.split_providers)); params.thread_factory = unbounded_thread_pool_.get_thread_factory(); params.thread_pool = &unbounded_thread_pool_; + params.model = std::make_shared(); ctx = std::make_unique(std::move(params)); SerializationContext::Params serialization_params(&op_ctx); auto serialization_ctx = diff --git a/tensorflow/core/framework/dataset.cc b/tensorflow/core/framework/dataset.cc index a197e73ed8ed47..02b6953d4ea1ab 100644 --- a/tensorflow/core/framework/dataset.cc +++ b/tensorflow/core/framework/dataset.cc @@ -545,7 +545,8 @@ Status IteratorBase::InitializeBase(IteratorContext* ctx, auto factory = [ctx, this](model::Node::Args args) { return CreateNode(ctx, std::move(args)); }; - model->AddNode(std::move(factory), prefix(), parent->model_node(), &node_); + model->AddNode(std::move(factory), prefix(), + parent == nullptr ? nullptr : parent->model_node(), &node_); cleanup_fns_.push_back([this, model]() { model->RemoveNode(node_); }); } return OkStatus(); From 24ddf0aa95b1bba32ac7fb77fa22090c55312e0f Mon Sep 17 00:00:00 2001 From: Reed Wanderman-Milne Date: Fri, 14 Jul 2023 14:20:59 -0700 Subject: [PATCH 334/376] Disable cluster_coordinator_test with XLA. It fails with the error "Duplicate variable passed to XLA cluster". PiperOrigin-RevId: 548221793 --- tensorflow/python/distribute/coordinator/BUILD | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorflow/python/distribute/coordinator/BUILD b/tensorflow/python/distribute/coordinator/BUILD index b3c9f3f836ae80..a2dc307a546684 100644 --- a/tensorflow/python/distribute/coordinator/BUILD +++ b/tensorflow/python/distribute/coordinator/BUILD @@ -96,6 +96,7 @@ distribute_py_strict_test( "notpu", "notsan", # TODO(b/171040359): Flaky timeout, even if maximum shards ], + xla_enable_strict_auto_jit = False, # TODO(b/291174864) xla_tags = [ "no_cuda_asan", # Race condition on async test ], From c509b9a5c4ab83bd5dd56bf39a351f7e9a30c9df Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 14 Jul 2023 14:40:11 -0700 Subject: [PATCH 335/376] Return Floating-Point Tensor for Dot-Like Hybrid Ops In ConvertMHLOQuantToInt PiperOrigin-RevId: 548226581 --- .../bridge/convert_mhlo_quant_to_int.cc | 49 ++++++------------- .../bridge/convert-mhlo-quant-to-int.mlir | 5 +- 2 files changed, 16 insertions(+), 38 deletions(-) diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_mhlo_quant_to_int.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_mhlo_quant_to_int.cc index 43a5d568132894..b54fc3bfc2ff22 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_mhlo_quant_to_int.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_mhlo_quant_to_int.cc @@ -447,21 +447,15 @@ template LogicalResult matchAndRewriteDotLikeHybridOp( OpType &op, OpAdaptorType &adaptor, ConversionPatternRewriter &rewriter, const quant::UniformQuantizedType &rhs_element_type) { - auto res_float32_tensor_type_or = GetSameShapeTensorType( - op, op.getResult().getType().template cast(), - rewriter.getF32Type(), rewriter); - if (failed(res_float32_tensor_type_or)) { - return failure(); - } - - Value lhs_float32_tensor = adaptor.getLhs(); - Value rhs = adaptor.getRhs(); - // For dot like hybrid ops, lhs is float type, rhs is uniform // quantized type and result is float type. // For weight-only quantization: // result = hybridOp(lhs, dequant(rhs)) - // + Value lhs_float32_tensor = adaptor.getLhs(); + Value rhs = adaptor.getRhs(); + auto res_float32_tensor_type = + op.getResult().getType().template cast(); + // Get scales and zero points for rhs. Value rhs_zero_point = rewriter.create( op->getLoc(), @@ -472,36 +466,24 @@ LogicalResult matchAndRewriteDotLikeHybridOp( // Dequantize rhs_float32_tensor. Value rhs_float32_tensor = rewriter.create( - op->getLoc(), *res_float32_tensor_type_or, rhs); + op->getLoc(), res_float32_tensor_type, rhs); rhs_float32_tensor = rewriter.create( - op->getLoc(), *res_float32_tensor_type_or, rhs_float32_tensor, - rhs_zero_point, nullptr); + op->getLoc(), res_float32_tensor_type, rhs_float32_tensor, rhs_zero_point, + nullptr); rhs_float32_tensor = rewriter.create( - op->getLoc(), *res_float32_tensor_type_or, rhs_float32_tensor, + op->getLoc(), res_float32_tensor_type, rhs_float32_tensor, rhs_scale_constant, nullptr); // Execute conversion target op. SmallVector operands{lhs_float32_tensor, rhs_float32_tensor}; Value res_float32 = rewriter.create( - op->getLoc(), *res_float32_tensor_type_or, operands, op->getAttrs()); + op->getLoc(), res_float32_tensor_type, operands, op->getAttrs()); Value half = rewriter.create( op->getLoc(), rewriter.getF32FloatAttr(0.5f)); res_float32 = rewriter.create( - op->getLoc(), *res_float32_tensor_type_or, res_float32, half, nullptr); - res_float32 = rewriter.create(op->getLoc(), res_float32); - - // Cast res_float_tensor_type to res_int_tensor_type. - auto res_int32_tensor_type_or = GetSameShapeTensorType( - op, op.getResult().getType().template cast(), - rewriter.getI32Type(), rewriter); - if (failed(res_int32_tensor_type_or)) { - return failure(); - } - - rewriter.replaceOpWithNewOp(op, *res_int32_tensor_type_or, - res_float32); - + op->getLoc(), res_float32_tensor_type, res_float32, half, nullptr); + rewriter.replaceOpWithNewOp(op, res_float32); return success(); } @@ -660,8 +642,7 @@ class ConvertUniformQuantizedDotOp : public OpConversionPattern { LogicalResult matchAndRewrite( mhlo::DotOp op, mhlo::DotOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - return matchAndRewriteDotLikeOp( - op, adaptor, rewriter); + return matchAndRewriteDotLikeOp(op, adaptor, rewriter); } }; @@ -673,9 +654,7 @@ class ConvertUniformQuantizedConvolutionOp LogicalResult matchAndRewrite( mhlo::ConvolutionOp op, mhlo::ConvolutionOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - return matchAndRewriteDotLikeOp(op, adaptor, - rewriter); + return matchAndRewriteDotLikeOp(op, adaptor, rewriter); } }; diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/bridge/convert-mhlo-quant-to-int.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/bridge/convert-mhlo-quant-to-int.mlir index 4766956e56b00f..d01c18bce6c607 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/bridge/convert-mhlo-quant-to-int.mlir +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/bridge/convert-mhlo-quant-to-int.mlir @@ -312,17 +312,16 @@ func.func @uniform_quantized_convolution(%arg0: tensor, %arg1: tens // ----- // CHECK-LABEL: func @uniform_quantize_dot_hybrid -func.func @uniform_quantize_dot_hybrid(%arg0: tensor, %arg1: tensor) { +func.func @uniform_quantize_dot_hybrid(%arg0: tensor, %arg1: tensor) -> tensor { // CHECK: %[[VAL1:.*]] = mhlo.convert %[[VAL0:.*]] : (tensor) -> tensor // CHECK: %[[VAL3:.*]] = chlo.broadcast_subtract %[[VAL1]], %[[VAL2:.*]] : (tensor, tensor) -> tensor // CHECK: %[[VAL5:.*]] = chlo.broadcast_multiply %[[VAL3]], %[[VAL4:.*]] : (tensor, tensor) -> tensor // CHECK: %[[VAL7:.*]] = "mhlo.dot"(%[[VAL6:.*]], %[[VAL5]]) : (tensor, tensor) -> tensor // CHECK: %[[VAL9:.*]] = chlo.broadcast_add %[[VAL7]], %[[VAL8:.*]] : (tensor, tensor) -> tensor // CHECK: %[[VAL10:.*]] = mhlo.floor %[[VAL9]] : tensor - // CHECK: %[[VAL11:.*]] = mhlo.convert %[[VAL10]] : (tensor) -> tensor %0 = mhlo.uniform_quantize %arg1 : (tensor) -> tensor> %1 = "mhlo.dot" (%arg0, %0): (tensor, tensor>) -> tensor - return + return %1: tensor } // ----- From 387f5f44350e080adeed06056bef3d7cd635a8d2 Mon Sep 17 00:00:00 2001 From: Adam Cogdell Date: Fri, 14 Jul 2023 14:41:21 -0700 Subject: [PATCH 336/376] Add fingerprints generated by different compilation modes to chunked SM tests. PiperOrigin-RevId: 548226906 --- tensorflow/cc/saved_model/BUILD | 1 + .../saved_model/pywrap_saved_model_fingerprinting_test.py | 8 ++++++-- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/tensorflow/cc/saved_model/BUILD b/tensorflow/cc/saved_model/BUILD index 27177124e80b0d..c33a59027c5f31 100644 --- a/tensorflow/cc/saved_model/BUILD +++ b/tensorflow/cc/saved_model/BUILD @@ -575,6 +575,7 @@ cc_library( # ], # deps = [ # ":fingerprinting", +# "@com_google_absl//absl/container:flat_hash_set", # "//tensorflow/core:protos_all_cc", # "//tensorflow/core:test", # "//tensorflow/core/platform:path", diff --git a/tensorflow/python/saved_model/pywrap_saved_model_fingerprinting_test.py b/tensorflow/python/saved_model/pywrap_saved_model_fingerprinting_test.py index ea34b16b77a36c..94425f2ab92174 100644 --- a/tensorflow/python/saved_model/pywrap_saved_model_fingerprinting_test.py +++ b/tensorflow/python/saved_model/pywrap_saved_model_fingerprinting_test.py @@ -97,9 +97,13 @@ def test_read_chunked_saved_model_fingerprint(self): fingerprint = fingerprint_pb2.FingerprintDef().FromString( pywrap_fingerprinting.CreateFingerprintDef(export_dir)) self.assertGreater(fingerprint.saved_model_checksum, 0) - self.assertEqual(fingerprint.graph_def_program_hash, 906548630859202535) + # We test for multiple fingerprints due to non-determinism when building + # with different compilation_mode flag options. + self.assertIn(fingerprint.graph_def_program_hash, + [906548630859202535, 9562420523583756263]) self.assertEqual(fingerprint.signature_def_hash, 1043582354059066488) - self.assertEqual(fingerprint.saved_object_graph_hash, 11894619660760763927) + self.assertIn(fingerprint.saved_object_graph_hash, + [11894619660760763927, 2766043449526180728]) self.assertEqual(fingerprint.checkpoint_hash, 0) From 40f01670c823de04fe8302fa9bc6987a376bbf4f Mon Sep 17 00:00:00 2001 From: Arian Arfaian Date: Fri, 14 Jul 2023 14:47:50 -0700 Subject: [PATCH 337/376] Deprecate `experimental_from_jax` in favor of `jax2tf` + `from_saved_model`. The recommended way to convert JAX models includes using PAX/SAX/T5X/Orbax Export. For a JAX-native model, `jax2tf` can be used directly. Once a TF SavedModel is procured using the aforementioned tools, then `lite.TFLiteConverter.from_saved_model` can be used without the need to specify any additional flags. PiperOrigin-RevId: 548228454 --- tensorflow/lite/python/lite.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tensorflow/lite/python/lite.py b/tensorflow/lite/python/lite.py index 8c31c7f973fcff..014dfe73fe5761 100644 --- a/tensorflow/lite/python/lite.py +++ b/tensorflow/lite/python/lite.py @@ -2127,6 +2127,11 @@ def from_keras_model(cls, model): return TFLiteKerasModelConverterV2(model) @classmethod + @_deprecation.deprecated( + None, + "Use `jax2tf.convert` and (`lite.TFLiteConverter.from_saved_model`" + " or `lite.TFLiteConverter.from_concrete_functions`) instead.", + ) def experimental_from_jax(cls, serving_funcs, inputs): # Experimental API, subject to changes. # TODO(b/197690428): Currently only support single function. From 5006a30cce50b6f613317a2ba5db16edd34cf02e Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 14 Jul 2023 15:59:55 -0700 Subject: [PATCH 338/376] Improve ragged_cross_op input ragged splits check and fix flaky ragged_cross_op_tests. PiperOrigin-RevId: 548243985 --- tensorflow/core/kernels/ragged_cross_op.cc | 30 ++++++++++-- .../python/ops/ragged/ragged_cross_op_test.py | 48 ++++++++++++------- 2 files changed, 56 insertions(+), 22 deletions(-) diff --git a/tensorflow/core/kernels/ragged_cross_op.cc b/tensorflow/core/kernels/ragged_cross_op.cc index 31af55a893a562..71deb58c3c12d0 100644 --- a/tensorflow/core/kernels/ragged_cross_op.cc +++ b/tensorflow/core/kernels/ragged_cross_op.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include "absl/status/status.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" @@ -386,13 +387,32 @@ class RaggedCrossOp : public OpKernel { // Validate tensor shapes. for (int i = 0; i < num_ragged; ++i) { - if (!TensorShapeUtils::IsVector(ragged_values_list[i].shape())) { - return errors::InvalidArgument( + if (!TensorShapeUtils::IsVector(ragged_values_list[i].shape()) || + !TensorShapeUtils::IsVector(ragged_splits_list[i].shape())) { + return absl::InvalidArgumentError( "tf.ragged.cross only supports inputs with rank=2."); } - if (!TensorShapeUtils::IsVector(ragged_splits_list[i].shape()) || - (ragged_splits_list[i].NumElements() == 0)) { - return errors::InvalidArgument("Invalid RaggedTensor"); + if (ragged_splits_list[i].NumElements() == 0) { + return absl::InvalidArgumentError( + "Invalid RaggedTensor: Ragged splits must be non-empty."); + } + auto flat_row_splits = ragged_splits_list[i].flat(); + if (flat_row_splits(0) != 0) { + return absl::InvalidArgumentError( + "Invalid RaggedTensor: Ragged splits must start from 0."); + } + int64_t num_values = ragged_values_list[i].NumElements(); + if (flat_row_splits(flat_row_splits.size() - 1) != num_values) { + return absl::InvalidArgumentError( + "Invalid RaggedTensor: " + "Ragged splits must end with the number of values."); + } + for (int i = 1; i < flat_row_splits.size(); ++i) { + if (flat_row_splits(i - 1) > flat_row_splits(i)) { + return absl::InvalidArgumentError( + "Invalid RaggedTensor: " + "Ragged splits must be sorted in ascending order."); + } } } for (int i = 0; i < num_sparse; ++i) { diff --git a/tensorflow/python/ops/ragged/ragged_cross_op_test.py b/tensorflow/python/ops/ragged/ragged_cross_op_test.py index 0408051f1f5c89..c098c13644f342 100644 --- a/tensorflow/python/ops/ragged/ragged_cross_op_test.py +++ b/tensorflow/python/ops/ragged/ragged_cross_op_test.py @@ -27,7 +27,6 @@ from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_ragged_array_ops -from tensorflow.python.ops import random_ops from tensorflow.python.ops import sparse_ops from tensorflow.python.ops.ragged import ragged_array_ops from tensorflow.python.ops.ragged import ragged_factory_ops @@ -350,29 +349,32 @@ def testRaggedCrossLargeBatch(self): dict( testcase_name='BadDType', inputs=[ragged_const([[1.1], [2.2, 3.3]])], - message=r'Unexpected dtype for inputs\[0\]'), + message=r'Unexpected dtype for inputs\[0\]', + ), dict( testcase_name='StaticBatchSizeMismatch1', - inputs=[ragged_const([[1]]), - ragged_const([[2], [3]])], + inputs=[ragged_const([[1]]), ragged_const([[2], [3]])], exception=(ValueError, errors.InvalidArgumentError), - message='inputs must all have the same batch dimension size'), + message='inputs must all have the same batch dimension size', + ), dict( testcase_name='StaticBatchSizeMismatch2', - inputs=[ragged_const([[1]]), - dense_const([[2], [3]])], + inputs=[ragged_const([[1]]), dense_const([[2], [3]])], exception=(ValueError, errors.InvalidArgumentError), - message='inputs must all have the same batch dimension size'), + message='inputs must all have the same batch dimension size', + ), dict( testcase_name='3DDenseTensor', inputs=[dense_const([[[1]]])], exception=(ValueError, errors.InvalidArgumentError), - message='tf.ragged.cross only supports inputs with rank=2'), + message='tf.ragged.cross only supports inputs with rank=2', + ), dict( testcase_name='0DDenseTensor', inputs=[dense_const(1)], exception=(ValueError, errors.InvalidArgumentError), - message='tf.ragged.cross only supports inputs with rank=2'), + message='tf.ragged.cross only supports inputs with rank=2', + ), ]) def testStaticError(self, inputs, exception=ValueError, message=None): with self.assertRaisesRegex(exception, message): @@ -382,25 +384,29 @@ def testStaticError(self, inputs, exception=ValueError, message=None): dict( testcase_name='3DRaggedTensor', inputs=[ragged_const([[[1]]], ragged_rank=1)], - message='tf.ragged.cross only supports inputs with rank=2'), + message='tf.ragged.cross only supports inputs with rank=2', + ), dict( testcase_name='0DDenseTensor', inputs=[dense_const(1)], signature=[[tensor_spec.TensorSpec(None, dtypes.int32)]], exception=(ValueError, errors.InvalidArgumentError), - message='tf.ragged.cross only supports inputs with rank=2'), + message='tf.ragged.cross only supports inputs with rank=2', + ), dict( testcase_name='1DDenseTensor', inputs=[dense_const([1])], signature=[[tensor_spec.TensorSpec(None, dtypes.int32)]], exception=(ValueError, errors.InvalidArgumentError), - message='tf.ragged.cross only supports inputs with rank=2'), + message='tf.ragged.cross only supports inputs with rank=2', + ), dict( testcase_name='3DDenseTensor', inputs=[dense_const([[[1]]])], signature=[[tensor_spec.TensorSpec(None, dtypes.int32)]], exception=(ValueError, errors.InvalidArgumentError), - message='tf.ragged.cross only supports inputs with rank=2'), + message='tf.ragged.cross only supports inputs with rank=2', + ), ]) def testRuntimeError(self, inputs, @@ -458,7 +464,15 @@ def testRaggedValuesAndSplitsMustMatch(self): out_values_type=dtypes.string, out_row_splits_type=dtypes.int64)) - def testRaggedCrossInvalidValue(self): + @parameterized.named_parameters([ + dict(testcase_name='EmptySplits', ragged_splits=[]), + dict( + testcase_name='NegativeSplits', ragged_splits=[-216, -114, -58, -54] + ), + dict(testcase_name='TooLargeValueSplits', ragged_splits=[0, 1, 2, 10]), + dict(testcase_name='UnsortedSplits', ragged_splits=[0, 2, 2, 1]), + ]) + def testRaggedCrossInvalidRaggedSplits(self, ragged_splits): # Test case in GitHub isseu 59114. with self.assertRaisesRegex( (ValueError, errors.InvalidArgumentError), 'Invalid RaggedTensor' @@ -468,8 +482,8 @@ def testRaggedCrossInvalidValue(self): ragged_values = [ ragged_values_0, ] - ragged_row_splits_0_tensor = random_ops.random_uniform( - [4], minval=-256, maxval=257, dtype=dtypes.int64 + ragged_row_splits_0_tensor = ragged_const( + ragged_splits, dtype=dtypes.int64 ) ragged_row_splits_0 = array_ops.identity(ragged_row_splits_0_tensor) ragged_row_splits = [ From 4af196bdf7602a3987fbac7e114e1cb01ed25484 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 14 Jul 2023 16:09:10 -0700 Subject: [PATCH 339/376] Add metrics in LegalizeTfTypesPass to help debugging PiperOrigin-RevId: 548246041 --- .../mlir/tf2xla/transforms/legalize_tf_types.cc | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_types.cc b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_types.cc index d8fc61604cebcb..bab46615f90003 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_types.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_types.cc @@ -20,6 +20,10 @@ limitations under the License. // TODO(b/180234029): The rewrite here should be part of the LegalizeTF pass // rather than its own pass. +#include +#include +#include + #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/TypeSwitch.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project @@ -31,6 +35,7 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" +#include "tensorflow/core/lib/monitoring/counter.h" #define DEBUG_TYPE "xla-legalize-tf-types" @@ -38,6 +43,12 @@ namespace mlir { namespace mhlo { namespace { +// TODO: b/290366702 - Temporarily added metrics for debugging. +auto *mlir_tf_quant_op_count = tensorflow::monitoring::Counter<1>::New( + "/tensorflow/core/tf2xla/tf_quant_op_count" /*metric_name*/, + "Counts the number of ops that has qint types" /*metric description*/, + "op_name" /*metric label*/); + bool IsIllegalElementType(Type type) { return type .isagetResults()); + // TODO: b/290366702 - Temporarily added metrics for debugging. + if (llvm::any_of(op->getResultTypes(), IsIllegalType)) { + mlir_tf_quant_op_count->GetCell(std::string(op->getName().getStringRef())) + ->IncrementBy(1); + } return success(); } }; From c060992f381f0b2728118532f9b680d826682b51 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 14 Jul 2023 16:38:26 -0700 Subject: [PATCH 340/376] Added attributes within hlo and xla data proto to keep track of desired statistics. PiperOrigin-RevId: 548251788 --- .../compiler/xla/hlo/ir/hlo_instruction.cc | 38 +++++++++ .../compiler/xla/hlo/ir/hlo_instruction.h | 26 ++++++ tensorflow/compiler/xla/service/hlo.proto | 6 +- .../compiler/xla/service/hlo_graph_dumper.cc | 74 +++++++++++++++-- .../xla/service/hlo_graph_dumper_test.cc | 19 +++++ .../xla/service/hlo_instruction_test.cc | 82 +++++++++++++++++++ tensorflow/compiler/xla/service/hlo_parser.cc | 81 ++++++++++++++++++ tensorflow/compiler/xla/service/hlo_parser.h | 5 ++ .../compiler/xla/service/hlo_parser_test.cc | 12 ++- tensorflow/compiler/xla/xla_data.proto | 17 ++++ 10 files changed, 352 insertions(+), 8 deletions(-) diff --git a/tensorflow/compiler/xla/hlo/ir/hlo_instruction.cc b/tensorflow/compiler/xla/hlo/ir/hlo_instruction.cc index 48282e2db005a3..1f583a923fb6cc 100644 --- a/tensorflow/compiler/xla/hlo/ir/hlo_instruction.cc +++ b/tensorflow/compiler/xla/hlo/ir/hlo_instruction.cc @@ -1049,6 +1049,10 @@ StatusOr> HloInstruction::CreateFromProto( instruction->set_frontend_attributes(proto.frontend_attributes()); } + if (proto.has_statistics_viz()) { + instruction->set_statistics_viz(proto.statistics_viz()); + } + return std::move(instruction); } @@ -1761,6 +1765,7 @@ HloInstruction::CreateBroadcastSequence( broadcast->copy_sharding(operand); } broadcast->set_frontend_attributes(operand->frontend_attributes()); + broadcast->set_statistics_viz(operand->statistics_viz()); return broadcast; } // Do explicit broadcast for degenerate broadcast. @@ -1787,6 +1792,7 @@ HloInstruction::CreateBroadcastSequence( reshaped_operand->copy_sharding(operand); } reshaped_operand->set_frontend_attributes(operand->frontend_attributes()); + reshaped_operand->set_statistics_viz(operand->statistics_viz()); // Broadcast 'reshape' up to the larger size. auto broadcast = HloInstruction::CreateBroadcast( broadcast_shape, reshaped_operand, broadcast_dimensions); @@ -1795,6 +1801,7 @@ HloInstruction::CreateBroadcastSequence( broadcast->copy_sharding(operand); } broadcast->set_frontend_attributes(operand->frontend_attributes()); + broadcast->set_statistics_viz(operand->statistics_viz()); return broadcast; } @@ -1878,6 +1885,7 @@ void HloInstruction::SetupDerivedInstruction( } derived_instruction->set_metadata(metadata_); derived_instruction->set_frontend_attributes(frontend_attributes_); + derived_instruction->set_statistics_viz(statistics_viz_); } bool HloInstruction::IsRoot() const { @@ -3518,6 +3526,12 @@ void HloInstruction::PrintExtraAttributes( printer->Append("}"); }); } + + if (!statistics_viz_.statistics().empty()) { + printer.Next([this](Printer* printer) { + AppendCat(printer, "statistics=", StatisticsVizToString(statistics_viz_)); + }); + } } std::vector HloInstruction::ExtraAttributesToString( @@ -3585,6 +3599,8 @@ HloInstructionProto HloInstruction::ToProto() const { *proto.mutable_frontend_attributes() = frontend_attributes_; + *proto.mutable_statistics_viz() = statistics_viz_; + return proto; } @@ -4268,6 +4284,28 @@ std::string FrontendAttributesToString( absl::StrJoin(sorted_attributes, ",", formatter)); } +std::string StatisticsVizToString(const StatisticsViz& statistics_viz) { + // Statistics is either empty, or always starts with the index of the + // statistic that is rendered on the graph, followed by the statistics that + // are being tracked. The index is 0 based, starting from the first statistic + // being tracked. The index and statistics are within a comma-separated list + // of attribute=value pairs, + // e.g., statistics={visualizing_index=0, count_nan=100, count_inf=200}. + + if (statistics_viz.statistics().empty()) return "{}"; + + std::vector all_statistics(statistics_viz.statistics().begin(), + statistics_viz.statistics().end()); + + const auto formatter = [](std::string* out, const Statistic& item) { + absl::StrAppend(out, item.stat_name(), "=", item.stat_val()); + }; + return absl::StrFormat("{%s,%s}", + absl::StrCat("visualizing_index=", + statistics_viz.stat_index_to_visualize()), + absl::StrJoin(all_statistics, ",", formatter)); +} + std::string PaddingConfigToString(const PaddingConfig& padding) { bool has_interior_padding = absl::c_any_of(padding.dimensions(), diff --git a/tensorflow/compiler/xla/hlo/ir/hlo_instruction.h b/tensorflow/compiler/xla/hlo/ir/hlo_instruction.h index e0bafcbcf46322..443698af0219c2 100644 --- a/tensorflow/compiler/xla/hlo/ir/hlo_instruction.h +++ b/tensorflow/compiler/xla/hlo/ir/hlo_instruction.h @@ -1835,6 +1835,27 @@ class HloInstruction { return frontend_attributes_; } + void add_single_statistic(Statistic statistic) { + *statistics_viz_.add_statistics() = std::move(statistic); + } + + void set_stat_index_to_visualize(int64_t index) { + statistics_viz_.set_stat_index_to_visualize(index); + } + + bool has_statistics() const { return !statistics_viz_.statistics().empty(); } + + const Statistic& statistic_to_visualize() const { + return statistics_viz_.statistics().at( + statistics_viz_.stat_index_to_visualize()); + } + + void set_statistics_viz(StatisticsViz statistics_viz) { + statistics_viz_ = std::move(statistics_viz); + } + + const StatisticsViz& statistics_viz() const { return statistics_viz_; } + // Getter/setter for raw JSON-encoded backend config. Prefer the // functions above that deal in proto Messages where possible. const std::string& raw_backend_config_string() const { @@ -2434,6 +2455,10 @@ class HloInstruction { // z' = const(20), frontend_attributes={?} FrontendAttributes frontend_attributes_; + // Used to render an HLO graph when tracking the propagation desired values + // through it. + StatisticsViz statistics_viz_; + // String identifier for instruction. std::string name_; @@ -2468,6 +2493,7 @@ StatusOr StringToFusionKind( std::string PaddingConfigToString(const PaddingConfig& padding); std::string FrontendAttributesToString( const FrontendAttributes& frontend_attributes); +std::string StatisticsVizToString(const StatisticsViz& statistics_viz); std::string RandomAlgorithmToString(const RandomAlgorithm& algorithm); std::string RandomDistributionToString(const RandomDistribution& distribution); std::string PrecisionToString(const PrecisionConfig::Precision& precision); diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto index 0b5398220a8364..cb27cdc1eda6a7 100644 --- a/tensorflow/compiler/xla/service/hlo.proto +++ b/tensorflow/compiler/xla/service/hlo.proto @@ -111,7 +111,7 @@ enum CustomCallApiVersion { } // Serialization of HloInstruction. -// Next ID: 82 +// Next ID: 83 message HloInstructionProto { reserved 10; reserved "parameter_name"; @@ -363,6 +363,10 @@ message HloInstructionProto { // Represents the K value for top-k. int64 k = 81; + + // Represents the information for tracking propagation of values within HLO + // graph. + xla.StatisticsViz statistics_viz = 82; } // Serialization of HloComputation. diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index 5c4c4ba2a02566..eb0caab8c81cfb 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -202,6 +202,45 @@ NodeColors NodeColorsForScheme(ColorScheme color) { } } +// Given a Statistic object, returns a hex string for the fill color of the node +// with that statistic. +const char* NodeFillColorForStatistic(const Statistic& statistic) { + auto stat_val = statistic.stat_val(); + if (stat_val == 0) { + return "#f5f5f5"; + } else if (stat_val < 10) { + return "#f7d4cc"; + } else if (stat_val < 20) { + return "#f8b2a3"; + } else if (stat_val < 30) { + return "#f9a28f"; + } else if (stat_val < 40) { + return "#fa917b"; + } else if (stat_val < 50) { + return "#fb8066"; + } else if (stat_val < 60) { + return "#fc7052"; + } else if (stat_val < 70) { + return "#fd5f3d"; + } else if (stat_val < 80) { + return "#fd4e29"; + } else if (stat_val < 90) { + return "#fe3e14"; + } else { + return "#ff2d00"; + } +} + +// Given a Statistic object, returns a hex string for the font color of the node +// with that statistic. +const char* NodeFontColorForStatistic(const Statistic& statistic) { + if (statistic.stat_val() < 60) { + return "black"; + } else { + return "white"; + } +} + // Given a ColorScheme, returns an attribute string for a node of that color. // Sets the node's style and fill/stroke/text colors. // @@ -658,7 +697,13 @@ std::string HloDotDumper::DumpSubcomputation( bool highlight = filter_.Highlight(parent_instr); const char* fillcolor; const char* strokecolor; - if (debug_options_.xla_hlo_graph_sharding_color() && !highlight) { + + if (!highlight && parent_instr->has_statistics()) { + // Use color from the statistic + fillcolor = + NodeFillColorForStatistic(parent_instr->statistic_to_visualize()); + strokecolor = "#c2c2c2"; + } else if (debug_options_.xla_hlo_graph_sharding_color() && !highlight) { // Use the sharding color, if the node isn't highlighted. NodeColors node_colors = NodeColorsForScheme(GetInstructionColor(parent_instr)); @@ -837,6 +882,22 @@ std::string HloDotDumper::DumpInstruction(const HloInstruction* instr) { color = kDarkRed; } } + + NodeColors node_colors = NodeColorsForScheme(color); + if (instr->has_statistics()) { + // override node's color to show statistics + const auto& statistic_to_visualize = instr->statistic_to_visualize(); + node_colors.fill_color = NodeFillColorForStatistic(statistic_to_visualize); + node_colors.stroke_color = "#c2c2c2"; + node_colors.font_color = NodeFontColorForStatistic(statistic_to_visualize); + } + + // Build the node style + std::string node_style = + StrFormat(R"(style="%s", fontcolor="%s", color="%s", fillcolor="%s")", + node_colors.style, node_colors.font_color, + node_colors.stroke_color, node_colors.fill_color); + // Build the text that will be displayed inside the node. std::string node_body = node_label; for (const std::string& s : {trivial_subcomputation, extra_info, @@ -849,7 +910,7 @@ std::string HloDotDumper::DumpInstruction(const HloInstruction* instr) { return StrFormat(R"(%s [label=<%s>, shape=%s, tooltip="%s", %s];)" "\n", InstructionId(instr), node_body, node_shape, node_metadata, - NodeColorAttributes(color)); + node_style); } std::string HloDotDumper::GetInstructionNodeInlinedOperands( @@ -2075,10 +2136,11 @@ void RegisterFusionState(const HloComputation& computation, fusion_progress.AddState(dot_txt, label, producer_to_highlight); } -StatusOr RenderGraph( - const HloComputation& computation, absl::string_view label, - const DebugOptions& debug_options, RenderedGraphFormat format, - HloRenderOptions hlo_render_options) { +StatusOr RenderGraph(const HloComputation& computation, + absl::string_view label, + const DebugOptions& debug_options, + RenderedGraphFormat format, + HloRenderOptions hlo_render_options) { absl::MutexLock lock(&url_renderer_mu); if (format == RenderedGraphFormat::kUrl && url_renderer == nullptr) { return Unavailable("Can't render as URL; no URL renderer was registered."); diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc index 901a5aabf302c0..e97f2495edf1fc 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc @@ -164,6 +164,25 @@ TEST_F(HloGraphDumperTest, Compare) { EXPECT_THAT(graph, HasSubstr("direction=LT")); } +TEST_F(HloGraphDumperTest, HasStatisticsViz) { + const char* hlo_string = R"( + HloModule comp + + ENTRY comp { + param.0 = f32[10] parameter(0), statistics={visualizing_index=0,stat-0=0.5} + param.1 = f32[10] parameter(1), statistics={visualizing_index=1,stat-0=55.5,stat-1=44.4} + ROOT lt = pred[10] compare(param.0, param.1), direction=LT + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + // Just check that it doesn't crash. + TF_ASSERT_OK_AND_ASSIGN( + std::string graph, + RenderGraph(*module->entry_computation(), /*label=*/"tuple_constant", + DebugOptions(), RenderedGraphFormat::kDot)); +} + TEST_F(HloGraphDumperTest, RootIsConstant) { const char* hlo_string = R"( HloModule indexed_conditional diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc index c9dc160999c2cc..ab72e6aef292d9 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc @@ -1582,6 +1582,88 @@ TEST_F(HloInstructionTest, Stringification) { "true_computation=%TransposeDot, false_computation=%TransposeDot"); } +TEST_F(HloInstructionTest, GetSetStatisticsViz) { + const Shape shape = ShapeUtil::MakeShape(F32, {5, 10}); + + HloComputation::Builder builder(TestName()); + HloInstruction* x = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "x")); + + StatisticsViz statistics_viz; + statistics_viz.set_stat_index_to_visualize(-1); + + x->set_statistics_viz(statistics_viz); + + EXPECT_FALSE(x->has_statistics()); + EXPECT_EQ(x->statistics_viz().stat_index_to_visualize(), -1); + + Statistic statistic; + statistic.set_stat_name("stat-1"); + statistic.set_stat_val(30.0); + + x->add_single_statistic(statistic); + x->set_stat_index_to_visualize(0); + + EXPECT_TRUE(x->has_statistics()); + EXPECT_TRUE( + protobuf_util::ProtobufEquals(x->statistic_to_visualize(), statistic)); + + statistic.set_stat_val(40.0); + *statistics_viz.add_statistics() = statistic; + + x->set_statistics_viz(statistics_viz); + + EXPECT_TRUE( + protobuf_util::ProtobufEquals(x->statistics_viz(), statistics_viz)); +} + +TEST_F(HloInstructionTest, StringifyStatisticsViz) { + const Shape shape = ShapeUtil::MakeShape(F32, {5, 10}); + + HloComputation::Builder builder(TestName()); + HloInstruction* x = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "x")); + HloInstruction* y = + builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "y")); + HloInstruction* add = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, x, y)); + + // Empty statistics viz must not print "statistics={}" + add->set_statistics_viz({}); + EXPECT_EQ(add->ToString(), + "%add = f32[5,10]{1,0} add(f32[5,10]{1,0} %x, f32[5,10]{1,0} %y)"); + + auto CreateStatisticsVizWithStatistics = + [](int64_t stat_index_to_visualize, + std::initializer_list> statistics) + -> StatisticsViz { + StatisticsViz statistics_viz; + statistics_viz.set_stat_index_to_visualize(stat_index_to_visualize); + + auto create_statistic = [](absl::string_view statistic_name, + double statistic_value) { + Statistic statistic; + statistic.set_stat_name(std::string(statistic_name)); + statistic.set_stat_val(statistic_value); + return statistic; + }; + + for (const auto& [statistic_name, statistic_value] : statistics) { + *statistics_viz.add_statistics() = + create_statistic(statistic_name, statistic_value); + } + + return statistics_viz; + }; + + add->set_statistics_viz(CreateStatisticsVizWithStatistics( + 1, {{"stat-1", 33.0}, {"stat-2", 44.0}})); + + EXPECT_EQ(add->ToString(), + "%add = f32[5,10]{1,0} add(f32[5,10]{1,0} %x, f32[5,10]{1,0} %y), " + "statistics={visualizing_index=1,stat-1=33,stat-2=44}"); +} + TEST_F(HloInstructionTest, StringifyGather_0) { Shape input_tensor_shape = ShapeUtil::MakeShape(F32, {50, 49, 48, 47, 46}); Shape start_indices_tensor_shape = diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index 2e18b6c9c226cd..bbfd2c83ba065b 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -234,6 +234,7 @@ class HloParserImpl : public HloParser { StatusOr ParseShapeOnly(); StatusOr ParseShardingOnly(); StatusOr ParseFrontendAttributesOnly(); + StatusOr ParseStatisticsVizOnly(); StatusOr> ParseParameterReplicationOnly(); StatusOr ParseBooleanListOrSingleBooleanOnly(); StatusOr ParseWindowOnly(); @@ -262,6 +263,7 @@ class HloParserImpl : public HloParser { kConvolutionDimensionNumbers, kSharding, kFrontendAttributes, + kStatisticsViz, kBracedBoolListOrBool, kParameterReplication, kInstructionList, @@ -467,6 +469,7 @@ class HloParserImpl : public HloParser { bool ParseListShardingType(std::vector* types); bool ParseSharding(OpSharding* sharding); bool ParseFrontendAttributes(FrontendAttributes* frontend_attributes); + bool ParseStatisticsViz(StatisticsViz* statistics_viz); bool ParseSingleSharding(OpSharding* sharding, bool lbrace_pre_lexed); bool ParseParameterReplication(ParameterReplication* parameter_replication); bool ParseBooleanListOrSingleBoolean(BoolList* boolean_list); @@ -1204,9 +1207,12 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder, absl::flat_hash_map attrs; optional sharding; optional frontend_attributes; + optional statistics_viz; attrs["sharding"] = {/*required=*/false, AttrTy::kSharding, &sharding}; attrs["frontend_attributes"] = { /*required=*/false, AttrTy::kFrontendAttributes, &frontend_attributes}; + attrs["statistics"] = {/*required=*/false, AttrTy::kStatisticsViz, + &statistics_viz}; optional parameter_replication; attrs["parameter_replication"] = {/*required=*/false, AttrTy::kParameterReplication, @@ -1289,6 +1295,9 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder, if (frontend_attributes) { instruction->set_frontend_attributes(*frontend_attributes); } + if (statistics_viz) { + instruction->set_statistics_viz(*statistics_viz); + } return AddInstruction(name, instruction, name_loc); } @@ -3109,6 +3118,52 @@ bool HloParserImpl::ParseFrontendAttributes( "expects '}' at the end of frontend attributes"); } +// statistics +// ::= '{' /*empty*/ '}' +// ::= '{' index, single_statistic '}' +// index ::= 'visualizing_index=' value +// single_statistic ::= statistic '=' value (',' statistic '=' value)* +bool HloParserImpl::ParseStatisticsViz(StatisticsViz* statistics_viz) { + CHECK(statistics_viz != nullptr); + if (!ParseToken(TokKind::kLbrace, "expected '{' to start statistics")) { + return false; + } + if (lexer_.GetKind() == TokKind::kRbrace) { + // empty + } else { + // index must exist + std::string visualizing_index_attr_name; + if (!ParseAttributeName(&visualizing_index_attr_name)) { + return false; + } + if (lexer_.GetKind() != TokKind::kInt) { + return false; + } + statistics_viz->set_stat_index_to_visualize(lexer_.GetInt64Val()); + lexer_.Lex(); + + // then process statistics + while (EatIfPresent(TokKind::kComma)) { + std::string stat_name; + if (!ParseAttributeName(&stat_name)) { + return false; + } + if (lexer_.GetKind() != TokKind::kDecimal && + lexer_.GetKind() != TokKind::kInt) { + return false; + } + Statistic statistic; + statistic.set_stat_name(stat_name); + statistic.set_stat_val(lexer_.GetKind() == TokKind::kDecimal + ? lexer_.GetDecimalVal() + : lexer_.GetInt64Val()); + lexer_.Lex(); + *statistics_viz->add_statistics() = std::move(statistic); + } + } + return ParseToken(TokKind::kRbrace, "expects '}' at the end of statistics"); +} + // ::= '{' 'replicated'? 'manual'? 'maximal'? ('device=' int)? shape? // ('devices=' ('[' dims ']')* device_list)? // ('metadata=' metadata)* '}' @@ -4458,6 +4513,15 @@ bool HloParserImpl::ParseAttributeHelper( ->emplace(frontend_attributes); return true; } + case AttrTy::kStatisticsViz: { + StatisticsViz statistics_viz; + if (!ParseStatisticsViz(&statistics_viz)) { + return false; + } + static_cast*>(attr_out_ptr) + ->emplace(statistics_viz); + return true; + } case AttrTy::kParameterReplication: { ParameterReplication parameter_replication; if (!ParseParameterReplication(¶meter_replication)) { @@ -6206,6 +6270,18 @@ StatusOr HloParserImpl::ParseFrontendAttributesOnly() { return attributes; } +StatusOr HloParserImpl::ParseStatisticsVizOnly() { + lexer_.Lex(); + StatisticsViz statistics_viz; + if (!ParseStatisticsViz(&statistics_viz)) { + return InvalidArgument("Syntax error:\n%s", GetError()); + } + if (lexer_.GetKind() != TokKind::kEof) { + return InvalidArgument("Syntax error:\nExtra content after statistics"); + } + return statistics_viz; +} + StatusOr> HloParserImpl::ParseParameterReplicationOnly() { lexer_.Lex(); ParameterReplication parameter_replication; @@ -6366,6 +6442,11 @@ StatusOr ParseFrontendAttributes(absl::string_view str) { return parser.ParseFrontendAttributesOnly(); } +StatusOr ParseStatisticsViz(absl::string_view str) { + HloParserImpl parser(str); + return parser.ParseStatisticsVizOnly(); +} + StatusOr> ParseParameterReplication(absl::string_view str) { HloParserImpl parser(str); return parser.ParseParameterReplicationOnly(); diff --git a/tensorflow/compiler/xla/service/hlo_parser.h b/tensorflow/compiler/xla/service/hlo_parser.h index 0ab47a4d276755..f295beb606310f 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.h +++ b/tensorflow/compiler/xla/service/hlo_parser.h @@ -57,6 +57,11 @@ StatusOr ParseSharding(absl::string_view str); // "{attr_a=a,attr_b=b}". StatusOr ParseFrontendAttributes(absl::string_view str); +// Parses statistics viz from str. str is supposed to contain the body of the +// statistics visualization, i.e. just the rhs of the "statistics={...}" +// attribute string, e.g., "{visualizing_index=1,nan_percent=50}". +StatusOr ParseStatisticsViz(absl::string_view str); + // Parses parameter replication from str. str is supposed to contain the body of // the parameter replication, i.e. just the rhs of the // "parameter_replication={...}" attribute string, e.g., "{true, false}". diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc index 075c60bf08310d..332dd413835fbb 100644 --- a/tensorflow/compiler/xla/service/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -132,7 +132,6 @@ ENTRY %constant_pred_array () -> pred[2,3] { )" }, - // s32 constant { "ConstantS32", @@ -142,6 +141,17 @@ ENTRY %constant_s32 () -> s32[] { ROOT %constant = s32[] constant(-42) } +)" +}, +// s32 constant with statistics +{ +"ConstantS32WithStatistics", +R"(HloModule constant_s32_module, entry_computation_layout={()->s32[]} + +ENTRY %constant_s32 () -> s32[] { + ROOT %constant = s32[] constant(-42), statistics={visualizing_index=1,stat-1=33,stat-2=44} +} + )" }, // f32 constant, but the value is not a decimal and there is a backend diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto index ce89c5fc61b2f4..42d6766439d0e8 100644 --- a/tensorflow/compiler/xla/xla_data.proto +++ b/tensorflow/compiler/xla/xla_data.proto @@ -759,6 +759,23 @@ message FrontendAttributes { map map = 1; } +// Represents a single statistic to track. +message Statistic { + // Must be a single word consisting of any alphanumeric characters + string stat_name = 1; + // Must be within a range of [0, 100], in order for the graph dumper to + // properly render the statistic onto the graph. + double stat_val = 2; +} + +// Represents the information needed to visualize propagation statistics when +// rendering an HLO graph. This includes an array of statistics as well as the +// index of the statistic to render. +message StatisticsViz { + int64 stat_index_to_visualize = 1; + repeated Statistic statistics = 2; +} + // LINT.IfChange message OpSharding { enum Type { From 282f113f004394ac8cb9fb169835a4da1858b537 Mon Sep 17 00:00:00 2001 From: Yishuang Pang Date: Fri, 14 Jul 2023 18:08:30 -0700 Subject: [PATCH 341/376] legalize mhlo.dynamic_broadcast_in_dim to tf.broadcast_to and tf.expand_dims PiperOrigin-RevId: 548265763 --- .../mlir/tensorflow/tests/legalize_hlo.mlir | 36 +++++++++++++++++++ .../tensorflow/transforms/legalize_hlo.cc | 28 +++++++++++++++ .../transforms/legalize_hlo_patterns.td | 12 +++++++ 3 files changed, 76 insertions(+) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir b/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir index 1da0ee417402b9..f369fbde5f6fa3 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir @@ -1555,6 +1555,42 @@ func.func @broadcast_in_dim_general_case(%arg0: tensor<3x1x16xf32>) -> tensor<3x func.return %0 : tensor<3x8x8x16xf32> } +// CHECK-LABEL: func @dynamic_broadcast_in_dim_tf_style( +// CHECK-SAME: %[[ARG_0:.*]]: tensor, +// CHECK-SAME: %[[ARG_1:.*]]: tensor<5xi32>) -> tensor { +// CHECK %[[VAL_0:.*]] = "tf.BroadcastTo"(%[[ARG_0]], %[[ARG_1]]) : (tensor, tensor<5xi32>) -> tensor +// CHECK return %[[VAL_0]] : tensor +func.func @dynamic_broadcast_in_dim_tf_style(%arg0: tensor, %arg1: tensor<5xi32>) -> tensor { + %0 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %arg1) {broadcast_dimensions = dense<[0, 1, 2, 3, 4]> : tensor<5xi64>} : (tensor, tensor<5xi32>) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: func @dynamic_broadcast_in_dim_general_case_expand_back_dims( +// CHECK-SAME: %[[ARG_0:.*]]: tensor, +// CHECK-SAME: %[[ARG_1:.*]]: tensor<4xi32>) -> tensor { +// CHECK %[[CST_0:.*]] = "tf.Const"() {value = dense<2> : tensor} : () -> tensor +// CHECK %[[VAL_0:.*]] = "tf.ExpandDims"(%[[ARG_0]], %[[CST_0]]) : (tensor, tensor) -> tensor +// CHECK %[[CST_1:.*]] = "tf.Const"() {value = dense<3> : tensor} : () -> tensor +// CHECK %[[VAL_1:.*]] = "tf.ExpandDims"(%[[VAL_0]], %[[CST_1]]) : (tensor, tensor) -> tensor +// CHECK %[[VAL_2:.*]] = "tf.BroadcastTo"(%[[VAL_1]], %[[ARG_1]]) : (tensor, tensor<4xi32>) -> tensor +// CHECK return %[[VAL_2]] : tensor +func.func @dynamic_broadcast_in_dim_general_case_expand_back_dims(%arg0: tensor, %arg1: tensor<4xi32>) -> tensor { + %0 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %arg1) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor, tensor<4xi32>) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: func @dynamic_broadcast_in_dim_general_case_expand_middle_dim( +// CHECK-SAME: %[[ARG_0:.*]]: tensor, +// CHECK-SAME: %[[ARG_1:.*]]: tensor<4xi32>) -> tensor { +// CHECK %[[CST_0:.*]] = "tf.Const"() {value = dense<2> : tensor} : () -> tensor +// CHECK %[[VAL_0:.*]] = "tf.ExpandDims"(%[[ARG_0]], %[[CST_0]]) : (tensor, tensor) -> tensor +// CHECK %[[VAL_1:.*]] = "tf.BroadcastTo"(%[[VAL_0]], %[[ARG_1]]) : (tensor, tensor<4xi32>) -> tensor +// CHECK return %[[VAL_1]] : tensor +func.func @dynamic_broadcast_in_dim_general_case_expand_middle_dim(%arg0: tensor, %arg1: tensor<4xi32>) -> tensor { + %0 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %arg1) {broadcast_dimensions = dense<[0, 1, 3]> : tensor<3xi64>} : (tensor, tensor<4xi32>) -> tensor + func.return %0 : tensor +} + // CHECK-LABEL: func @convert_dot_general( // CHECK-SAME: %[[VAL_0:.*]]: tensor<3x2x6x5x1xf32>, // CHECK-SAME: %[[VAL_1:.*]]: tensor<3x2x4x6xf32>) -> tensor<3x5x1x4xf32> { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc b/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc index 87586a57bbef2e..a69aef8edc5695 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc @@ -32,6 +32,7 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/STLForwardCompat.h" #include "llvm/ADT/Sequence.h" +#include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/Casting.h" @@ -3741,6 +3742,33 @@ arith::ConstantOp ExpandedShape(PatternRewriter& rewriter, Value input, return rewriter.create(output.getLoc(), attr_type, attr); } +Value ExpandedDynamicShape(PatternRewriter& rewriter, Value input, + DenseIntElementsAttr broadcast_dimensions, + Value output) { + assert(output.getType().cast() && + "output type must be of ShapedType"); + int64_t output_rank = output.getType().cast().getRank(); + llvm::SmallVector expanded_dimensions; + llvm::SmallSet broadcast_dimensions_values; + for (auto x : llvm::enumerate(broadcast_dimensions)) { + broadcast_dimensions_values.insert(x.value().getSExtValue()); + } + for (int64_t i = 0; i < output_rank; i++) { + if (!broadcast_dimensions_values.contains(i)) { + expanded_dimensions.push_back(i); + } + } + Value expanded_input = input; + for (int64_t i : expanded_dimensions) { + auto index_attr = DenseIntElementsAttr::get( + RankedTensorType::get({}, rewriter.getI64Type()), {i}); + Value index = rewriter.create(output.getLoc(), index_attr); + expanded_input = rewriter.create(output.getLoc(), + expanded_input, index); + } + return expanded_input; +} + #include "tensorflow/compiler/mlir/tensorflow/transforms/generated_legalize_hlo.inc" /// Performs the lowering to TF dialect. diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo_patterns.td b/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo_patterns.td index 0261783da7c33f..7ac8934e5bb915 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo_patterns.td +++ b/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo_patterns.td @@ -39,6 +39,7 @@ def IsNotTFStyleBroadcast : Constraint>, // Return intermediate shape before broadcasting, wrapped in a constant op. def ExpandedShape : NativeCodeCall<"ExpandedShape($_builder, $0, $1, $2)">; +def ExpandedDynamicShape : NativeCodeCall<"ExpandedDynamicShape($_builder, $0, $1, $2)">; def : Pat<(MHLO_ConstantOp:$output $value), (TF_ConstOp $value), [(TF_Tensor $output)]>; @@ -183,6 +184,17 @@ def : Pat<(MHLO_BroadcastInDimOp:$output $input, $broadcast_dimensions), (ExpandedShape $input, $broadcast_dimensions, $output)), (ShapeToConst $output)), [(IsNotTFStyleBroadcast $broadcast_dimensions, $output)]>; +// Dynamism op +def : Pat<(MHLO_DynamicBroadcastInDimOp:$output $input, $output_dimensions, + $broadcast_dimensions, $expanding_dimensions_unused, $nonexpanding_dimensions_unused), + (TF_BroadcastToOp $input, $output_dimensions), + [(IsTFStyleBroadcast $broadcast_dimensions, $output)]>; +def : Pat<(MHLO_DynamicBroadcastInDimOp:$output $input, $output_dimensions, + $broadcast_dimensions, $expanding_dimensions_unused, $nonexpanding_dimensions_unused), + (TF_BroadcastToOp (ExpandedDynamicShape $input, $broadcast_dimensions, $output), $output_dimensions), + [(IsNotTFStyleBroadcast $broadcast_dimensions, $output)]>; + + def : Pat<(MHLO_TransposeOp $arg, $permutation), (TF_TransposeOp $arg, (TF_ConstOp $permutation))>; def : Pat<(MHLO_ReverseOp $op, $dims), (TF_ReverseV2Op $op, (TF_ConstOp $dims))>; From bf191945497badb732b0b50bc10b79ba9a4baf6c Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 14 Jul 2023 18:30:28 -0700 Subject: [PATCH 342/376] Expose ShapeVerifier::CheckParameterCount() as a protected static function. PiperOrigin-RevId: 548268215 --- .../compiler/xla/service/hlo_verifier.cc | 51 +++++++++++++++++++ .../compiler/xla/service/hlo_verifier.h | 19 +++++++ tensorflow/compiler/xla/shape_util.cc | 22 ++++++++ tensorflow/compiler/xla/shape_util.h | 5 ++ 4 files changed, 97 insertions(+) diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index e8fb8526164f12..ec99f4b580f457 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -849,6 +849,26 @@ Status ShapeVerifier::CheckIsTokenOperand(const HloInstruction* instruction, return OkStatus(); } +Status ShapeVerifier::CheckShardedParameter( + const HloInstruction* operand, const HloInstruction* sharded_parameter, + int64_t num_shards) { + TF_RET_CHECK(num_shards > 0); + Shape unsharded_parameter_shape = + ShapeUtil::GetUnshardedShape(sharded_parameter->shape(), num_shards); + + if (!ShapesSame(operand->shape(), unsharded_parameter_shape)) { + return InternalError( + "Operand %s shape: %s does not match sharded parameter %s expected " + "shape: %s, actual shape: %s " + "num shards: %d", + operand->name(), operand->shape().ToString(), sharded_parameter->name(), + operand->shape().ToString(), unsharded_parameter_shape.ToString(), + num_shards); + } + + return OkStatus(); +} + Status ShapeVerifier::CheckOperandAndParameter( const HloInstruction* instruction, int64_t operand_number, const HloComputation* computation, int64_t parameter_number) { @@ -863,6 +883,19 @@ Status ShapeVerifier::CheckOperandAndParameter( return OkStatus(); } +Status ShapeVerifier::CheckOperandAndShardedParameter( + const HloInstruction* instruction, int64_t operand_number, + const HloComputation* computation, int64_t parameter_number, + int64_t num_shards) { + TF_RET_CHECK(num_shards > 0); + const HloInstruction* operand = instruction->operand(operand_number); + const HloInstruction* parameter = + computation->parameter_instruction(parameter_number); + // In the case of verifying a sharded called computation parameter, check that + // the parameter is correctly sharded amongst the specified number of shards. + return CheckShardedParameter(operand, parameter, num_shards); +} + Status ShapeVerifier::HandleInfeed(HloInstruction* instruction) { HloInfeedInstruction* infeed = Cast(instruction); TF_RETURN_IF_ERROR(CheckIsTokenOperand(instruction, 0)); @@ -1312,6 +1345,24 @@ Status ShapeVerifier::HandleCall(HloInstruction* call) { return CheckShape(call, call->to_apply()->root_instruction()->shape()); } +Status ShapeVerifier::VerifyShardedCall(const HloInstruction* call, + int64_t num_shards) { + TF_RET_CHECK(num_shards > 0); + TF_RETURN_IF_ERROR( + CheckParameterCount(call, call->to_apply(), call->operand_count())); + for (int64_t i = 0; i < call->to_apply()->num_parameters(); ++i) { + TF_RETURN_IF_ERROR(CheckOperandAndShardedParameter( + call, i, call->to_apply(), i, num_shards)); + } + // The shape of kCall should match the shape of the computation it calls. + // In the case of verifying a sharded called computation, check that the + // output is correctly sharded amongst the specified number of shards. + const HloComputation* to_apply_computation = call->to_apply(); + Shape unsharded_output_shape = ShapeUtil::GetUnshardedShape( + to_apply_computation->root_instruction()->shape(), num_shards); + return CheckShape(call, unsharded_output_shape); +} + Status ShapeVerifier::HandleCustomCall(HloInstruction* instruction) { const HloCustomCallInstruction* custom_call = DynCast(instruction); diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h index bfcd46ef3d3d98..a55f10e10e3778 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.h +++ b/tensorflow/compiler/xla/service/hlo_verifier.h @@ -252,6 +252,14 @@ class ShapeVerifier : public DfsHloVisitor { Status CheckTernaryShape(const HloInstruction* instruction); Status CheckVariadicShape(const HloInstruction* instruction); + Status VerifyShardedCall(const HloInstruction* call, int64_t num_shards); + + Status CheckOperandAndShardedParameter(const HloInstruction* instruction, + int64_t operand_number, + const HloComputation* computation, + int64_t parameter_number, + int64_t num_shards); + private: bool ShapesSameIgnoringFpPrecision(const Shape& a, const Shape& b, bool minor_to_major_only = false) { @@ -289,6 +297,17 @@ class ShapeVerifier : public DfsHloVisitor { const HloComputation* computation, int64_t parameter_number); + // Checks that the shape of `operand` is compatible with `sharded_parameter` + // which resides within a "sharded" computation. An `operand` and + // `sharded_parameter` shape are compatible if for all of `operand` + // sub-shapes, the major dimension of the non-dynamic tensors in + // `sharded_parameter` are partitioned among `num_shards`. + // + // Precondition: `num_shards` > 1. + Status CheckShardedParameter(const HloInstruction* operand, + const HloInstruction* sharded_parameter, + int64_t num_shards); + // Checks that the shape of async op operands and results match the called // computation parameters and root. Status CheckAsyncOpComputationShapes(const HloInstruction* async_op, diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index d85a38bec85935..51d7892eee177f 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -2123,4 +2123,26 @@ int64_t ShapeUtil::ForEachState::CalculateNumSteps() const { return size; } +Shape ShapeUtil::GetUnshardedShape(const Shape& sharded_shape, + int64_t num_shards) { + if (ShapeUtil::IsScalar(sharded_shape)) { + return sharded_shape; + } + + Shape unsharded_shape = sharded_shape; + + ShapeUtil::ForEachMutableSubshape( + &unsharded_shape, + [sharded_shape, num_shards](Shape* subshape, const ShapeIndex& index) { + if (subshape->IsArray() && subshape->rank() >= 1 && + !subshape->is_dynamic()) { + const Shape& sharded_subshape = + ShapeUtil::GetSubshape(sharded_shape, index); + subshape->set_dimensions(0, + sharded_subshape.dimensions(0) * num_shards); + } + }); + return unsharded_shape; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index a499299b31c751..a08ab7703571ab 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -888,6 +888,11 @@ class ShapeUtil { // due to the tiling requirement. static int64_t ArrayDataSize(const Shape& shape); + // Returns the unsharded shape for an input `sharded_shape` that is + // partitioned among `num_shards`. + static Shape GetUnshardedShape(const Shape& sharded_shape, + int64_t num_shards); + private: // Fills *shape. Returns true on success. // REQUIRES: *shape is empty. From 46b98302ecd9e1fe63cb816ebd303a216eb07c37 Mon Sep 17 00:00:00 2001 From: Changhui Lin Date: Fri, 14 Jul 2023 19:12:48 -0700 Subject: [PATCH 343/376] Remove the flag use_bridge_for_gpu. PiperOrigin-RevId: 548273048 --- .../tests/saved_model/saved_model_test.cc | 2 -- .../mlir/tfrt/tests/xla_launch_fallback.mlir | 2 +- .../compiler/mlir/tfrt/transforms/passes.cc | 3 +-- .../mlir/tfrt/transforms/tf_to_tfrt.cc | 24 ++++++------------- .../tfrt/transforms/tfrt_pipeline_options.h | 5 ---- .../mlir/tfrt/translate/import_model.cc | 4 +--- .../tfrt/translate/tfrt_compile_options.h | 4 ---- .../core/tfrt/saved_model/saved_model.cc | 3 --- 8 files changed, 10 insertions(+), 37 deletions(-) diff --git a/tensorflow/compiler/mlir/tfrt/tests/saved_model/saved_model_test.cc b/tensorflow/compiler/mlir/tfrt/tests/saved_model/saved_model_test.cc index 3089afe93686b5..4cd8ce4b833ce3 100644 --- a/tensorflow/compiler/mlir/tfrt/tests/saved_model/saved_model_test.cc +++ b/tensorflow/compiler/mlir/tfrt/tests/saved_model/saved_model_test.cc @@ -131,7 +131,6 @@ TEST(SavedModelTest, ConvertTfMlirToBefWithXlaFuncExport) { tfrt_stub::GraphExecutionOptions options(runtime.get()); options.compile_options.device_target = TfrtDeviceInfraTarget::kGpu; - options.compile_options.use_bridge_for_gpu = true; TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr fallback_state, @@ -170,7 +169,6 @@ TEST(SavedModelTest, ConvertTfMlirToBefExportingXlaReduceWindow) { tensorflow::tfrt_stub::Runtime::Create(/*num_inter_op_threads=*/1); tfrt_stub::GraphExecutionOptions options(runtime.get()); options.compile_options.device_target = TfrtDeviceInfraTarget::kGpu; - options.compile_options.use_bridge_for_gpu = true; TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr fallback_state, diff --git a/tensorflow/compiler/mlir/tfrt/tests/xla_launch_fallback.mlir b/tensorflow/compiler/mlir/tfrt/tests/xla_launch_fallback.mlir index 3905074bfd25b6..6fee8545cfe6c4 100644 --- a/tensorflow/compiler/mlir/tfrt/tests/xla_launch_fallback.mlir +++ b/tensorflow/compiler/mlir/tfrt/tests/xla_launch_fallback.mlir @@ -1,4 +1,4 @@ -// RUN: tf-tfrt-opt -split-input-file -tf-executor-to-tfrt-pipeline="target-gpu=true use-bridge-for-gpu=true func-use-fallback-tensor=true" -tfrt-lower-tf-savedmodel=hoist-invariant-ops=true %s | FileCheck %s --dump-input=fail --dump-input-filter=all +// RUN: tf-tfrt-opt -split-input-file -tf-executor-to-tfrt-pipeline="target-gpu=true func-use-fallback-tensor=true" -tfrt-lower-tf-savedmodel=hoist-invariant-ops=true %s | FileCheck %s --dump-input=fail --dump-input-filter=all func.func private @xla_func_0(%arg0: tensor<1x3xf32>, %arg1: tensor<1x3xf32>) -> tensor<1x3xf32> attributes {tf._XlaMustCompile = true, tf._noinline = true, tf._original_func_name = "should_not_be_used"} { %1 = "tf.AddV2"(%arg0, %arg1) : (tensor<1x3xf32>, tensor<1x3xf32>) -> tensor<1x3xf32> diff --git a/tensorflow/compiler/mlir/tfrt/transforms/passes.cc b/tensorflow/compiler/mlir/tfrt/transforms/passes.cc index 4375b78fc2497f..c812cf1c9f51ef 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/passes.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/passes.cc @@ -215,8 +215,7 @@ void CreateTFExecutorToTFInvariantOptimizationPipelineHelper( } Status ValidateTfrtPipelineOptions(const TfrtPipelineOptions &options) { - if (options.target_tpurt && - (options.target_gpu || options.use_bridge_for_gpu)) { + if (options.target_tpurt && options.target_gpu) { return tensorflow::errors::Internal( "Invalid pipeline options. Targeting both TPU and GPU is not " "supported."); diff --git a/tensorflow/compiler/mlir/tfrt/transforms/tf_to_tfrt.cc b/tensorflow/compiler/mlir/tfrt/transforms/tf_to_tfrt.cc index f6b0dbc3767c9a..9b57bf04156c06 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/tf_to_tfrt.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/tf_to_tfrt.cc @@ -154,7 +154,7 @@ class FallbackExecuteOpConversion : public mlir::ConversionPattern { tfrt_compiler::FallbackConverter *fallback_converter, const mlir::SymbolTable *symbol_table, const tfrt_compiler::CostAnalysis *cost_analysis, - bool tpu_lower_to_fallback, bool target_tpurt, bool use_bridge_for_gpu) + bool tpu_lower_to_fallback, bool target_tpurt) : mlir::ConversionPattern(mlir::Pattern::MatchAnyOpTypeTag(), kFallbackBenefit, context), corert_converter_(*corert_converter), @@ -162,8 +162,7 @@ class FallbackExecuteOpConversion : public mlir::ConversionPattern { symbol_table_(*symbol_table), cost_analysis_(*cost_analysis), tpu_lower_to_fallback_(tpu_lower_to_fallback), - target_tpurt_(target_tpurt), - use_bridge_for_gpu_(use_bridge_for_gpu) {} + target_tpurt_(target_tpurt) {} LogicalResult matchAndRewrite( mlir::Operation *op, ArrayRef operands, @@ -193,7 +192,7 @@ class FallbackExecuteOpConversion : public mlir::ConversionPattern { // e.g., variable lifting. The new MLIR function will need to be exported to // the function library for runtime to use. bool use_mlir_func_name = - parsed_device_name->device_type == DEVICE_GPU && use_bridge_for_gpu_ && + parsed_device_name->device_type == DEVICE_GPU && op->getName().getStringRef().str() == "tf.XlaLaunch"; mlir::ArrayAttr op_func_attrs = corert_converter_.CreateOpFuncAttrs( @@ -293,8 +292,6 @@ class FallbackExecuteOpConversion : public mlir::ConversionPattern { const tfrt_compiler::CostAnalysis &cost_analysis_; bool tpu_lower_to_fallback_; bool target_tpurt_; - // TODO(b/260915352): Remove the flag and default to using bridge. - bool use_bridge_for_gpu_; }; mlir::LogicalResult FallbackExecuteOpConversion::ConvertToFallbackExecuteOp( @@ -332,8 +329,7 @@ mlir::LogicalResult FallbackExecuteOpConversion::ConvertToFallbackExecuteOp( // For now, we only consider GPU XLA clusters in the form of XlaLaunch for // simplicity. We could extend to support other GPU ops that cann't be XLAed. bool is_xla_launch_on_gpu = - is_gpu_op && use_bridge_for_gpu_ && - op->getName().getStringRef().str() == "tf.XlaLaunch"; + is_gpu_op && op->getName().getStringRef().str() == "tf.XlaLaunch"; if (is_xla_launch_on_gpu) { new_operands = AddGpuVariableAndInputTensorTransferOps(op, new_operands, rewriter); @@ -1444,11 +1440,11 @@ void PopulateTFToTFRTConversionPatterns( const tfrt_compiler::TensorArraySideEffectAnalysis *tensor_array_side_effect_analysis, bool func_use_fallback_tensor, bool enable_while_parallel_iterations, - bool tpu_lower_to_fallback, bool target_tpurt, bool use_bridge_for_gpu) { + bool tpu_lower_to_fallback, bool target_tpurt) { // By default, we lower all TF ops to fallback ops. patterns->add( context, corert_converter, fallback_converter, symbol_table, - cost_analysis, tpu_lower_to_fallback, target_tpurt, use_bridge_for_gpu); + cost_analysis, tpu_lower_to_fallback, target_tpurt); patterns->add(context, corert_converter); @@ -1523,7 +1519,6 @@ class TfToTfrtConversionPass enable_while_parallel_iterations_ = options.enable_while_parallel_iterations; target_gpu_ = options.target_gpu; - use_bridge_for_gpu_ = options.use_bridge_for_gpu; } TfToTfrtConversionPass(const TfToTfrtConversionPass &) {} @@ -1566,7 +1561,7 @@ class TfToTfrtConversionPass &context, &patterns, &corert_converter, &fallback_converter, &symbol_table, &cost_analysis, &tensor_array_side_effect_analysis, func_use_fallback_tensor_, enable_while_parallel_iterations_, - tpu_lower_to_fallback_, target_tpurt_, use_bridge_for_gpu_); + tpu_lower_to_fallback_, target_tpurt_); return mlir::applyPartialConversion(func, target, std::move(patterns)); } @@ -1761,11 +1756,6 @@ class TfToTfrtConversionPass llvm::cl::desc("If true, target GPU compiler passes."), llvm::cl::init(false)}; - // TODO(b/260915352): Remove the flag and default to using bridge. - Option use_bridge_for_gpu_{ - *this, "use-bridge-for-gpu", - llvm::cl::desc("If true, GPU bridge is used."), llvm::cl::init(false)}; - Option cost_threshold_{ *this, "tfrt-cost-threshold", llvm::cl::desc( diff --git a/tensorflow/compiler/mlir/tfrt/transforms/tfrt_pipeline_options.h b/tensorflow/compiler/mlir/tfrt/transforms/tfrt_pipeline_options.h index f85a700b9dcf04..0a1209f457be7e 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/tfrt_pipeline_options.h +++ b/tensorflow/compiler/mlir/tfrt/transforms/tfrt_pipeline_options.h @@ -100,11 +100,6 @@ struct TfrtPipelineOptions llvm::cl::desc("If true, target GPU compiler passes."), llvm::cl::init(false)}; - // TODO(b/260915352): Remove the flag and default to using bridge. - Option use_bridge_for_gpu{ - *this, "use-bridge-for-gpu", - llvm::cl::desc("If true, GPU bridge is used."), llvm::cl::init(false)}; - Option func_use_fallback_tensor{ *this, "func-use-fallback-tensor", llvm::cl::desc( diff --git a/tensorflow/compiler/mlir/tfrt/translate/import_model.cc b/tensorflow/compiler/mlir/tfrt/translate/import_model.cc index 55e923d3b6f400..4ff4a47feaaa15 100644 --- a/tensorflow/compiler/mlir/tfrt/translate/import_model.cc +++ b/tensorflow/compiler/mlir/tfrt/translate/import_model.cc @@ -204,8 +204,7 @@ Status ConvertTfMlirToRuntimeExecutable( return diag_handler.Combine(absl::InternalError( "Failed to process TPUPartitionedCallOp for fallback execution")); } - } else if (options.device_target == TfrtDeviceInfraTarget::kGpu && - options.use_bridge_for_gpu) { + } else if (options.device_target == TfrtDeviceInfraTarget::kGpu) { TF_RETURN_IF_ERROR(mlir::TF::RunTFXLABridge(module)); // GPU XLA clusters are wrapped in functions, which could be transformed by @@ -300,7 +299,6 @@ std::unique_ptr GetTfrtPipelineOptions( (options.device_target == TfrtDeviceInfraTarget::kTpurt); pipeline_options->target_gpu = (options.device_target == TfrtDeviceInfraTarget::kGpu); - pipeline_options->use_bridge_for_gpu = options.use_bridge_for_gpu; pipeline_options->tpu_fuse_ops = options.tpu_fuse_ops; pipeline_options->use_tpu_host_allocator_for_inputs = options.use_tpu_host_allocator_for_inputs; diff --git a/tensorflow/compiler/mlir/tfrt/translate/tfrt_compile_options.h b/tensorflow/compiler/mlir/tfrt/translate/tfrt_compile_options.h index 6a2d4a4a23d32a..619f89cfa83d71 100644 --- a/tensorflow/compiler/mlir/tfrt/translate/tfrt_compile_options.h +++ b/tensorflow/compiler/mlir/tfrt/translate/tfrt_compile_options.h @@ -147,10 +147,6 @@ struct TfrtCompileOptions { // Whether to compile to sync TFRT dialect. bool compile_to_sync_tfrt_dialect = false; - - // Whether to use bridge for GPU. - // TODO(b/260915352): Remove the flag and default to using bridge. - bool use_bridge_for_gpu = false; }; std::ostream& operator<<(std::ostream& os, const TfrtCompileOptions& options); diff --git a/tensorflow/core/tfrt/saved_model/saved_model.cc b/tensorflow/core/tfrt/saved_model/saved_model.cc index b4c7b32dcf97dd..488c27d8c1ae67 100644 --- a/tensorflow/core/tfrt/saved_model/saved_model.cc +++ b/tensorflow/core/tfrt/saved_model/saved_model.cc @@ -528,9 +528,6 @@ void UpdateCompileOptions(SavedModel::Options& options) { if (options.graph_execution_options.enable_tfrt_gpu) { options.graph_execution_options.compile_options.decompose_resource_ops = false; - // TODO(b/260915352): Remove this flag and use GPU bridge by default, and - // remove the obsolete TFRT GPU runtime as well. - options.graph_execution_options.compile_options.use_bridge_for_gpu = true; } options.graph_execution_options.compile_options From afb11fa292a99a1edbae0e53a95171954e5c17ea Mon Sep 17 00:00:00 2001 From: "ag.ramesh" Date: Fri, 14 Jul 2023 22:23:10 -0700 Subject: [PATCH 344/376] Moved code to get threadpool to a function and fixed a few typos --- .../core/kernels/mkl/mkl_avgpooling_op.cc | 10 +++---- .../core/kernels/mkl/mkl_batch_matmul_op.cc | 6 ++--- tensorflow/core/kernels/mkl/mkl_concat_op.cc | 6 ++--- .../kernels/mkl/mkl_conv_grad_filter_ops.cc | 6 ++--- .../kernels/mkl/mkl_conv_grad_input_ops.cc | 6 ++--- tensorflow/core/kernels/mkl/mkl_conv_ops.cc | 6 ++--- .../core/kernels/mkl/mkl_dequantize_op.cc | 6 ++--- tensorflow/core/kernels/mkl/mkl_einsum_op.cc | 6 ++--- .../kernels/mkl/mkl_fused_batch_norm_op.cc | 10 +++---- .../kernels/mkl/mkl_fused_instance_norm_op.cc | 6 ++--- .../core/kernels/mkl/mkl_layer_norm_op.cc | 6 ++--- tensorflow/core/kernels/mkl/mkl_matmul_op.cc | 6 ++--- .../core/kernels/mkl/mkl_matmul_op_fused.cc | 6 ++--- .../core/kernels/mkl/mkl_matmul_ops_common.h | 6 ++--- .../core/kernels/mkl/mkl_maxpooling_op.cc | 10 +++---- tensorflow/core/kernels/mkl/mkl_qmatmul_op.cc | 6 ++--- .../core/kernels/mkl/mkl_quantize_op.cc | 6 ++--- tensorflow/core/kernels/mkl/mkl_relu_op.cc | 10 +++---- .../mkl/mkl_requantize_per_channel_op.cc | 6 ++--- tensorflow/core/kernels/mkl/mkl_softmax_op.cc | 6 ++--- .../core/kernels/mkl/mkl_transpose_op.cc | 6 ++--- tensorflow/core/util/mkl_util.h | 26 +++++++++---------- 22 files changed, 59 insertions(+), 109 deletions(-) diff --git a/tensorflow/core/kernels/mkl/mkl_avgpooling_op.cc b/tensorflow/core/kernels/mkl/mkl_avgpooling_op.cc index 42c5ed61bd888d..c3ca762f54349e 100644 --- a/tensorflow/core/kernels/mkl/mkl_avgpooling_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_avgpooling_op.cc @@ -123,12 +123,10 @@ class MklAvgPoolingOp : public MklPoolingForwardOpBase { dnnl::algorithm::pooling_avg_exclude_padding, pooling_prop_kind, static_cast(this->data_format_mkldnn_), input_md, this->native_format_); - // Create the oneDNN wrapper over eigen threapool and set max threads + // Create the oneDNN wrapper over Eigen threadpool and set max threads // in oneDNN. Eigen::ThreadPoolInterface* eigen_interface = - context->device() - ->tensorflow_cpu_worker_threads() - ->workers->AsEigenThreadPool(); + EigenThreadPoolFromTfContext(context); tsl::OneDnnThreadPool eigen_tp(eigen_interface, ThreadPoolUseCallerThread()); pooling_fwd = MklPoolingFwdPrimitiveFactory::Get(fwdParams); @@ -348,9 +346,7 @@ class MklAvgPoolingGradOp : public MklPoolingBackwardOpBase { static_cast(this->data_format_mkldnn_), src_md, this->native_format_); Eigen::ThreadPoolInterface* eigen_interface = - context->device() - ->tensorflow_cpu_worker_threads() - ->workers->AsEigenThreadPool(); + EigenThreadPoolFromTfContext(context); tsl::OneDnnThreadPool eigen_tp(eigen_interface, ThreadPoolUseCallerThread()); MklPoolingBwdPrimitive* pooling_bwd = diff --git a/tensorflow/core/kernels/mkl/mkl_batch_matmul_op.cc b/tensorflow/core/kernels/mkl/mkl_batch_matmul_op.cc index 62e1de035821e7..972ccf26875cbd 100644 --- a/tensorflow/core/kernels/mkl/mkl_batch_matmul_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_batch_matmul_op.cc @@ -160,12 +160,10 @@ class BatchMatMulMkl : public OpKernel { out_shape, adj_x_, adj_y_); this->ExtendMklMatMulParams(ctx, *params); - // Create the oneDNN wrapper over eigen threapool and set max threads + // Create the oneDNN wrapper over Eigen threadpool and set max threads // in oneDNN. Eigen::ThreadPoolInterface* eigen_interface = - ctx->device() - ->tensorflow_cpu_worker_threads() - ->workers->AsEigenThreadPool(); + EigenThreadPoolFromTfContext(ctx); tsl::OneDnnThreadPool eigen_tp(eigen_interface, ThreadPoolUseCallerThread()); // Create or retrieve matmul primitive from cache. diff --git a/tensorflow/core/kernels/mkl/mkl_concat_op.cc b/tensorflow/core/kernels/mkl/mkl_concat_op.cc index 319f21201d470f..d5c0260b178761 100644 --- a/tensorflow/core/kernels/mkl/mkl_concat_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_concat_op.cc @@ -760,12 +760,10 @@ class MklConcatOp : public OpKernel { // then since MklDnn order is NCHW, concat_dim needs to be 1. if (are_all_mkl_inputs) concat_dim = mkl_input_shapes[0].TfDimIdx(concat_dim); - // Create the oneDNN wrapper over eigen threapool and set max threads + // Create the oneDNN wrapper over Eigen threadpool and set max threads // in oneDNN. Eigen::ThreadPoolInterface* eigen_interface = - context->device() - ->tensorflow_cpu_worker_threads() - ->workers->AsEigenThreadPool(); + EigenThreadPoolFromTfContext(context); tsl::OneDnnThreadPool eigen_tp(eigen_interface, ThreadPoolUseCallerThread()); if (!inputs.empty()) { diff --git a/tensorflow/core/kernels/mkl/mkl_conv_grad_filter_ops.cc b/tensorflow/core/kernels/mkl/mkl_conv_grad_filter_ops.cc index 7728c9aaf53cee..8d8c9b40451f17 100644 --- a/tensorflow/core/kernels/mkl/mkl_conv_grad_filter_ops.cc +++ b/tensorflow/core/kernels/mkl/mkl_conv_grad_filter_ops.cc @@ -517,12 +517,10 @@ class MklConvCustomBackpropFilterOp // variable TF_MKL_OPTIMIZE_PRIMITIVE_MEMUSE is set to true. bool do_not_cache = MklPrimitiveFactory::IsPrimitiveMemOptEnabled(); - // Create the oneDNN wrapper over eigen threadpool and set max threads + // Create the oneDNN wrapper over Eigen threadpool and set max threads // in oneDNN. Eigen::ThreadPoolInterface* eigen_interface = - context->device() - ->tensorflow_cpu_worker_threads() - ->workers->AsEigenThreadPool(); + EigenThreadPoolFromTfContext(context); tsl::OneDnnThreadPool eigen_tp(eigen_interface, ThreadPoolUseCallerThread()); MklConvBwdFilterPrimitive* conv_bwd_filter = diff --git a/tensorflow/core/kernels/mkl/mkl_conv_grad_input_ops.cc b/tensorflow/core/kernels/mkl/mkl_conv_grad_input_ops.cc index c376b4e4ec6531..16a6db176843b1 100644 --- a/tensorflow/core/kernels/mkl/mkl_conv_grad_input_ops.cc +++ b/tensorflow/core/kernels/mkl/mkl_conv_grad_input_ops.cc @@ -470,12 +470,10 @@ class MklConvCustomBackpropInputOp (MklPrimitiveFactory::IsLegacyPlatform() || IsConv1x1StrideNot1(fwd_filter_dims, strides)); - // Create the oneDNN wrapper over eigen threadpool and set max threads + // Create the oneDNN wrapper over Eigen threadpool and set max threads // in oneDNN. Eigen::ThreadPoolInterface* eigen_interface = - context->device() - ->tensorflow_cpu_worker_threads() - ->workers->AsEigenThreadPool(); + EigenThreadPoolFromTfContext(context); tsl::OneDnnThreadPool eigen_tp(eigen_interface, ThreadPoolUseCallerThread()); MklConvBwdInputPrimitive* conv_bwd_input = diff --git a/tensorflow/core/kernels/mkl/mkl_conv_ops.cc b/tensorflow/core/kernels/mkl/mkl_conv_ops.cc index 027c3df8a20434..f3125e28c81955 100644 --- a/tensorflow/core/kernels/mkl/mkl_conv_ops.cc +++ b/tensorflow/core/kernels/mkl/mkl_conv_ops.cc @@ -876,12 +876,10 @@ class MklConvOp : public OpKernel { // TODO(intel-tf): Extend the basic parameters for data types and fusions this->ExtendConvFwdParams(context, convFwdDims); - // Create the oneDNN wrapper over eigen threadpool and set max threads + // Create the oneDNN wrapper over Eigen threadpool and set max threads // in oneDNN. Eigen::ThreadPoolInterface* eigen_interface = - context->device() - ->tensorflow_cpu_worker_threads() - ->workers->AsEigenThreadPool(); + EigenThreadPoolFromTfContext(context); tsl::OneDnnThreadPool eigen_tp(eigen_interface, ThreadPoolUseCallerThread()); conv_fwd = diff --git a/tensorflow/core/kernels/mkl/mkl_dequantize_op.cc b/tensorflow/core/kernels/mkl/mkl_dequantize_op.cc index 2f07569fc36261..c12b516074da27 100644 --- a/tensorflow/core/kernels/mkl/mkl_dequantize_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_dequantize_op.cc @@ -72,12 +72,10 @@ class MklDequantizeOp : public OpKernel { MklDnnData dst(&cpu_engine); std::shared_ptr reorder_stream; - // Create the oneDNN wrapper over eigen threadpool and set max threads + // Create the oneDNN wrapper over Eigen threadpool and set max threads // in oneDNN. Eigen::ThreadPoolInterface* eigen_interface = - ctx->device() - ->tensorflow_cpu_worker_threads() - ->workers->AsEigenThreadPool(); + EigenThreadPoolFromTfContext(ctx); tsl::OneDnnThreadPool eigen_tp(eigen_interface, ThreadPoolUseCallerThread()); reorder_stream.reset(CreateStream(&eigen_tp, cpu_engine)); diff --git a/tensorflow/core/kernels/mkl/mkl_einsum_op.cc b/tensorflow/core/kernels/mkl/mkl_einsum_op.cc index c7ba2b46cd2bee..698dcdb12ec530 100644 --- a/tensorflow/core/kernels/mkl/mkl_einsum_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_einsum_op.cc @@ -112,12 +112,10 @@ struct MklEinsumHelper { auto params = bmm.CreateMatMulParams(prefix, lhs.shape(), rhs.shape(), out_shape, trans_x, trans_y); - // Create the oneDNN wrapper over eigen threadpool and set max threads + // Create the oneDNN wrapper over Eigen threadpool and set max threads // in oneDNN. Eigen::ThreadPoolInterface* eigen_interface = - ctx->device() - ->tensorflow_cpu_worker_threads() - ->workers->AsEigenThreadPool(); + EigenThreadPoolFromTfContext(ctx); tsl::OneDnnThreadPool eigen_tp(eigen_interface, ThreadPoolUseCallerThread()); // Create or retrieve matmul primitive from cache. diff --git a/tensorflow/core/kernels/mkl/mkl_fused_batch_norm_op.cc b/tensorflow/core/kernels/mkl/mkl_fused_batch_norm_op.cc index 83a3e525609370..fc030c77bec166 100644 --- a/tensorflow/core/kernels/mkl/mkl_fused_batch_norm_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_fused_batch_norm_op.cc @@ -837,12 +837,10 @@ class MklFusedBatchNormOp : public OpKernel { MklBatchNormFwdParams fwdParams(src_dims, depth_, epsilon_, is_training_, tensor_format_, src_md, activation_mode_); - // Create the oneDNN wrapper over eigen threadpool and set max threads + // Create the oneDNN wrapper over Eigen threadpool and set max threads // in oneDNN. Eigen::ThreadPoolInterface* eigen_interface = - context->device() - ->tensorflow_cpu_worker_threads() - ->workers->AsEigenThreadPool(); + EigenThreadPoolFromTfContext(context); tsl::OneDnnThreadPool eigen_tp(eigen_interface, ThreadPoolUseCallerThread()); // Get forward batch-normalization op from the primitive caching pool. @@ -1320,9 +1318,7 @@ class MklFusedBatchNormGradOp : public OpKernel { is_training_, tensor_format_, src_md, diff_dst_md); Eigen::ThreadPoolInterface* eigen_interface = - context->device() - ->tensorflow_cpu_worker_threads() - ->workers->AsEigenThreadPool(); + EigenThreadPoolFromTfContext(context); tsl::OneDnnThreadPool eigen_tp(eigen_interface, ThreadPoolUseCallerThread()); MklFusedBatchNormBwdPrimitive* bn_bwd = diff --git a/tensorflow/core/kernels/mkl/mkl_fused_instance_norm_op.cc b/tensorflow/core/kernels/mkl/mkl_fused_instance_norm_op.cc index c98fff64627e27..d0019e029ccccf 100644 --- a/tensorflow/core/kernels/mkl/mkl_fused_instance_norm_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_fused_instance_norm_op.cc @@ -71,12 +71,10 @@ class MklFusedInstanceNormOp : public OpKernel { OP_REQUIRES(ctx, FormatFromString(data_format_, &tensor_format), errors::InvalidArgument("Invalid data format")); - // Create the oneDNN wrapper over eigen threadpool and set max threads + // Create the oneDNN wrapper over Eigen threadpool and set max threads // in oneDNN. Eigen::ThreadPoolInterface* eigen_interface = - ctx->device() - ->tensorflow_cpu_worker_threads() - ->workers->AsEigenThreadPool(); + EigenThreadPoolFromTfContext(ctx); tsl::OneDnnThreadPool eigen_tp(eigen_interface, ThreadPoolUseCallerThread()); std::shared_ptr engine_stream_ptr; diff --git a/tensorflow/core/kernels/mkl/mkl_layer_norm_op.cc b/tensorflow/core/kernels/mkl/mkl_layer_norm_op.cc index ad78096ec11547..cba883c2a9d1dc 100644 --- a/tensorflow/core/kernels/mkl/mkl_layer_norm_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_layer_norm_op.cc @@ -61,12 +61,10 @@ class MklLayerNormOp : public OpKernel { "tensors are not same.")); auto cpu_engine = engine(engine::kind::cpu, 0); - // Create the oneDNN wrapper over eigen threadpool and set max threads + // Create the oneDNN wrapper over Eigen threadpool and set max threads // in oneDNN. Eigen::ThreadPoolInterface* eigen_interface = - ctx->device() - ->tensorflow_cpu_worker_threads() - ->workers->AsEigenThreadPool(); + EigenThreadPoolFromTfContext(ctx); tsl::OneDnnThreadPool eigen_tp(eigen_interface, ThreadPoolUseCallerThread()); auto cpu_stream = diff --git a/tensorflow/core/kernels/mkl/mkl_matmul_op.cc b/tensorflow/core/kernels/mkl/mkl_matmul_op.cc index b7d157a58bedd2..e4122c83b4042d 100644 --- a/tensorflow/core/kernels/mkl/mkl_matmul_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_matmul_op.cc @@ -162,12 +162,10 @@ class MklMatMulOp : public OpKernel { char char_transb = transb ? 'T' : 'N'; VLOG(2) << "MKL DNN SGEMM called"; #ifndef ENABLE_ONEDNN_OPENMP - // Create the oneDNN wrapper over eigen threadpool and set max threads + // Create the oneDNN wrapper over Eigen threadpool and set max threads // in oneDNN. Eigen::ThreadPoolInterface* eigen_interface = - ctx->device() - ->tensorflow_cpu_worker_threads() - ->workers->AsEigenThreadPool(); + EigenThreadPoolFromTfContext(ctx); tsl::OneDnnThreadPool eigen_tp(eigen_interface, ThreadPoolUseCallerThread()); // With threadpool , the runtime overhead is comparable to the kernel diff --git a/tensorflow/core/kernels/mkl/mkl_matmul_op_fused.cc b/tensorflow/core/kernels/mkl/mkl_matmul_op_fused.cc index 3d9d24e358bf86..1d388705e7ddd0 100644 --- a/tensorflow/core/kernels/mkl/mkl_matmul_op_fused.cc +++ b/tensorflow/core/kernels/mkl/mkl_matmul_op_fused.cc @@ -135,12 +135,10 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase { // Extend the basic parameters for data types and fusions. ExtendMklDnnMatMulFwdParams(ctx, matmul_params); auto st = ExecuteSingleThreadedGemm(batch, channel, k, sizeof(T)); - // Create the oneDNN wrapper over eigen threadpool and set max threads + // Create the oneDNN wrapper over Eigen threadpool and set max threads // in oneDNN. Eigen::ThreadPoolInterface* eigen_interface = - ctx->device() - ->tensorflow_cpu_worker_threads() - ->workers->AsEigenThreadPool(); + EigenThreadPoolFromTfContext(ctx); tsl::OneDnnThreadPool eigen_tp(eigen_interface, ThreadPoolUseCallerThread(), st ? 1 : -1); MklDnnMatMulFwdPrimitive* matmul_prim = diff --git a/tensorflow/core/kernels/mkl/mkl_matmul_ops_common.h b/tensorflow/core/kernels/mkl/mkl_matmul_ops_common.h index e74ad037944bb1..a9dec55e9b535b 100644 --- a/tensorflow/core/kernels/mkl/mkl_matmul_ops_common.h +++ b/tensorflow/core/kernels/mkl/mkl_matmul_ops_common.h @@ -1016,12 +1016,10 @@ void dnnl_gemm(char transa, char transb, int64_t m, int64_t n, int64_t k, MklMatMulParams params("dnnl_gemm", a_dims, b_dims, c_dims, a_strides, b_strides, c_strides); auto st = ExecuteSingleThreadedGemm(m, n, k, sizeof(T)); - // Create the oneDNN wrapper over eigen threadpool and set max threads + // Create the oneDNN wrapper over Eigen threadpool and set max threads // in oneDNN. Eigen::ThreadPoolInterface* eigen_interface = - ctx->device() - ->tensorflow_cpu_worker_threads() - ->workers->AsEigenThreadPool(); + EigenThreadPoolFromTfContext(ctx); tsl::OneDnnThreadPool eigen_tp(eigen_interface, ThreadPoolUseCallerThread(), st ? 1 : -1); MklMatMulPrimitive* matmul_prim = diff --git a/tensorflow/core/kernels/mkl/mkl_maxpooling_op.cc b/tensorflow/core/kernels/mkl/mkl_maxpooling_op.cc index 050ca1190380cc..2360cdefba407c 100644 --- a/tensorflow/core/kernels/mkl/mkl_maxpooling_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_maxpooling_op.cc @@ -143,12 +143,10 @@ class MklMaxPoolingOp : public MklPoolingForwardOpBase { pooling_prop_kind, static_cast(this->data_format_mkldnn_), input_md, this->native_format_); - // Create the oneDNN wrapper over eigen threadpool and set max threads + // Create the oneDNN wrapper over Eigen threadpool and set max threads // in oneDNN. Eigen::ThreadPoolInterface* eigen_interface = - context->device() - ->tensorflow_cpu_worker_threads() - ->workers->AsEigenThreadPool(); + EigenThreadPoolFromTfContext(context); tsl::OneDnnThreadPool eigen_tp(eigen_interface, ThreadPoolUseCallerThread()); pooling_fwd = MklPoolingFwdPrimitiveFactory::Get(fwdParams); @@ -345,9 +343,7 @@ class MklMaxPoolingGradOp : public MklPoolingBackwardOpBase { static_cast(this->data_format_mkldnn_), src_md, this->native_format_); Eigen::ThreadPoolInterface* eigen_interface = - context->device() - ->tensorflow_cpu_worker_threads() - ->workers->AsEigenThreadPool(); + EigenThreadPoolFromTfContext(context); tsl::OneDnnThreadPool eigen_tp(eigen_interface, ThreadPoolUseCallerThread()); MklPoolingBwdPrimitive* pooling_bwd = diff --git a/tensorflow/core/kernels/mkl/mkl_qmatmul_op.cc b/tensorflow/core/kernels/mkl/mkl_qmatmul_op.cc index 28f27dab0c2a34..259dfacc0bf51b 100644 --- a/tensorflow/core/kernels/mkl/mkl_qmatmul_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_qmatmul_op.cc @@ -245,12 +245,10 @@ class MklDnnQuantizedMatMulOp // Extend the basic parameters for data types and fusions. this->ExtendMklDnnMatMulFwdParams(context, matmul_fwd_dims); - // Create the oneDNN wrapper over eigen threadpool and set max threads + // Create the oneDNN wrapper over Eigen threadpool and set max threads // in oneDNN. Eigen::ThreadPoolInterface* eigen_interface = - context->device() - ->tensorflow_cpu_worker_threads() - ->workers->AsEigenThreadPool(); + EigenThreadPoolFromTfContext(context); tsl::OneDnnThreadPool eigen_tp(eigen_interface, ThreadPoolUseCallerThread()); // Get a MatMul fwd from primitive pool. diff --git a/tensorflow/core/kernels/mkl/mkl_quantize_op.cc b/tensorflow/core/kernels/mkl/mkl_quantize_op.cc index ff6ed33df0674a..b8190118a04e93 100644 --- a/tensorflow/core/kernels/mkl/mkl_quantize_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_quantize_op.cc @@ -560,12 +560,10 @@ class MklQuantizeV2Op : public OpKernel { fwdParams.post_op_params.param.push_back(scale_factor); #endif // ENABLE_ONEDNN_V3 - // Create the oneDNN wrapper over eigen threadpool and set max threads + // Create the oneDNN wrapper over Eigen threadpool and set max threads // in oneDNN. Eigen::ThreadPoolInterface* eigen_interface = - ctx->device() - ->tensorflow_cpu_worker_threads() - ->workers->AsEigenThreadPool(); + EigenThreadPoolFromTfContext(ctx); tsl::OneDnnThreadPool eigen_tp(eigen_interface, ThreadPoolUseCallerThread()); MklReorderWithScalePrimitive* reorder_prim = diff --git a/tensorflow/core/kernels/mkl/mkl_relu_op.cc b/tensorflow/core/kernels/mkl/mkl_relu_op.cc index 1d07848cc15cd8..2db0b812b8917a 100644 --- a/tensorflow/core/kernels/mkl/mkl_relu_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_relu_op.cc @@ -476,12 +476,10 @@ class MklReluOpBase : public OpKernel { // Try to get an eltwise forward primitive from caching pool MklEltwiseFwdParams fwdParams(src_dims, src_md, alg_kind, alpha_, beta_); - // Create the oneDNN wrapper over eigen threadpool and set max threads + // Create the oneDNN wrapper over Eigen threadpool and set max threads // in oneDNN. Eigen::ThreadPoolInterface* eigen_interface = - context->device() - ->tensorflow_cpu_worker_threads() - ->workers->AsEigenThreadPool(); + EigenThreadPoolFromTfContext(context); tsl::OneDnnThreadPool eigen_tp(eigen_interface, ThreadPoolUseCallerThread()); MklEltwiseFwdPrimitive* eltwise_fwd = @@ -691,9 +689,7 @@ class MklReluGradOpBase : public OpKernel { beta_, GetTypeOfInputTensorFromFwdOp()); Eigen::ThreadPoolInterface* eigen_interface = - context->device() - ->tensorflow_cpu_worker_threads() - ->workers->AsEigenThreadPool(); + EigenThreadPoolFromTfContext(context); tsl::OneDnnThreadPool eigen_tp(eigen_interface, ThreadPoolUseCallerThread()); MklEltwiseBwdPrimitive* eltwise_bwd = diff --git a/tensorflow/core/kernels/mkl/mkl_requantize_per_channel_op.cc b/tensorflow/core/kernels/mkl/mkl_requantize_per_channel_op.cc index 5d2da88a5b3313..8be29e797e9184 100644 --- a/tensorflow/core/kernels/mkl/mkl_requantize_per_channel_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_requantize_per_channel_op.cc @@ -115,12 +115,10 @@ class MklRequantizePerChannelOp : public OpKernel { cpu_engine_, scales.data()); #endif // !ENABLE_ONEDNN_V3 - // Create the oneDNN wrapper over eigen threadpool and set max threads + // Create the oneDNN wrapper over Eigen threadpool and set max threads // in oneDNN. Eigen::ThreadPoolInterface* eigen_interface = - ctx->device() - ->tensorflow_cpu_worker_threads() - ->workers->AsEigenThreadPool(); + EigenThreadPoolFromTfContext(ctx); tsl::OneDnnThreadPool eigen_tp(eigen_interface, ThreadPoolUseCallerThread()); memory::dims dims_mkl_order = diff --git a/tensorflow/core/kernels/mkl/mkl_softmax_op.cc b/tensorflow/core/kernels/mkl/mkl_softmax_op.cc index 2fd9e16f1e25a9..60624f2b7d110f 100644 --- a/tensorflow/core/kernels/mkl/mkl_softmax_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_softmax_op.cc @@ -266,12 +266,10 @@ class MklSoftmaxOp : public OpKernel { fwdParams.aarch64_counter = MklSoftmaxPrimitiveFactory::IncrementCounter(); #endif - // Create the oneDNN wrapper over eigen threadpool and set max threads + // Create the oneDNN wrapper over Eigen threadpool and set max threads // in oneDNN. Eigen::ThreadPoolInterface* eigen_interface = - context->device() - ->tensorflow_cpu_worker_threads() - ->workers->AsEigenThreadPool(); + EigenThreadPoolFromTfContext(context); tsl::OneDnnThreadPool eigen_tp(eigen_interface, ThreadPoolUseCallerThread()); MklSoftmaxPrimitive* softmax_fwd = diff --git a/tensorflow/core/kernels/mkl/mkl_transpose_op.cc b/tensorflow/core/kernels/mkl/mkl_transpose_op.cc index c7ee23e508221f..b26879dd51556a 100644 --- a/tensorflow/core/kernels/mkl/mkl_transpose_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_transpose_op.cc @@ -83,12 +83,10 @@ Status MKLTransposeND(OpKernelContext* context, const Tensor& in_tensor, out.SetUsrMem(in_dims, out_strides, out_tensor); std::vector net; - // Create the oneDNN wrapper over eigen threadpool and set max threads + // Create the oneDNN wrapper over Eigen threadpool and set max threads // in oneDNN. Eigen::ThreadPoolInterface* eigen_interface = - context->device() - ->tensorflow_cpu_worker_threads() - ->workers->AsEigenThreadPool(); + EigenThreadPoolFromTfContext(context); tsl::OneDnnThreadPool eigen_tp(eigen_interface, ThreadPoolUseCallerThread()); auto* prim = FindOrCreateReorder(in.GetUsrMem(), out.GetUsrMem()); diff --git a/tensorflow/core/util/mkl_util.h b/tensorflow/core/util/mkl_util.h index 7dffa8e347e0ac..95564fe2c15b20 100644 --- a/tensorflow/core/util/mkl_util.h +++ b/tensorflow/core/util/mkl_util.h @@ -649,6 +649,13 @@ class MklDnnShape { } }; +inline Eigen::ThreadPoolInterface* EigenThreadPoolFromTfContext( + OpKernelContext* context) { + return context->device() + ->tensorflow_cpu_worker_threads() + ->workers->AsEigenThreadPool(); +} + // List of MklShape objects. Used in Concat/Split layers. typedef std::vector MklDnnShapeList; @@ -663,14 +670,12 @@ inline void ExecutePrimitive(const std::vector& net, DCHECK(net_args); DCHECK_EQ(net.size(), net_args->size()); std::unique_ptr cpu_stream; - // Create the oneDNN wrapper over eigen threadpool and set max threads + // Create the oneDNN wrapper over Eigen threadpool and set max threads // in oneDNN. tsl::OneDnnThreadPool eigen_tp; if (context != nullptr) { Eigen::ThreadPoolInterface* eigen_interface = - context->device() - ->tensorflow_cpu_worker_threads() - ->workers->AsEigenThreadPool(); + EigenThreadPoolFromTfContext(context); eigen_tp = tsl::OneDnnThreadPool(eigen_interface, ThreadPoolUseCallerThread()); cpu_stream.reset(CreateStream(&eigen_tp, cpu_engine)); @@ -1606,9 +1611,7 @@ class MklDnnData { tsl::OneDnnThreadPool eigen_tp; if (context != nullptr) { Eigen::ThreadPoolInterface* eigen_interface = - context->device() - ->tensorflow_cpu_worker_threads() - ->workers->AsEigenThreadPool(); + EigenThreadPoolFromTfContext(context); eigen_tp = tsl::OneDnnThreadPool(eigen_interface, ThreadPoolUseCallerThread()); cpu_stream.reset(CreateStream(&eigen_tp, prim->GetEngine())); @@ -1678,9 +1681,7 @@ class MklDnnData { tsl::OneDnnThreadPool eigen_tp; if (context != nullptr) { Eigen::ThreadPoolInterface* eigen_interface = - context->device() - ->tensorflow_cpu_worker_threads() - ->workers->AsEigenThreadPool(); + EigenThreadPoolFromTfContext(context); eigen_tp = tsl::OneDnnThreadPool(eigen_interface, ThreadPoolUseCallerThread()); cpu_stream.reset(CreateStream(&eigen_tp, prim->GetEngine())); @@ -1794,9 +1795,7 @@ class MklDnnData { tsl::OneDnnThreadPool eigen_tp; if (ctx != nullptr) { Eigen::ThreadPoolInterface* eigen_interface = - ctx->device() - ->tensorflow_cpu_worker_threads() - ->workers->AsEigenThreadPool(); + EigenThreadPoolFromTfContext(ctx); eigen_tp = tsl::OneDnnThreadPool(eigen_interface, ThreadPoolUseCallerThread()); cpu_stream.reset(CreateStream(&eigen_tp, prim->GetEngine())); @@ -2273,6 +2272,7 @@ inline bool IsConv1x1StrideNot1(memory::dims filter_dims, ((strides[0] != 1) || (strides[1] != 1))); } + #undef ARE_MEMORY_DESCS_EQUAL #undef CREATE_MEMORY_DESC_USING_STRIDES #undef GET_DATA_TYPE From b2b41f4bd62593feed5942616abe1d1f2ea12d5a Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 14 Jul 2023 22:22:36 -0700 Subject: [PATCH 345/376] Internal Code Change PiperOrigin-RevId: 548295593 --- tensorflow/compiler/mlir/lite/stablehlo/BUILD | 1 - 1 file changed, 1 deletion(-) diff --git a/tensorflow/compiler/mlir/lite/stablehlo/BUILD b/tensorflow/compiler/mlir/lite/stablehlo/BUILD index aaeb0d4dff4836..9fd1317d132c9d 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/BUILD +++ b/tensorflow/compiler/mlir/lite/stablehlo/BUILD @@ -291,7 +291,6 @@ cc_library( "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", - "@llvm-project//mlir:ShapeDialect", "@llvm-project//mlir:Transforms", ], alwayslink = 1, From b259da3d0163c4a1fa46adbfb518fd69caf1de95 Mon Sep 17 00:00:00 2001 From: Hyeontaek Lim Date: Fri, 14 Jul 2023 22:24:00 -0700 Subject: [PATCH 346/376] [IFRT] Clean up sharding tests to use a mock client Sharding tests use a test fixture that wraps a mock client. This mock client exports mock device objects that allow sharding construction and serialization/deserialization with memory safety. PiperOrigin-RevId: 548295734 --- tensorflow/compiler/xla/python/ifrt/BUILD | 17 ++- .../xla/python/ifrt/array_impl_test_lib.cc | 1 - .../xla/python/ifrt/sharding_serdes_test.cc | 61 ++------ .../compiler/xla/python/ifrt/sharding_test.cc | 130 ++++++++++-------- .../xla/python/ifrt/sharding_test_util.cc | 96 +++++++++++++ .../xla/python/ifrt/sharding_test_util.h | 55 ++++++++ .../compiler/xla/python/ifrt/support/BUILD | 1 + .../sharding_param_to_op_sharding_test.cc | 64 ++++----- .../compiler/xla/python/ifrt/test_util.cc | 18 ++- .../compiler/xla/python/ifrt/test_util.h | 5 + .../compiler/xla/python/pjrt_ifrt/BUILD | 4 +- .../pjrt_ifrt/xla_sharding_serdes_test.cc | 51 ++----- .../xla/python/pjrt_ifrt/xla_sharding_test.cc | 71 +++++----- 13 files changed, 353 insertions(+), 221 deletions(-) create mode 100644 tensorflow/compiler/xla/python/ifrt/sharding_test_util.cc create mode 100644 tensorflow/compiler/xla/python/ifrt/sharding_test_util.h diff --git a/tensorflow/compiler/xla/python/ifrt/BUILD b/tensorflow/compiler/xla/python/ifrt/BUILD index a16497218b4744..39dbf46fef0527 100644 --- a/tensorflow/compiler/xla/python/ifrt/BUILD +++ b/tensorflow/compiler/xla/python/ifrt/BUILD @@ -149,6 +149,7 @@ xla_cc_test( srcs = ["sharding_test.cc"], deps = [ ":ifrt", + ":sharding_test_util", "//tensorflow/compiler/xla/python/ifrt/ir", "//tensorflow/tsl/platform:errors", "//tensorflow/tsl/platform:status_matchers", @@ -176,6 +177,19 @@ cc_library( ], ) +cc_library( + name = "sharding_test_util", + testonly = True, + srcs = ["sharding_test_util.cc"], + hdrs = ["sharding_test_util.h"], + deps = [ + ":ifrt", + ":mock", + ":test_util", + "//tensorflow/tsl/platform:test", + ], +) + cc_library( name = "no_impl_test_main", testonly = 1, @@ -333,10 +347,9 @@ xla_cc_test( srcs = ["sharding_serdes_test.cc"], deps = [ ":ifrt", - ":mock", ":serdes", ":sharding_serdes", - "@com_google_absl//absl/container:flat_hash_map", + ":sharding_test_util", "@com_google_googletest//:gtest_main", ], ) diff --git a/tensorflow/compiler/xla/python/ifrt/array_impl_test_lib.cc b/tensorflow/compiler/xla/python/ifrt/array_impl_test_lib.cc index adbcf74e6f80c2..9d1768fdc8da80 100644 --- a/tensorflow/compiler/xla/python/ifrt/array_impl_test_lib.cc +++ b/tensorflow/compiler/xla/python/ifrt/array_impl_test_lib.cc @@ -355,7 +355,6 @@ TEST(ArrayImplTest, AssembleAndDisassembleArray) { /*on_done_with_host_buffer=*/{})); std::vector> arrays({array0, array1}); - std::vector single_device_shapes({shape, shape}); Shape assembled_shape({4, 3}); ShardingParam sharding_param( /*dim_shards=*/{2, 1}, {/*permutation=*/{0, 1}, /*axis_sizes=*/{2, 1}}); diff --git a/tensorflow/compiler/xla/python/ifrt/sharding_serdes_test.cc b/tensorflow/compiler/xla/python/ifrt/sharding_serdes_test.cc index 90efc6d9667167..3ea3aad6478fa6 100644 --- a/tensorflow/compiler/xla/python/ifrt/sharding_serdes_test.cc +++ b/tensorflow/compiler/xla/python/ifrt/sharding_serdes_test.cc @@ -21,10 +21,9 @@ limitations under the License. #include #include -#include "absl/container/flat_hash_map.h" -#include "tensorflow/compiler/xla/python/ifrt/mock.h" #include "tensorflow/compiler/xla/python/ifrt/serdes.h" #include "tensorflow/compiler/xla/python/ifrt/sharding.h" +#include "tensorflow/compiler/xla/python/ifrt/sharding_test_util.h" namespace xla { namespace ifrt { @@ -32,45 +31,11 @@ namespace { using ::testing::ElementsAreArray; -// Test fixture for sharding serialization and deserialization. It makes a mock -// client with a number of fake devices. Client implements `devices()` and -// `LookupDevice()`, and Device implements `id()`, with an arbitrary device ids -// assigned. -class ShardingSerDesTest : public ::testing::TestWithParam { - public: - void SetUp() override { - const int num_devices = GetParam(); - device_map_.reserve(num_devices); - devices_.reserve(num_devices); - for (int i = 0; i < num_devices; ++i) { - auto device = std::make_unique(); - ON_CALL(*device, id).WillByDefault([i]() { return i + 10; }); - devices_.push_back(device.get()); - device_map_.insert({i + 10, std::move(device)}); - } - client_ = std::make_unique(); - ON_CALL(*client_, devices) - .WillByDefault( - [this]() -> absl::Span { return devices_; }); - ON_CALL(*client_, LookupDevice) - .WillByDefault([this](int device_id) -> StatusOr { - auto it = device_map_.find(device_id); - if (it == device_map_.end()) { - return InvalidArgument("Unexpected device id: %d", device_id); - } - return it->second.get(); - }); - } - Client* client() { return client_.get(); } - - private: - std::unique_ptr client_; - absl::flat_hash_map> device_map_; - std::vector devices_; -}; +class ShardingSerDesTest : public test_util::ShardingTest {}; TEST_P(ShardingSerDesTest, SingleDeviceShardingRoundTrip) { - auto sharding = SingleDeviceSharding::Create(client()->devices().front()); + auto sharding = + SingleDeviceSharding::Create(GetDevices({0}).devices().front()); TF_ASSERT_OK_AND_ASSIGN(auto serialized, Serialize(*sharding)); @@ -87,8 +52,7 @@ TEST_P(ShardingSerDesTest, SingleDeviceShardingRoundTrip) { } TEST_P(ShardingSerDesTest, OpaqueShardingRoundTrip) { - auto sharding = OpaqueSharding::Create(DeviceList(DeviceList::Devices( - client()->devices().begin(), client()->devices().end()))); + auto sharding = OpaqueSharding::Create(GetDevices({0, 1})); TF_ASSERT_OK_AND_ASSIGN(auto serialized, Serialize(*sharding)); @@ -105,8 +69,7 @@ TEST_P(ShardingSerDesTest, OpaqueShardingRoundTrip) { TEST_P(ShardingSerDesTest, ConcreteShardingRoundTrip) { auto sharding = ConcreteSharding::Create( - DeviceList(DeviceList::Devices(client()->devices().begin(), - client()->devices().end())), + GetDevices({0, 1}), /*shape=*/Shape({10, 20}), /*shard_shapes=*/{Shape({3, 20}), Shape({7, 20})}); @@ -128,11 +91,9 @@ TEST_P(ShardingSerDesTest, ConcreteShardingRoundTrip) { } TEST_P(ShardingSerDesTest, ConcreteEvenShardingRoundTrip) { - auto sharding = ConcreteEvenSharding::Create( - DeviceList(DeviceList::Devices(client()->devices().begin(), - client()->devices().end())), - /*shape=*/Shape({10, 20}), - /*shard_shape=*/Shape({5, 20})); + auto sharding = ConcreteEvenSharding::Create(GetDevices({0, 1}), + /*shape=*/Shape({10, 20}), + /*shard_shape=*/Shape({5, 20})); TF_ASSERT_OK_AND_ASSIGN(auto serialized, Serialize(*sharding)); @@ -150,7 +111,9 @@ TEST_P(ShardingSerDesTest, ConcreteEvenShardingRoundTrip) { EXPECT_THAT(out_sharding->shard_shape(), sharding->shard_shape()); } -INSTANTIATE_TEST_SUITE_P(NumDevices, ShardingSerDesTest, testing::Values(2)); +INSTANTIATE_TEST_SUITE_P(NumDevices, ShardingSerDesTest, + testing::Values(test_util::ShardingTestParam{ + .num_devices = 2, .num_addressable_devices = 2})); } // namespace } // namespace ifrt diff --git a/tensorflow/compiler/xla/python/ifrt/sharding_test.cc b/tensorflow/compiler/xla/python/ifrt/sharding_test.cc index e1b842fde75517..be95647993278a 100644 --- a/tensorflow/compiler/xla/python/ifrt/sharding_test.cc +++ b/tensorflow/compiler/xla/python/ifrt/sharding_test.cc @@ -24,6 +24,7 @@ limitations under the License. #include "llvm/Support/Casting.h" #include "tensorflow/compiler/xla/python/ifrt/device.h" #include "tensorflow/compiler/xla/python/ifrt/ir/sharding_param.h" +#include "tensorflow/compiler/xla/python/ifrt/sharding_test_util.h" #include "tensorflow/tsl/platform/errors.h" #include "tensorflow/tsl/platform/status_matchers.h" #include "tensorflow/tsl/platform/statusor.h" @@ -33,32 +34,31 @@ namespace ifrt { namespace { using ::testing::ElementsAre; +using ::testing::ElementsAreArray; using ::testing::HasSubstr; using ::testing::SizeIs; using ::tsl::testing::StatusIs; -DeviceList CreateDummyDevices(int count) { - DeviceList::Devices devices; - devices.reserve(count); - for (int i = 0; i < count; ++i) { - devices.push_back(reinterpret_cast(i + 1)); - } - return DeviceList(std::move(devices)); -} +class SingleDeviceShardingTest : public test_util::ShardingTest {}; +class OpaqueShardingTest : public test_util::ShardingTest {}; +class ConcreteShardingTest : public test_util::ShardingTest {}; +class ConcreteEvenShardingTest : public test_util::ShardingTest {}; +class ShardingParamShardingTest : public test_util::ShardingTest {}; -TEST(SingleDeviceShardingTest, IndexDomains) { +TEST_P(SingleDeviceShardingTest, IndexDomains) { + auto device_list = GetDevices({0}); std::shared_ptr sharding = - SingleDeviceSharding::Create(reinterpret_cast(1)); + SingleDeviceSharding::Create(device_list.devices().front()); Shape shape({10, 20}); TF_ASSERT_OK_AND_ASSIGN(auto index_domains, sharding->IndexDomains(shape)); EXPECT_THAT(index_domains, ElementsAre(IndexDomain(shape))); } -TEST(SingleDeviceShardingTest, Disassemble) { - auto device = reinterpret_cast(1); +TEST_P(SingleDeviceShardingTest, Disassemble) { + auto device_list = GetDevices({0}); std::shared_ptr sharding = - SingleDeviceSharding::Create(device); + SingleDeviceSharding::Create(device_list.devices().front()); Shape shape({10, 20}); TF_ASSERT_OK_AND_ASSIGN(auto disassembled, sharding->Disassemble(shape)); @@ -67,11 +67,12 @@ TEST(SingleDeviceShardingTest, Disassemble) { const auto& [result_shape, result_sharding] = disassembled[0]; ASSERT_EQ(shape, result_shape); ASSERT_TRUE(llvm::isa(*result_sharding)); - EXPECT_THAT(result_sharding->devices().devices(), ElementsAre(device)); + EXPECT_THAT(result_sharding->devices().devices(), + ElementsAreArray(device_list.devices())); } -TEST(OpaqueShardingTest, FailedToDisassemble) { - DeviceList device_list = CreateDummyDevices(2); +TEST_P(OpaqueShardingTest, FailedToDisassemble) { + auto device_list = GetDevices({0, 1}); std::shared_ptr sharding = OpaqueSharding::Create(device_list); @@ -82,8 +83,8 @@ TEST(OpaqueShardingTest, FailedToDisassemble) { HasSubstr("OpaqueSharding does not have shard shape information"))); } -TEST(OpaqueShardingTest, IndexDomainsFails) { - DeviceList device_list = CreateDummyDevices(2); +TEST_P(OpaqueShardingTest, IndexDomainsFails) { + auto device_list = GetDevices({0, 1}); std::shared_ptr sharding = OpaqueSharding::Create(device_list); @@ -94,8 +95,8 @@ TEST(OpaqueShardingTest, IndexDomainsFails) { HasSubstr("OpaqueSharding does not have index domain information"))); } -TEST(ConcreteShardingTest, Disassemble) { - DeviceList device_list = CreateDummyDevices(2); +TEST_P(ConcreteShardingTest, Disassemble) { + auto device_list = GetDevices({0, 1}); std::vector shard_shapes; shard_shapes.reserve(2); shard_shapes.push_back(Shape({10})); @@ -115,8 +116,8 @@ TEST(ConcreteShardingTest, Disassemble) { } } -TEST(ConcreteShardingTest, DisassembleFailsForUnexpectedShape) { - DeviceList device_list = CreateDummyDevices(2); +TEST_P(ConcreteShardingTest, DisassembleFailsForUnexpectedShape) { + auto device_list = GetDevices({0, 1}); std::vector shard_shapes; shard_shapes.reserve(2); shard_shapes.push_back(Shape({10})); @@ -129,8 +130,8 @@ TEST(ConcreteShardingTest, DisassembleFailsForUnexpectedShape) { HasSubstr("ConcreteSharding can only disassemble"))); } -TEST(ConcreteShardingTest, IndexDomainsFails) { - DeviceList device_list = CreateDummyDevices(2); +TEST_P(ConcreteShardingTest, IndexDomainsFails) { + auto device_list = GetDevices({0, 1}); std::vector shard_shapes; shard_shapes.reserve(2); shard_shapes.push_back(Shape({10})); @@ -144,8 +145,8 @@ TEST(ConcreteShardingTest, IndexDomainsFails) { "domain information"))); } -TEST(ConcreteEvenShardingTest, Disassemble) { - DeviceList device_list = CreateDummyDevices(2); +TEST_P(ConcreteEvenShardingTest, Disassemble) { + auto device_list = GetDevices({0, 1}); std::shared_ptr sharding = ConcreteEvenSharding::Create(device_list, Shape({30}), Shape({15})); @@ -161,8 +162,8 @@ TEST(ConcreteEvenShardingTest, Disassemble) { } } -TEST(ConcreteEvenShardingTest, DisassembleFailsForUnexpectedShape) { - DeviceList device_list = CreateDummyDevices(2); +TEST_P(ConcreteEvenShardingTest, DisassembleFailsForUnexpectedShape) { + auto device_list = GetDevices({0, 1}); std::shared_ptr sharding = ConcreteEvenSharding::Create(device_list, Shape({30}), Shape({15})); @@ -171,8 +172,8 @@ TEST(ConcreteEvenShardingTest, DisassembleFailsForUnexpectedShape) { HasSubstr("ConcreteEvenSharding can only disassemble"))); } -TEST(ConcreteEvenShardingTest, IndexDomainsFails) { - DeviceList device_list = CreateDummyDevices(2); +TEST_P(ConcreteEvenShardingTest, IndexDomainsFails) { + auto device_list = GetDevices({0, 1}); std::vector shard_shapes; std::shared_ptr sharding = ConcreteEvenSharding::Create(device_list, Shape({30}), Shape({15})); @@ -185,8 +186,8 @@ TEST(ConcreteEvenShardingTest, IndexDomainsFails) { "ConcreteEvenSharding does not have index domain information"))); } -TEST(ShardingParamShardingTest, CreateFailsWhenDeviceCountNotMatch) { - DeviceList device_list = CreateDummyDevices(2); +TEST_P(ShardingParamShardingTest, CreateFailsWhenDeviceCountNotMatch) { + auto device_list = GetDevices({0, 1}); ShardingParam param{/*dim_shards=*/{2, 3}, {/*permutation=*/{1, 0}, /*axis_sizes=*/{3, 2}}}; @@ -196,13 +197,12 @@ TEST(ShardingParamShardingTest, CreateFailsWhenDeviceCountNotMatch) { "ShardingParam 6 vs from DeviceList 2"))); } -TEST(ShardingParamShardingTest, Disassemble) { - DeviceList device_list = CreateDummyDevices(6); +TEST_P(ShardingParamShardingTest, Disassemble) { + auto device_list = GetDevices({0, 1, 2, 3, 4, 5}); ShardingParam param{/*dim_shards=*/{2, 3}, {/*permutation=*/{1, 0}, /*axis_sizes=*/{3, 2}}}; - TF_ASSERT_OK_AND_ASSIGN( - std::shared_ptr param_sharding, - ShardingParamSharding::Create(param, CreateDummyDevices(6))); + TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr param_sharding, + ShardingParamSharding::Create(param, device_list)); TF_ASSERT_OK_AND_ASSIGN(auto disassembled, param_sharding->Disassemble(Shape({6, 6}))); @@ -216,12 +216,12 @@ TEST(ShardingParamShardingTest, Disassemble) { } } -TEST(ShardingParamShardingTest, DisassembleFailsWhenRankNotMatch) { +TEST_P(ShardingParamShardingTest, DisassembleFailsWhenRankNotMatch) { + auto device_list = GetDevices({0, 1, 2, 3, 4, 5}); ShardingParam param{/*dim_shards=*/{2, 3}, {/*permutation=*/{1, 0}, /*axis_sizes=*/{3, 2}}}; - TF_ASSERT_OK_AND_ASSIGN( - std::shared_ptr param_sharding, - ShardingParamSharding::Create(param, CreateDummyDevices(6))); + TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr param_sharding, + ShardingParamSharding::Create(param, device_list)); EXPECT_THAT( param_sharding->Disassemble(Shape({6, 6, 6})), @@ -230,12 +230,12 @@ TEST(ShardingParamShardingTest, DisassembleFailsWhenRankNotMatch) { "Ranks don't match. From Shape 3 vs from ShardingParam 2"))); } -TEST(ShardingParamShardingTest, DisassembleFailsForUnevenSharding) { +TEST_P(ShardingParamShardingTest, DisassembleFailsForUnevenSharding) { + auto device_list = GetDevices({0, 1, 2, 3, 4, 5}); ShardingParam param{/*dim_shards=*/{2, 3}, {/*permutation=*/{1, 0}, /*axis_sizes=*/{3, 2}}}; - TF_ASSERT_OK_AND_ASSIGN( - std::shared_ptr param_sharding, - ShardingParamSharding::Create(param, CreateDummyDevices(6))); + TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr param_sharding, + ShardingParamSharding::Create(param, device_list)); EXPECT_THAT( param_sharding->Disassemble(Shape({7, 6})), @@ -244,12 +244,12 @@ TEST(ShardingParamShardingTest, DisassembleFailsForUnevenSharding) { HasSubstr("Uneven shard is not supported. dim: 7, dim_shards: 2"))); } -TEST(ShardingParamShardingTest, IndexDomain) { +TEST_P(ShardingParamShardingTest, IndexDomain) { + auto device_list = GetDevices({0, 1, 2, 3, 4, 5}); ShardingParam param{/*dim_shards=*/{2, 3}, {/*permutation=*/{0, 1}, /*axis_sizes=*/{2, 3}}}; - TF_ASSERT_OK_AND_ASSIGN( - std::shared_ptr param_sharding, - ShardingParamSharding::Create(param, CreateDummyDevices(6))); + TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr param_sharding, + ShardingParamSharding::Create(param, device_list)); TF_ASSERT_OK_AND_ASSIGN(auto index_domains, param_sharding->IndexDomains(Shape({6, 6}))); @@ -262,12 +262,12 @@ TEST(ShardingParamShardingTest, IndexDomain) { IndexDomain(Index({3, 4}), Shape({3, 2})))); } -TEST(ShardingParamShardingTest, IndexDomainWithPermutation) { +TEST_P(ShardingParamShardingTest, IndexDomainWithPermutation) { + auto device_list = GetDevices({0, 1, 2, 3, 4, 5}); ShardingParam param{/*dim_shards=*/{2, 3}, {/*permutation=*/{1, 0}, /*axis_sizes=*/{3, 2}}}; - TF_ASSERT_OK_AND_ASSIGN( - std::shared_ptr param_sharding, - ShardingParamSharding::Create(param, CreateDummyDevices(6))); + TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr param_sharding, + ShardingParamSharding::Create(param, device_list)); TF_ASSERT_OK_AND_ASSIGN(auto index_domains, param_sharding->IndexDomains(Shape({6, 6}))); @@ -280,12 +280,12 @@ TEST(ShardingParamShardingTest, IndexDomainWithPermutation) { IndexDomain(Index({3, 4}), Shape({3, 2})))); } -TEST(ShardingParamShardingTest, IndexDomainWithReplication) { +TEST_P(ShardingParamShardingTest, IndexDomainWithReplication) { + auto device_list = GetDevices({0, 1, 2, 3, 4, 5}); ShardingParam param{/*dim_shards=*/{2, 1}, {/*permutation=*/{0, 1}, /*axis_sizes=*/{2, 3}}}; - TF_ASSERT_OK_AND_ASSIGN( - std::shared_ptr param_sharding, - ShardingParamSharding::Create(param, CreateDummyDevices(6))); + TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr param_sharding, + ShardingParamSharding::Create(param, device_list)); TF_ASSERT_OK_AND_ASSIGN(auto index_domains, param_sharding->IndexDomains(Shape({6, 6}))); @@ -298,6 +298,22 @@ TEST(ShardingParamShardingTest, IndexDomainWithReplication) { IndexDomain(Index({3, 0}), Shape({3, 6})))); } +INSTANTIATE_TEST_SUITE_P(NumDevices, SingleDeviceShardingTest, + testing::Values(test_util::ShardingTestParam{ + .num_devices = 6, .num_addressable_devices = 6})); +INSTANTIATE_TEST_SUITE_P(NumDevices, OpaqueShardingTest, + testing::Values(test_util::ShardingTestParam{ + .num_devices = 6, .num_addressable_devices = 6})); +INSTANTIATE_TEST_SUITE_P(NumDevices, ConcreteShardingTest, + testing::Values(test_util::ShardingTestParam{ + .num_devices = 6, .num_addressable_devices = 6})); +INSTANTIATE_TEST_SUITE_P(NumDevices, ConcreteEvenShardingTest, + testing::Values(test_util::ShardingTestParam{ + .num_devices = 6, .num_addressable_devices = 6})); +INSTANTIATE_TEST_SUITE_P(NumDevices, ShardingParamShardingTest, + testing::Values(test_util::ShardingTestParam{ + .num_devices = 6, .num_addressable_devices = 4})); + } // namespace } // namespace ifrt } // namespace xla diff --git a/tensorflow/compiler/xla/python/ifrt/sharding_test_util.cc b/tensorflow/compiler/xla/python/ifrt/sharding_test_util.cc new file mode 100644 index 00000000000000..e43c363eff7f41 --- /dev/null +++ b/tensorflow/compiler/xla/python/ifrt/sharding_test_util.cc @@ -0,0 +1,96 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/python/ifrt/sharding_test_util.h" + +#include +#include +#include + +#include "tensorflow/compiler/xla/python/ifrt/device.h" +#include "tensorflow/compiler/xla/python/ifrt/mock.h" +#include "tensorflow/compiler/xla/python/ifrt/test_util.h" +#include "tensorflow/tsl/platform/test.h" + +namespace xla { +namespace ifrt { +namespace test_util { + +namespace { + +using ::testing::Return; + +// Internal state of a client for sharding tests. +struct ShardingTestClientState { + // Mapping from a device ID to the mock device object. + absl::flat_hash_map> device_map; + // Raw pointers to mock devices. + std::vector devices; +}; + +// Creates a mock client for sharding tests. The client will have a specified +// number of fake addressable and non-addressable devices. Client implements +// `devices()` and `LookupDevice()`. Device implements `id()`, with an arbitrary +// deterministic device ids assigned. +std::shared_ptr MakeShardingTestClient( + int num_devices, int num_addressable_devices) { + auto state = std::make_shared(); + state->device_map.reserve(num_devices); + state->devices.reserve(num_devices); + + for (int i = 0; i < num_addressable_devices; ++i) { + auto device = std::make_unique(); + ON_CALL(*device, id).WillByDefault(Return(i + 10)); + ON_CALL(*device, IsAddressable).WillByDefault(Return(true)); + state->devices.push_back(device.get()); + state->device_map.insert({i + 10, std::move(device)}); + } + for (int i = num_addressable_devices; i < num_devices; ++i) { + auto device = std::make_unique(); + ON_CALL(*device, id).WillByDefault(Return(i + 10)); + ON_CALL(*device, IsAddressable).WillByDefault(Return(false)); + state->devices.push_back(device.get()); + state->device_map.insert({i + 10, std::move(device)}); + } + + auto client = std::make_shared(); + ON_CALL(*client, devices) + .WillByDefault( + [state]() -> absl::Span { return state->devices; }); + ON_CALL(*client, LookupDevice) + .WillByDefault([state](int device_id) -> StatusOr { + auto it = state->device_map.find(device_id); + if (it == state->device_map.end()) { + return InvalidArgument("Unexpected device id: %d", device_id); + } + return it->second.get(); + }); + return client; +} + +} // namespace + +void ShardingTest::SetUp() { + const auto [num_devices, num_addressable_devices] = GetParam(); + client_ = MakeShardingTestClient(num_devices, num_addressable_devices); +} + +DeviceList ShardingTest::GetDevices(absl::Span device_indices) { + return test_util::GetDevices(client_.get(), device_indices).value(); +} + +} // namespace test_util +} // namespace ifrt +} // namespace xla diff --git a/tensorflow/compiler/xla/python/ifrt/sharding_test_util.h b/tensorflow/compiler/xla/python/ifrt/sharding_test_util.h new file mode 100644 index 00000000000000..b7ba399c1689e8 --- /dev/null +++ b/tensorflow/compiler/xla/python/ifrt/sharding_test_util.h @@ -0,0 +1,55 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_IFRT_SHARDING_TEST_UTIL_H_ +#define TENSORFLOW_COMPILER_XLA_PYTHON_IFRT_SHARDING_TEST_UTIL_H_ + +#include + +#include "tensorflow/compiler/xla/python/ifrt/client.h" +#include "tensorflow/tsl/platform/test.h" + +namespace xla { +namespace ifrt { +namespace test_util { + +// Parameters for ShardingTest. +// Requests `num_devices` total devices, where `num_addressable_devices` of them +// are addressable, and the rest of devices are non-addressable. +struct ShardingTestParam { + int num_devices; + int num_addressable_devices; +}; + +// Test fixture for sharding tests. +class ShardingTest : public testing::TestWithParam { + public: + void SetUp() override; + Client* client() { return client_.get(); } + + // Returns `DeviceList` containing devices at given indexes (not ids) within + // `client.devices()`. + // REQUIRES: 0 <= device_indices[i] < num_devices + DeviceList GetDevices(absl::Span device_indices); + + private: + std::shared_ptr client_; +}; + +} // namespace test_util +} // namespace ifrt +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_PYTHON_IFRT_SHARDING_TEST_UTIL_H_ diff --git a/tensorflow/compiler/xla/python/ifrt/support/BUILD b/tensorflow/compiler/xla/python/ifrt/support/BUILD index 448dca40f6791e..08ecda728a4614 100644 --- a/tensorflow/compiler/xla/python/ifrt/support/BUILD +++ b/tensorflow/compiler/xla/python/ifrt/support/BUILD @@ -29,6 +29,7 @@ xla_cc_test( "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/hlo/ir:hlo", "//tensorflow/compiler/xla/python/ifrt", + "//tensorflow/compiler/xla/python/ifrt:sharding_test_util", "//tensorflow/compiler/xla/python/ifrt/ir", "//tensorflow/tsl/platform:errors", "//tensorflow/tsl/platform:status_matchers", diff --git a/tensorflow/compiler/xla/python/ifrt/support/sharding_param_to_op_sharding_test.cc b/tensorflow/compiler/xla/python/ifrt/support/sharding_param_to_op_sharding_test.cc index ad61143c31cc4f..94d06804e167aa 100644 --- a/tensorflow/compiler/xla/python/ifrt/support/sharding_param_to_op_sharding_test.cc +++ b/tensorflow/compiler/xla/python/ifrt/support/sharding_param_to_op_sharding_test.cc @@ -15,7 +15,6 @@ limitations under the License. #include "tensorflow/compiler/xla/python/ifrt/support/sharding_param_to_op_sharding.h" -#include #include #include #include @@ -29,6 +28,7 @@ limitations under the License. #include "tensorflow/compiler/xla/python/ifrt/ir/sharding_param.h" #include "tensorflow/compiler/xla/python/ifrt/shape.h" #include "tensorflow/compiler/xla/python/ifrt/sharding.h" +#include "tensorflow/compiler/xla/python/ifrt/sharding_test_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/tsl/platform/errors.h" @@ -49,15 +49,6 @@ StatusOr ToHloSharding(const ShardingParam& sharding_param, return xla::HloSharding::FromProto(op_sharding); } -DeviceList CreateDummyDevices(int count) { - DeviceList::Devices devices; - devices.reserve(count); - for (int i = 0; i < count; ++i) { - devices.push_back(reinterpret_cast(i + 1)); - } - return DeviceList(std::move(devices)); -} - TEST(ShardingParamToOpShardingTest, Replicated) { ShardingParam sharding_param{/*dim_shards=*/{1, 1, 1}, {/*permutation=*/{0, 1}, /*axis_sizes=*/{2, 3}}}; @@ -114,27 +105,34 @@ TEST(ShardingParamToOpShardingTest, ErrorOnDeviceAssignment) { StatusIs(tsl::error::OUT_OF_RANGE, "Can't map device 5")); } -void AssertSameTiling(const ShardingParam& sharding_param, - const HloSharding& hlo_sharding, const Shape& shape) { - TF_ASSERT_OK_AND_ASSIGN( - std::shared_ptr sharding, - ShardingParamSharding::Create(sharding_param, CreateDummyDevices(6))); - const xla::Shape xla_shape(PrimitiveType::F16, shape.dims(), {}, {}); - - TF_ASSERT_OK_AND_ASSIGN(const std::vector index_domains, - sharding->IndexDomains(shape)); - ASSERT_EQ(index_domains.size(), - hlo_sharding.tile_assignment().num_elements()); - const xla::Shape xla_tile_shape = hlo_sharding.TileShape(xla_shape); - for (int i = 0; i < index_domains.size(); ++i) { - SCOPED_TRACE(absl::StrCat("on device ", i)); - EXPECT_EQ(index_domains[i].origin().elements(), - hlo_sharding.TileOffsetForDevice(xla_shape, i)); - EXPECT_EQ(index_domains[i].shape().dims(), xla_tile_shape.dimensions()); +class ShardingParamToOpShardingEquivalentTest : public test_util::ShardingTest { + public: + void AssertSameTiling(const ShardingParam& sharding_param, + const HloSharding& hlo_sharding, const Shape& shape) { + auto device_list = GetDevices({0, 1, 2, 3, 4, 5}); + TF_ASSERT_OK_AND_ASSIGN( + std::shared_ptr sharding, + ShardingParamSharding::Create(sharding_param, device_list)); + const xla::Shape xla_shape(PrimitiveType::F16, shape.dims(), {}, {}); + + TF_ASSERT_OK_AND_ASSIGN(const std::vector index_domains, + sharding->IndexDomains(shape)); + ASSERT_EQ(index_domains.size(), + hlo_sharding.tile_assignment().num_elements()); + const xla::Shape xla_tile_shape = hlo_sharding.TileShape(xla_shape); + for (int i = 0; i < index_domains.size(); ++i) { + SCOPED_TRACE(absl::StrCat("on device ", i)); + EXPECT_EQ(index_domains[i].origin().elements(), + hlo_sharding.TileOffsetForDevice(xla_shape, i)); + EXPECT_EQ(index_domains[i].shape().dims(), xla_tile_shape.dimensions()); + } } -} -TEST(ShardingParamToOpShardingEquivalentTest, FullySharded) { + private: + std::shared_ptr client_; +}; + +TEST_P(ShardingParamToOpShardingEquivalentTest, FullySharded) { ShardingParam sharding_param{/*dim_shards=*/{2, 3}, {/*permutation=*/{0, 1}, /*axis_sizes=*/{2, 3}}}; TF_ASSERT_OK_AND_ASSIGN(const xla::HloSharding hlo_sharding, @@ -142,7 +140,7 @@ TEST(ShardingParamToOpShardingEquivalentTest, FullySharded) { AssertSameTiling(sharding_param, hlo_sharding, Shape({6, 6})); } -TEST(ShardingParamToOpShardingEquivalentTest, WithPermutation) { +TEST_P(ShardingParamToOpShardingEquivalentTest, WithPermutation) { ShardingParam sharding_param{/*dim_shards=*/{2, 3}, {/*permutation=*/{1, 0}, /*axis_sizes=*/{3, 2}}}; TF_ASSERT_OK_AND_ASSIGN(const xla::HloSharding hlo_sharding, @@ -150,7 +148,7 @@ TEST(ShardingParamToOpShardingEquivalentTest, WithPermutation) { AssertSameTiling(sharding_param, hlo_sharding, Shape({6, 6})); } -TEST(ShardingParamToOpShardingEquivalentTest, WithReplication) { +TEST_P(ShardingParamToOpShardingEquivalentTest, WithReplication) { ShardingParam sharding_param{/*dim_shards=*/{2, 1}, {/*permutation=*/{0, 1}, /*axis_sizes=*/{2, 3}}}; TF_ASSERT_OK_AND_ASSIGN(const xla::HloSharding hlo_sharding, @@ -158,6 +156,10 @@ TEST(ShardingParamToOpShardingEquivalentTest, WithReplication) { AssertSameTiling(sharding_param, hlo_sharding, Shape({6, 6})); } +INSTANTIATE_TEST_SUITE_P(NumDevices, ShardingParamToOpShardingEquivalentTest, + testing::Values(test_util::ShardingTestParam{ + .num_devices = 6, .num_addressable_devices = 4})); + } // namespace } // namespace support } // namespace ifrt diff --git a/tensorflow/compiler/xla/python/ifrt/test_util.cc b/tensorflow/compiler/xla/python/ifrt/test_util.cc index ac0e3b53bf9bdc..4e73e6d834884d 100644 --- a/tensorflow/compiler/xla/python/ifrt/test_util.cc +++ b/tensorflow/compiler/xla/python/ifrt/test_util.cc @@ -19,10 +19,12 @@ limitations under the License. #include #include +#include "absl/strings/str_cat.h" #include "absl/synchronization/mutex.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/python/ifrt/client.h" +#include "tensorflow/compiler/xla/python/ifrt/device.h" #include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/tsl/platform/test.h" namespace xla { namespace ifrt { @@ -80,6 +82,20 @@ void SetTestFilterIfNotUserSpecified(absl::string_view custom_filter) { #endif } +absl::StatusOr GetDevices(Client* client, + absl::Span device_indices) { + DeviceList::Devices devices; + devices.reserve(device_indices.size()); + for (int device_index : device_indices) { + if (device_index < 0 || device_index >= client->devices().size()) { + return absl::InvalidArgumentError( + absl::StrCat("Out of range device index: ", device_index)); + } + devices.push_back(client->devices()[device_index]); + } + return DeviceList(std::move(devices)); +} + } // namespace test_util } // namespace ifrt } // namespace xla diff --git a/tensorflow/compiler/xla/python/ifrt/test_util.h b/tensorflow/compiler/xla/python/ifrt/test_util.h index d25a00cef294d3..b65d6c1cf948e4 100644 --- a/tensorflow/compiler/xla/python/ifrt/test_util.h +++ b/tensorflow/compiler/xla/python/ifrt/test_util.h @@ -80,6 +80,11 @@ void AssertPerShardData( } } +// Helper function that makes `DeviceList` containing devices at given +// indexes (not ids) within `client.devices()`. +absl::StatusOr GetDevices(Client* client, + absl::Span device_indices); + } // namespace test_util } // namespace ifrt } // namespace xla diff --git a/tensorflow/compiler/xla/python/pjrt_ifrt/BUILD b/tensorflow/compiler/xla/python/pjrt_ifrt/BUILD index 060be5ab05d5c9..dec1e1a9e611c7 100644 --- a/tensorflow/compiler/xla/python/pjrt_ifrt/BUILD +++ b/tensorflow/compiler/xla/python/pjrt_ifrt/BUILD @@ -140,8 +140,8 @@ xla_cc_test( ":xla_ifrt", ":xla_sharding_serdes", "//tensorflow/compiler/xla/hlo/ir:hlo", - "//tensorflow/compiler/xla/python/ifrt:mock", "//tensorflow/compiler/xla/python/ifrt:sharding_serdes", + "//tensorflow/compiler/xla/python/ifrt:sharding_test_util", "@com_google_googletest//:gtest_main", ], ) @@ -182,6 +182,8 @@ xla_cc_test( deps = [ ":tfrt_cpu_client_test_lib", ":xla_ifrt", + "//tensorflow/compiler/xla:xla_data_proto_cc", + "//tensorflow/compiler/xla/python/ifrt:sharding_test_util", "//tensorflow/compiler/xla/python/ifrt:tuple_impl_test_lib", "//tensorflow/tsl/platform:errors", "//tensorflow/tsl/platform:status_matchers", diff --git a/tensorflow/compiler/xla/python/pjrt_ifrt/xla_sharding_serdes_test.cc b/tensorflow/compiler/xla/python/pjrt_ifrt/xla_sharding_serdes_test.cc index e043fb7e575f2b..95b90c21fbba9a 100644 --- a/tensorflow/compiler/xla/python/pjrt_ifrt/xla_sharding_serdes_test.cc +++ b/tensorflow/compiler/xla/python/pjrt_ifrt/xla_sharding_serdes_test.cc @@ -15,13 +15,12 @@ limitations under the License. #include #include -#include #include #include #include "tensorflow/compiler/xla/hlo/ir/hlo_sharding.h" -#include "tensorflow/compiler/xla/python/ifrt/mock.h" #include "tensorflow/compiler/xla/python/ifrt/sharding_serdes.h" +#include "tensorflow/compiler/xla/python/ifrt/sharding_test_util.h" #include "tensorflow/compiler/xla/python/pjrt_ifrt/xla_sharding.h" namespace xla { @@ -30,49 +29,13 @@ namespace { using ::testing::ElementsAreArray; -// Test fixture for sharding serialization and deserialization. It makes a mock -// client with a number of fake devices. Client implements `devices()` and -// `LookupDevice()`, and Device implements `id()`, with an arbitrary device ids -// assigned. -class XlaShardingSerDesTest : public ::testing::TestWithParam { - public: - void SetUp() override { - const int num_devices = GetParam(); - device_map_.reserve(num_devices); - devices_.reserve(num_devices); - for (int i = 0; i < num_devices; ++i) { - auto device = std::make_unique(); - ON_CALL(*device, id).WillByDefault([i]() { return i + 10; }); - devices_.push_back(device.get()); - device_map_.insert({i + 10, std::move(device)}); - } - client_ = std::make_unique(); - ON_CALL(*client_, devices) - .WillByDefault( - [this]() -> absl::Span { return devices_; }); - ON_CALL(*client_, LookupDevice) - .WillByDefault([this](int device_id) -> StatusOr { - auto it = device_map_.find(device_id); - if (it == device_map_.end()) { - return InvalidArgument("Unexpected device id: %d", device_id); - } - return it->second.get(); - }); - } - Client* client() { return client_.get(); } - - private: - std::unique_ptr client_; - absl::flat_hash_map> device_map_; - std::vector devices_; -}; +class XlaShardingSerDesTest : public test_util::ShardingTest {}; TEST_P(XlaShardingSerDesTest, HloShardingRoundTrip) { + auto device_list = GetDevices({0, 1}); auto xla_hlo_sharding = xla::HloSharding::Tile(xla::TileAssignment({2, 1})); - auto sharding = HloSharding::Create( - DeviceList(DeviceList::Devices(client()->devices().begin(), - client()->devices().end())), - /*xla_hlo_sharding=*/xla_hlo_sharding); + auto sharding = HloSharding::Create(device_list, + /*xla_hlo_sharding=*/xla_hlo_sharding); TF_ASSERT_OK_AND_ASSIGN(auto serialized, Serialize(*sharding)); @@ -88,7 +51,9 @@ TEST_P(XlaShardingSerDesTest, HloShardingRoundTrip) { EXPECT_EQ(out_sharding->xla_hlo_sharding(), sharding->xla_hlo_sharding()); } -INSTANTIATE_TEST_SUITE_P(NumDevices, XlaShardingSerDesTest, testing::Values(2)); +INSTANTIATE_TEST_SUITE_P(NumDevices, XlaShardingSerDesTest, + testing::Values(test_util::ShardingTestParam{ + .num_devices = 2, .num_addressable_devices = 2})); } // namespace } // namespace ifrt diff --git a/tensorflow/compiler/xla/python/pjrt_ifrt/xla_sharding_test.cc b/tensorflow/compiler/xla/python/pjrt_ifrt/xla_sharding_test.cc index a9104a4d77ff94..304b22247a0b3c 100644 --- a/tensorflow/compiler/xla/python/pjrt_ifrt/xla_sharding_test.cc +++ b/tensorflow/compiler/xla/python/pjrt_ifrt/xla_sharding_test.cc @@ -20,6 +20,8 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/python/ifrt/sharding_test_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/tsl/platform/errors.h" #include "tensorflow/tsl/platform/status_matchers.h" @@ -33,17 +35,10 @@ using ::testing::HasSubstr; using ::testing::SizeIs; using ::tsl::testing::StatusIs; -DeviceList CreateDummyDevices(int count) { - DeviceList::Devices devices; - devices.reserve(count); - for (int i = 0; i < count; ++i) { - devices.push_back(reinterpret_cast(i + 1)); - } - return DeviceList(std::move(devices)); -} +class HloShardingTest : public test_util::ShardingTest {}; -TEST(HloShardingTest, IndexDomainsWithReplication) { - auto device_list = CreateDummyDevices(2); +TEST_P(HloShardingTest, IndexDomainsWithReplication) { + auto device_list = GetDevices({0, 1}); // Fully replicated. auto xla_hlo_sharding = xla::HloSharding::Replicate(); std::shared_ptr sharding = @@ -59,8 +54,8 @@ TEST(HloShardingTest, IndexDomainsWithReplication) { ElementsAreArray(TEST_HloShardingIndexDomainsSlowPath(*sharding, shape))); } -TEST(HloShardingTest, DisassembleWithReplication) { - auto device_list = CreateDummyDevices(2); +TEST_P(HloShardingTest, DisassembleWithReplication) { + auto device_list = GetDevices({0, 1}); // Fully replicated. auto xla_hlo_sharding = xla::HloSharding::Replicate(); std::shared_ptr sharding = @@ -79,8 +74,8 @@ TEST(HloShardingTest, DisassembleWithReplication) { } } -TEST(HloShardingTest, IndexDomainsWithTile) { - auto device_list = CreateDummyDevices(2); +TEST_P(HloShardingTest, IndexDomainsWithTile) { + auto device_list = GetDevices({0, 1}); // 2-way sharded along axis 0, 1-way sharded along axis 1. auto xla_hlo_sharding = xla::HloSharding::Tile( xla::TileAssignment((absl::Span){2, 1})); @@ -98,8 +93,8 @@ TEST(HloShardingTest, IndexDomainsWithTile) { ElementsAreArray(TEST_HloShardingIndexDomainsSlowPath(*sharding, shape))); } -TEST(HloShardingTest, DisassembleWithTile) { - auto device_list = CreateDummyDevices(2); +TEST_P(HloShardingTest, DisassembleWithTile) { + auto device_list = GetDevices({0, 1}); // 2-way sharded along axis 0, 1-way sharded along axis 1. auto xla_hlo_sharding = xla::HloSharding::Tile( xla::TileAssignment((absl::Span){2, 1})); @@ -119,8 +114,8 @@ TEST(HloShardingTest, DisassembleWithTile) { } } -TEST(HloShardingTest, IndexDomainsWithUnevenTile) { - auto device_list = CreateDummyDevices(2); +TEST_P(HloShardingTest, IndexDomainsWithUnevenTile) { + auto device_list = GetDevices({0, 1}); // 2-way sharded along axis 0, 1-way sharded along axis 1. auto xla_hlo_sharding = xla::HloSharding::Tile( xla::TileAssignment((absl::Span){2, 1})); @@ -138,8 +133,8 @@ TEST(HloShardingTest, IndexDomainsWithUnevenTile) { ElementsAreArray(TEST_HloShardingIndexDomainsSlowPath(*sharding, shape))); } -TEST(HloShardingTest, DisassembleWithUnevenTile) { - auto device_list = CreateDummyDevices(2); +TEST_P(HloShardingTest, DisassembleWithUnevenTile) { + auto device_list = GetDevices({0, 1}); // 2-way sharded along axis 0, 1-way sharded along axis 1. auto xla_hlo_sharding = xla::HloSharding::Tile( xla::TileAssignment((absl::Span){2, 1})); @@ -163,8 +158,8 @@ TEST(HloShardingTest, DisassembleWithUnevenTile) { } } -TEST(HloShardingTest, IndexDomainsWithPartialTile) { - auto device_list = CreateDummyDevices(6); +TEST_P(HloShardingTest, IndexDomainsWithPartialTile) { + auto device_list = GetDevices({0, 1, 2, 3, 4, 5}); // 2-way sharded along axis 0, 1-way sharded along axis 1, each shard // replicated by 3 times. auto xla_hlo_sharding = @@ -187,8 +182,8 @@ TEST(HloShardingTest, IndexDomainsWithPartialTile) { ElementsAreArray(TEST_HloShardingIndexDomainsSlowPath(*sharding, shape))); } -TEST(HloShardingTest, DisassembleWithPartialTile) { - auto device_list = CreateDummyDevices(6); +TEST_P(HloShardingTest, DisassembleWithPartialTile) { + auto device_list = GetDevices({0, 1, 2, 3, 4, 5}); // 2-way sharded along axis 0, 1-way sharded along axis 1, each shard // replicated by 3 times. auto xla_hlo_sharding = @@ -209,8 +204,8 @@ TEST(HloShardingTest, DisassembleWithPartialTile) { } } -TEST(HloShardingTest, IndexDomainsWithSubgroupReplicated) { - auto device_list = CreateDummyDevices(6); +TEST_P(HloShardingTest, IndexDomainsWithSubgroupReplicated) { + auto device_list = GetDevices({0, 1, 2, 3, 4, 5}); // 2-way sharded along axis 0, 1-way sharded along axis 1, each shard // replicated by 3 times. auto xla_hlo_sharding = xla::HloSharding::Subgroup( @@ -233,8 +228,8 @@ TEST(HloShardingTest, IndexDomainsWithSubgroupReplicated) { ElementsAreArray(TEST_HloShardingIndexDomainsSlowPath(*sharding, shape))); } -TEST(HloShardingTest, DisassembleWithSubgroupReplicated) { - auto device_list = CreateDummyDevices(6); +TEST_P(HloShardingTest, DisassembleWithSubgroupReplicated) { + auto device_list = GetDevices({0, 1, 2, 3, 4, 5}); // 2-way sharded along axis 0, 1-way sharded along axis 1, each shard // replicated by 3 times. auto xla_hlo_sharding = xla::HloSharding::Subgroup( @@ -255,8 +250,8 @@ TEST(HloShardingTest, DisassembleWithSubgroupReplicated) { } } -TEST(HloShardingTest, IndexDomainsWithSubgroupMaximalSlowPath) { - auto device_list = CreateDummyDevices(6); +TEST_P(HloShardingTest, IndexDomainsWithSubgroupMaximalSlowPath) { + auto device_list = GetDevices({0, 1, 2, 3, 4, 5}); // 2-way sharded along axis 0, 1-way sharded along axis 1, each shard // maximal-replicated by 3 times, device#0 in each replication is maximal. auto xla_hlo_sharding = xla::HloSharding::Subgroup( @@ -279,8 +274,8 @@ TEST(HloShardingTest, IndexDomainsWithSubgroupMaximalSlowPath) { ElementsAreArray(TEST_HloShardingIndexDomainsSlowPath(*sharding, shape))); } -TEST(HloShardingTest, DisassembleWithSubgroupMaximalSlowPath) { - auto device_list = CreateDummyDevices(6); +TEST_P(HloShardingTest, DisassembleWithSubgroupMaximalSlowPath) { + auto device_list = GetDevices({0, 1, 2, 3, 4, 5}); // 2-way sharded along axis 0, 1-way sharded along axis 1, each shard // maximal-replicated by 3 times, device#0 in each replication is maximal. auto xla_hlo_sharding = xla::HloSharding::Subgroup( @@ -301,8 +296,8 @@ TEST(HloShardingTest, DisassembleWithSubgroupMaximalSlowPath) { } } -TEST(HloShardingTest, DisassembleFailsWithInvalidDeviceCount) { - auto device_list = CreateDummyDevices(1); +TEST_P(HloShardingTest, DisassembleFailsWithInvalidDeviceCount) { + auto device_list = GetDevices({0}); // 2-way sharded along axis 0, 1-way sharded along axis 1. auto xla_hlo_sharding = xla::HloSharding::Tile( xla::TileAssignment((absl::Span){2, 1})); @@ -316,8 +311,8 @@ TEST(HloShardingTest, DisassembleFailsWithInvalidDeviceCount) { "device count does not match: 2 vs. 1"))); } -TEST(HloShardingTest, DisassembleFailsWithMismatchingShapeDimsSize) { - auto device_list = CreateDummyDevices(2); +TEST_P(HloShardingTest, DisassembleFailsWithMismatchingShapeDimsSize) { + auto device_list = GetDevices({0, 1}); // 2-way sharded along axis 0, 1-way sharded along axis 1. auto xla_hlo_sharding = xla::HloSharding::Tile( xla::TileAssignment((absl::Span){2, 1})); @@ -332,6 +327,10 @@ TEST(HloShardingTest, DisassembleFailsWithMismatchingShapeDimsSize) { HasSubstr("shape must have 2 dimensions, but has 1 dimensions"))); } +INSTANTIATE_TEST_SUITE_P(NumDevices, HloShardingTest, + testing::Values(test_util::ShardingTestParam{ + .num_devices = 6, .num_addressable_devices = 4})); + } // namespace } // namespace ifrt } // namespace xla From 46896eaa08a1ee41ec5d08fed9b9359e4a278e00 Mon Sep 17 00:00:00 2001 From: Hyeontaek Lim Date: Fri, 14 Jul 2023 23:31:38 -0700 Subject: [PATCH 347/376] [IFRT] Use `= True` instead of `= 1` in BUILD Update IFRT BUILD files based on https://bazel.build/build/style-guide#other-conventions. PiperOrigin-RevId: 548302672 --- tensorflow/compiler/xla/python/ifrt/BUILD | 18 +++++++++--------- tensorflow/compiler/xla/python/pjrt_ifrt/BUILD | 12 ++++++------ 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/tensorflow/compiler/xla/python/ifrt/BUILD b/tensorflow/compiler/xla/python/ifrt/BUILD index 39dbf46fef0527..9e6d600f4105cf 100644 --- a/tensorflow/compiler/xla/python/ifrt/BUILD +++ b/tensorflow/compiler/xla/python/ifrt/BUILD @@ -161,7 +161,7 @@ xla_cc_test( cc_library( name = "test_util", - testonly = 1, + testonly = True, srcs = ["test_util.cc"], hdrs = ["test_util.h"], deps = [ @@ -192,7 +192,7 @@ cc_library( cc_library( name = "no_impl_test_main", - testonly = 1, + testonly = True, srcs = ["no_impl_test_main.cc"], deps = [ "@com_google_googletest//:gtest", @@ -201,7 +201,7 @@ cc_library( cc_library( name = "array_impl_test_lib", - testonly = 1, + testonly = True, srcs = ["array_impl_test_lib.cc"], deps = [ ":ifrt", @@ -211,7 +211,7 @@ cc_library( "@com_google_absl//absl/time", "@com_google_absl//absl/types:span", ], - alwayslink = 1, + alwayslink = True, ) xla_cc_test( @@ -225,14 +225,14 @@ xla_cc_test( cc_library( name = "client_impl_test_lib", - testonly = 1, + testonly = True, srcs = ["client_impl_test_lib.cc"], deps = [ ":ifrt", ":test_util", "//tensorflow/tsl/platform:test", ], - alwayslink = 1, + alwayslink = True, ) xla_cc_test( @@ -247,7 +247,7 @@ xla_cc_test( cc_library( name = "tuple_impl_test_lib", - testonly = 1, + testonly = True, srcs = ["tuple_impl_test_lib.cc"], deps = [ ":ifrt", @@ -258,7 +258,7 @@ cc_library( "@com_google_absl//absl/types:span", "@tf_runtime//:ref_count", ], - alwayslink = 1, + alwayslink = True, ) xla_cc_test( @@ -339,7 +339,7 @@ cc_library( "//tensorflow/tsl/platform:statusor", "@llvm-project//llvm:Support", ], - alwayslink = 1, + alwayslink = True, ) xla_cc_test( diff --git a/tensorflow/compiler/xla/python/pjrt_ifrt/BUILD b/tensorflow/compiler/xla/python/pjrt_ifrt/BUILD index dec1e1a9e611c7..d9f0338ae29a35 100644 --- a/tensorflow/compiler/xla/python/pjrt_ifrt/BUILD +++ b/tensorflow/compiler/xla/python/pjrt_ifrt/BUILD @@ -90,7 +90,7 @@ cc_library( "@stablehlo//:stablehlo_portable_api", "@stablehlo//:stablehlo_serialization", ], - alwayslink = 1, + alwayslink = True, ) xla_cc_test( @@ -130,7 +130,7 @@ cc_library( "//tensorflow/compiler/xla/python/ifrt:serdes", "//tensorflow/compiler/xla/python/ifrt:sharding_serdes", ], - alwayslink = 1, + alwayslink = True, ) xla_cc_test( @@ -149,7 +149,7 @@ xla_cc_test( # TODO(hyeontaek): Move this target out of pjrt_ifrt. cc_library( name = "xla_executable_impl_test_lib", - testonly = 1, + testonly = True, srcs = ["xla_executable_impl_test_lib.cc"], deps = [ ":xla_ifrt", @@ -160,7 +160,7 @@ cc_library( "//tensorflow/tsl/platform:statusor", "@com_google_absl//absl/strings", ], - alwayslink = 1, + alwayslink = True, ) # TODO(hyeontaek): Move this target out of pjrt_ifrt. @@ -238,14 +238,14 @@ cc_library( cc_library( name = "tfrt_cpu_client_test_lib", - testonly = 1, + testonly = True, srcs = ["tfrt_cpu_client_test_lib.cc"], deps = [ ":pjrt_ifrt", "//tensorflow/compiler/xla/pjrt:tfrt_cpu_pjrt_client", "//tensorflow/compiler/xla/python/ifrt:test_util", ], - alwayslink = 1, + alwayslink = True, ) xla_cc_test( From dddee9e2c223315eb363ac643f03c5249c97d4e8 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Sat, 15 Jul 2023 01:58:49 -0700 Subject: [PATCH 348/376] [xla:gpu] NFC: Prepare graph instances cache for adding eviction policy Improve logging and move graph instances cache to a heap allocated storage in preparation for gpu graph eviction policy implementation PiperOrigin-RevId: 548322917 --- .../xla/service/gpu/gpu_executable.cc | 13 +-- .../compiler/xla/service/gpu/runtime/BUILD | 1 - .../xla/service/gpu/runtime/executable.cc | 32 +++++-- .../xla/service/gpu/runtime/executable.h | 16 ++-- .../xla/service/gpu/runtime/graph_launch.cc | 89 +++++++++++++++---- .../xla/service/gpu/runtime/graph_launch.h | 39 ++++++-- 6 files changed, 148 insertions(+), 42 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc index a475ec153b6878..10595fc9bbb5b7 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc @@ -102,8 +102,9 @@ StatusOr> GpuExecutable::Create(Params params) { if (std::holds_alternative(executable)) { auto& program = std::get(executable); - TF_ASSIGN_OR_RETURN(result->gpu_runtime_executable_, - GpuRuntimeExecutable::Create(std::move(program))); + TF_ASSIGN_OR_RETURN( + result->gpu_runtime_executable_, + GpuRuntimeExecutable::Create(result->module_name_, std::move(program))); return result; } @@ -1058,10 +1059,10 @@ StatusOr> GpuExecutable::LoadFromObjFile( executable.status().message()); // Move runtime::Executable ownership to the GpuRuntimeExecutable. - TF_ASSIGN_OR_RETURN( - auto gpu_runtime_executable, - GpuRuntimeExecutable::Create(buffer_sizes, std::move(*executable), - std::move(debug_options))); + TF_ASSIGN_OR_RETURN(auto gpu_runtime_executable, + GpuRuntimeExecutable::Create( + hlo_module->name(), buffer_sizes, + std::move(*executable), std::move(debug_options))); // Construct GpuExecutable for the loaded XLA Runtime executable. std::string name = hlo_module->name(); diff --git a/tensorflow/compiler/xla/service/gpu/runtime/BUILD b/tensorflow/compiler/xla/service/gpu/runtime/BUILD index 8224b840eec663..97101796437b60 100644 --- a/tensorflow/compiler/xla/service/gpu/runtime/BUILD +++ b/tensorflow/compiler/xla/service/gpu/runtime/BUILD @@ -349,7 +349,6 @@ cc_library( "//tensorflow/tsl/profiler/lib:scoped_annotation_stack", "//tensorflow/tsl/profiler/lib:traceme", "//tensorflow/tsl/profiler/lib:traceme_encode", - "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:node_hash_map", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", diff --git a/tensorflow/compiler/xla/service/gpu/runtime/executable.cc b/tensorflow/compiler/xla/service/gpu/runtime/executable.cc index 96b7309136645b..bbe53057345cd7 100644 --- a/tensorflow/compiler/xla/service/gpu/runtime/executable.cc +++ b/tensorflow/compiler/xla/service/gpu/runtime/executable.cc @@ -138,13 +138,22 @@ void RegisterXlaGpuAttrEncoding(CustomCallAttrEncodingSet& encoding) { //===----------------------------------------------------------------------===// +// Executable can have only one "main" function and only graph capture function. +static int64_t GetNumGraphs(const runtime::Executable& executable) { + return executable.num_functions() - 1; +} + GpuRuntimeExecutable::GpuRuntimeExecutable( - std::vector buffer_sizes, + std::string module_name, std::vector buffer_sizes, std::unique_ptr jit_executable, DebugOptions debug_options, ModulesState modules_state, FfiModulesState ffi_modules_state) - : buffer_sizes_(std::move(buffer_sizes)), + : module_name_(std::move(module_name)), + buffer_sizes_(std::move(buffer_sizes)), executable_(std::move(jit_executable)), debug_options_(std::move(debug_options)), +#if GOOGLE_CUDA + graph_instances_(module_name_, GetNumGraphs(executable())), +#endif // GOOGLE_CUDA modules_state_(std::move(modules_state)), ffi_modules_state_(std::move(ffi_modules_state)) { ExportModules(dynamic_custom_calls_); // export runtime modules @@ -152,12 +161,16 @@ GpuRuntimeExecutable::GpuRuntimeExecutable( } GpuRuntimeExecutable::GpuRuntimeExecutable( - std::vector buffer_sizes, + std::string module_name, std::vector buffer_sizes, std::unique_ptr aot_executable, DebugOptions debug_options, ModulesState modules_state, FfiModulesState ffi_modules_state) - : buffer_sizes_(std::move(buffer_sizes)), + : module_name_(std::move(module_name)), + buffer_sizes_(std::move(buffer_sizes)), executable_(std::move(aot_executable)), debug_options_(std::move(debug_options)), +#if GOOGLE_CUDA + graph_instances_(module_name_, GetNumGraphs(executable())), +#endif // GOOGL_CUDA modules_state_(std::move(modules_state)), ffi_modules_state_(std::move(ffi_modules_state)) { ExportModules(dynamic_custom_calls_); // export runtime modules @@ -169,7 +182,8 @@ GpuRuntimeExecutable::GpuRuntimeExecutable( //===----------------------------------------------------------------------===// /*static*/ StatusOr> -GpuRuntimeExecutable::Create(std::unique_ptr program) { +GpuRuntimeExecutable::Create(std::string module_name, + std::unique_ptr program) { // Options for the default XLA Runtime compilation pipeline. runtime::CompilationPipelineOptions copts; @@ -223,7 +237,7 @@ GpuRuntimeExecutable::Create(std::unique_ptr program) { ffi_modules_state.status().message()); return std::unique_ptr(new GpuRuntimeExecutable( - std::move(program->buffer_sizes), + std::move(module_name), std::move(program->buffer_sizes), std::make_unique(std::move(*jit_executable)), std::move(program->debug_options), std::move(*modules_state), std::move(*ffi_modules_state))); @@ -234,7 +248,8 @@ GpuRuntimeExecutable::Create(std::unique_ptr program) { //===----------------------------------------------------------------------===// /*static*/ StatusOr> -GpuRuntimeExecutable::Create(absl::Span buffer_sizes, +GpuRuntimeExecutable::Create(std::string module_name, + absl::Span buffer_sizes, Executable executable, DebugOptions debug_options) { // Instantiate state for all registered runtime modules. @@ -250,6 +265,7 @@ GpuRuntimeExecutable::Create(absl::Span buffer_sizes, ffi_modules_state.status().message()); return std::unique_ptr(new GpuRuntimeExecutable( + std::move(module_name), std::vector(buffer_sizes.begin(), buffer_sizes.end()), std::make_unique(std::move(executable)), std::move(debug_options), std::move(*modules_state), @@ -275,7 +291,7 @@ static void InitializeCallFrame(runtime::Executable::CallFrame& call_frame, assert(ptrs.empty() && "pointers storage must be empty"); ptrs.resize_for_overwrite(num_allocations); - // Each buffer allocation pased as 1d memref to the compiled function: + // Each buffer allocation passed as 1d memref to the compiled function: // {basePtr, dataPtr, offset, [sizes, ...], [strides, ...]} size_t num_args_ptrs = 1 + num_allocations * 5; call_frame.args.resize_for_overwrite(num_args_ptrs); diff --git a/tensorflow/compiler/xla/service/gpu/runtime/executable.h b/tensorflow/compiler/xla/service/gpu/runtime/executable.h index 114c3655711ccd..0405c86db315b6 100644 --- a/tensorflow/compiler/xla/service/gpu/runtime/executable.h +++ b/tensorflow/compiler/xla/service/gpu/runtime/executable.h @@ -93,12 +93,12 @@ class GpuRuntimeExecutable { public: // Creates GpuRuntimeExecutable from the Xla Gpu Program. static StatusOr> Create( - std::unique_ptr program); + std::string module_name, std::unique_ptr program); // Creates GpuRuntimeExecutable from the AOT compiled binary. static StatusOr> Create( - absl::Span buffer_sizes, runtime::Executable executable, - DebugOptions debug_options); + std::string module_name, absl::Span buffer_sizes, + runtime::Executable executable, DebugOptions debug_options); // Executes entry function with the given buffer arguments. Status Execute(const ServiceExecutableRunOptions* run_options, @@ -115,17 +115,23 @@ class GpuRuntimeExecutable { // Returns MLIR module behind this executable if it is available. StatusOr GetMlirModule() const; + std::string_view module_name() const { return module_name_; } + private: - GpuRuntimeExecutable(std::vector buffer_sizes, + GpuRuntimeExecutable(std::string module_name, + std::vector buffer_sizes, std::unique_ptr jit_executable, DebugOptions debug_options, ModulesState modules_state, FfiModulesState ffi_modules_state); - GpuRuntimeExecutable(std::vector buffer_sizes, + GpuRuntimeExecutable(std::string module_name, + std::vector buffer_sizes, std::unique_ptr aot_executable, DebugOptions debug_options, ModulesState modules_state, FfiModulesState ffi_modules_state); + std::string module_name_; + // Depending on the state of `executable_` returns a reference to active // Xla runtime executable. runtime::Executable& executable(); diff --git a/tensorflow/compiler/xla/service/gpu/runtime/graph_launch.cc b/tensorflow/compiler/xla/service/gpu/runtime/graph_launch.cc index 4325fe67881d94..f9df7a72503163 100644 --- a/tensorflow/compiler/xla/service/gpu/runtime/graph_launch.cc +++ b/tensorflow/compiler/xla/service/gpu/runtime/graph_launch.cc @@ -73,16 +73,68 @@ static absl::StatusOr CaptureGraph( // CUDA graphs caching. //===----------------------------------------------------------------------===// -StreamExecutorGraphInstances* GraphInstances::operator()( - se::StreamExecutor* executor) { - absl::MutexLock lock(&mutex_); - return &graphs_[executor]; +static absl::Mutex* GetGraphInstancesMutex() { + static auto* mu = new absl::Mutex(); + return mu; } -CapturedFunctionExecutionCount* CapturedFunctionExecutionCounts::operator()( +// Keep track of instantiated graphs on each StreamExecutor, we use this +// information in the graph eviction policy. +using GraphInstancesState = absl::flat_hash_map; + +static GraphInstancesState& GetGraphInstancesState() { + static auto* state = new GraphInstancesState(); + return *state; +} + +static int64_t NotifyGraphInstancesCreated(se::StreamExecutor* executor, + int64_t num_graphs) { + absl::MutexLock lock(GetGraphInstancesMutex()); + return GetGraphInstancesState()[executor] += num_graphs; +} + +static int64_t NotifyGraphInstancesDestroyed(se::StreamExecutor* executor, + int64_t num_graphs) { + absl::MutexLock lock(GetGraphInstancesMutex()); + return GetGraphInstancesState()[executor] -= num_graphs; +} + +GraphInstances::GraphInstances(std::string module_name, int64_t num_graphs) + : impl_(std::make_shared()) { + impl_->module_name = std::move(module_name); + impl_->num_graphs = num_graphs; + VLOG(3) << "Construct graph instances cache for: @" << impl_->module_name + << " (num_graphs = " << impl_->num_graphs << ")"; +} + +GraphInstances::~GraphInstances() { + VLOG(3) << "Destroy graph instances cache for: @" << impl_->module_name + << " (num_graphs = " << impl_->num_graphs << ")"; + + absl::MutexLock lock(&impl_->mu); + for (auto& [executor, state] : impl_->graphs) { + VLOG(3) << "Destroy " << impl_->num_graphs << " graphs for: @" + << impl_->module_name << " at executor: " << executor + << ". Total remaining graphs at given executor: " + << NotifyGraphInstancesDestroyed(executor, impl_->num_graphs); + } +} + +StreamExecutorGraphInstances* GraphInstances::operator()( se::StreamExecutor* executor) { - absl::MutexLock lock(&mutex_); - return &counts_[executor]; + absl::MutexLock lock(&impl_->mu); + + auto it = impl_->graphs.try_emplace(executor); + if (it.second && impl_->num_graphs > 0) { + VLOG(3) << "Instantiate " << impl_->num_graphs << " graphs for: @" + << impl_->module_name << " at executor: " << executor + << ". Total graphs at given executor: " + << NotifyGraphInstancesCreated(executor, impl_->num_graphs); + } + + State& state = it.first->second; + state.last_use_micros = tsl::Env::Default()->NowMicros(); + return &state.instances; } bool GraphInstances::InstantiatedAllGraphs( @@ -90,8 +142,8 @@ bool GraphInstances::InstantiatedAllGraphs( const Executable& executable) { if (executable.num_functions() == 1) return true; - absl::MutexLock lock(&mutex_); - return instantiated_.contains(run_options->stream()->parent()); + absl::MutexLock lock(&impl_->mu); + return impl_->graphs[run_options->stream()->parent()].instantiated; } Status GraphInstances::InstantiateAllGraphs( @@ -101,19 +153,18 @@ Status GraphInstances::InstantiateAllGraphs( // We have only "main" function in the executable. if (executable.num_functions() == 1) return OkStatus(); - absl::MutexLock lock(&mutex_); + absl::MutexLock lock(&impl_->mu); se::StreamExecutor* executor = run_options->stream()->parent(); - // All Gpu graphs are already instantiated for a given executor. - if (instantiated_.contains(executor)) return OkStatus(); + State& state = impl_->graphs[executor]; - VLOG(3) << "Instantate all Gpu graphs in executable " << executable.name(); + // All Gpu graphs are already instantiated for a given executor. + if (state.instantiated) return OkStatus(); TraceMe trace("cuda.graph.instantiate_all"); // Initialize graph instances snapshot for a given executor. - StreamExecutorGraphInstances::Snapshot instances = - graphs_[executor].snapshot(); + StreamExecutorGraphInstances::Snapshot instances = state.instances.snapshot(); // Instantiate all Gpu graphs by calling graph capture functions with fake // arguments. Once we'll execute them first time for real, they'll be updated @@ -172,10 +223,16 @@ Status GraphInstances::InstantiateAllGraphs( #endif // GOOGLE_CUDA } - instantiated_.insert(executor); + state.instantiated = true; return OkStatus(); } +CapturedFunctionExecutionCount* CapturedFunctionExecutionCounts::operator()( + se::StreamExecutor* executor) { + absl::MutexLock lock(&mutex_); + return &counts_[executor]; +} + //===----------------------------------------------------------------------===// // Helper structure to hash the remaining arguments' memref pointers. //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/xla/service/gpu/runtime/graph_launch.h b/tensorflow/compiler/xla/service/gpu/runtime/graph_launch.h index 5c3bc4b4867450..6a183409f17855 100644 --- a/tensorflow/compiler/xla/service/gpu/runtime/graph_launch.h +++ b/tensorflow/compiler/xla/service/gpu/runtime/graph_launch.h @@ -18,9 +18,9 @@ limitations under the License. #include #include +#include #include -#include "absl/container/flat_hash_set.h" #include "absl/container/node_hash_map.h" #include "tensorflow/compiler/xla/runtime/custom_call_registry.h" #include "tensorflow/compiler/xla/runtime/executable.h" @@ -80,8 +80,16 @@ class StreamExecutorGraphInstances #endif // #if GOOGLE_CUDA // Xla executable keeps a mapping from stream executors to graph instances. +// +// Graph instances allocate on-device memory, so we periodically destroy +// them to free up some space on device. JAX for example keeps all XLA +// executables alive, and destroys them when the process shuts down, so we can +// end up with thousands of unused (or rarely used) graphs in device memory. class GraphInstances { public: + GraphInstances(std::string module_name, int64_t num_graphs); + ~GraphInstances(); + StreamExecutorGraphInstances* operator()(se::StreamExecutor* executor); // Instantiates all Gpu graphs defined by the given executable using user @@ -98,11 +106,30 @@ class GraphInstances { const runtime::Executable& executable); private: - mutable absl::Mutex mutex_; - absl::node_hash_map graphs_ - ABSL_GUARDED_BY(mutex_); - absl::flat_hash_set instantiated_ - ABSL_GUARDED_BY(mutex_); + struct State { + // A flag signalling if `InstantiateAllGraphs` was already called and we + // have all Gpu graph instantiated ahead of time. + bool instantiated = false; + + // Last time graph instances were used by a particular stream executor. + uint64_t last_use_micros = 0; + + StreamExecutorGraphInstances instances; + }; + + struct Impl { + // XLA module name that owns graph instances. We use it only to produce logs + // that can be attributed back to XLA executables. + std::string module_name; + + // Number of graphs in the parent module. + int64_t num_graphs; + + mutable absl::Mutex mu; + absl::node_hash_map graphs ABSL_GUARDED_BY(mu); + }; + + std::shared_ptr impl_; }; // Xla executable keeps a mapping from stream executors to execution counts. From 00fa3ece8763fa4fbe6093ece2250c474ff639dc Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sat, 15 Jul 2023 02:02:08 -0700 Subject: [PATCH 349/376] compat: Update forward compatibility horizon to 2023-07-15 PiperOrigin-RevId: 548323344 --- tensorflow/python/compat/compat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index 764155e5591961..70b6e1971d498e 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -29,7 +29,7 @@ # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2023, 7, 14) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2023, 7, 15) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None From 69533b1b2603eff7d3c45213b9995a96546636a2 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sat, 15 Jul 2023 02:02:08 -0700 Subject: [PATCH 350/376] Update GraphDef version to 1558. PiperOrigin-RevId: 548323348 --- tensorflow/core/public/version.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index 2a64ed66331723..d319fe9e9a52cc 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -108,7 +108,7 @@ limitations under the License. #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 -#define TF_GRAPH_DEF_VERSION 1557 // Updated: 2023/7/14 +#define TF_GRAPH_DEF_VERSION 1558 // Updated: 2023/7/15 // Checkpoint compatibility versions (the versions field in SavedSliceMeta). // From 03a0795ab9535102a97cbf30e3c131b84facf541 Mon Sep 17 00:00:00 2001 From: Penporn Koanantakool Date: Sat, 15 Jul 2023 14:33:11 -0700 Subject: [PATCH 351/376] Rollback of PR #61123 Rollback of PR #61123: Update to ACL 23.05.1, add ACL reorders The original PR author asked to revert because it had caused some unit test build issues downstream. https://github.com/tensorflow/tensorflow/pull/61123#issuecomment-1636013571 PiperOrigin-RevId: 548393592 --- tensorflow/workspace2.bzl | 16 +- third_party/compute_library/BUILD | 172 + .../acl_fixed_format_kernels_striding.patch | 70 + .../compute_library/acl_openmp_fix.patch | 46 + .../compute_library/compute_library.patch | 8 + .../onednn_acl_depthwise_convolution.patch | 312 +- .../onednn_acl_fixed_format_kernels.patch | 1370 ++---- .../mkl_dnn/onednn_acl_remove_winograd.patch | 326 -- third_party/mkl_dnn/onednn_acl_reorder.patch | 352 -- .../mkl_dnn/onednn_acl_reorder_padded.patch | 858 ---- .../mkl_dnn/onednn_acl_reorder_update.patch | 4193 ----------------- .../onednn_acl_threadpool_scheduler.patch | 17 - 12 files changed, 835 insertions(+), 6905 deletions(-) create mode 100644 third_party/compute_library/acl_fixed_format_kernels_striding.patch create mode 100644 third_party/compute_library/acl_openmp_fix.patch create mode 100644 third_party/compute_library/compute_library.patch delete mode 100644 third_party/mkl_dnn/onednn_acl_remove_winograd.patch delete mode 100644 third_party/mkl_dnn/onednn_acl_reorder.patch delete mode 100644 third_party/mkl_dnn/onednn_acl_reorder_padded.patch delete mode 100644 third_party/mkl_dnn/onednn_acl_reorder_update.patch diff --git a/tensorflow/workspace2.bzl b/tensorflow/workspace2.bzl index 30c0daa23dc4b3..c4e64dbfa66d25 100644 --- a/tensorflow/workspace2.bzl +++ b/tensorflow/workspace2.bzl @@ -204,13 +204,9 @@ def _tf_repositories(): build_file = "//third_party/mkl_dnn:mkldnn_acl.BUILD", patch_file = [ "//third_party/mkl_dnn:onednn_acl_threadcap.patch", - "//third_party/mkl_dnn:onednn_acl_remove_winograd.patch", "//third_party/mkl_dnn:onednn_acl_fixed_format_kernels.patch", "//third_party/mkl_dnn:onednn_acl_depthwise_convolution.patch", "//third_party/mkl_dnn:onednn_acl_threadpool_scheduler.patch", - "//third_party/mkl_dnn:onednn_acl_reorder_padded.patch", - "//third_party/mkl_dnn:onednn_acl_reorder_update.patch", - "//third_party/mkl_dnn:onednn_acl_reorder.patch", ], sha256 = "a50993aa6265b799b040fe745e0010502f9f7103cc53a9525d59646aef006633", strip_prefix = "oneDNN-2.7.3", @@ -219,9 +215,15 @@ def _tf_repositories(): tf_http_archive( name = "compute_library", - sha256 = "c4ca329a78da380163b2d86e91ba728349b6f0ee97d66e260a694ef37f0b0d93", - strip_prefix = "ComputeLibrary-23.05.1", - urls = tf_mirror_urls("https://github.com/ARM-software/ComputeLibrary/archive/v23.05.1.tar.gz"), + sha256 = "e20a060d3c4f803889d96c2f0b865004ba3ef4e228299a44339ea1c1ba827c85", + strip_prefix = "ComputeLibrary-22.11", + build_file = "//third_party/compute_library:BUILD", + patch_file = [ + "//third_party/compute_library:compute_library.patch", + "//third_party/compute_library:acl_fixed_format_kernels_striding.patch", + "//third_party/compute_library:acl_openmp_fix.patch", + ], + urls = tf_mirror_urls("https://github.com/ARM-software/ComputeLibrary/archive/v22.11.tar.gz"), ) tf_http_archive( diff --git a/third_party/compute_library/BUILD b/third_party/compute_library/BUILD index 4fc694c50a43cf..14bde5ac345c80 100644 --- a/third_party/compute_library/BUILD +++ b/third_party/compute_library/BUILD @@ -2,6 +2,178 @@ load("@bazel_skylib//:bzl_library.bzl", "bzl_library") exports_files(["LICENSE"]) +cc_library( + name = "include", + hdrs = glob([ + "include/**/*.h", + "include/**/*.hpp", + ]), + includes = ["include"], + strip_include_prefix = "include", +) + +_COMPUTE_LIBRARY_DEFINES = [ + "ARM_COMPUTE_OPENMP_SCHEDULER", + "ARM_COMPUTE_CPU_ENABLED", + "ENABLE_NEON", + "ARM_COMPUTE_ENABLE_NEON", + "ENABLE_SVE", + "ARM_COMPUTE_ENABLE_SVE", + "ARM_COMPUTE_ENABLE_BF16", + "ARM_COMPUTE_ENABLE_I8MM", + "ARM_COMPUTE_ENABLE_SVEF32MM", + "ENABLE_FP32_KERNELS", + "ENABLE_QASYMM8_KERNELS", + "ENABLE_QASYMM8_SIGNED_KERNELS", + "ENABLE_QSYMM16_KERNELS", + "ENABLE_INTEGER_KERNELS", + "ENABLE_NHWC_KERNELS", + "ENABLE_NCHW_KERNELS", + "ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS", +] + +cc_library( + name = "arm_compute_sve2", + srcs = glob( + [ + "src/cpu/kernels/**/sve2/*.cpp", + "**/*.h", + "**/*.hpp", + "**/*.inl", + ], + ), + copts = [ + "-march=armv8.6-a+sve2", + "-fopenmp", + ], + defines = _COMPUTE_LIBRARY_DEFINES + ["ARM_COMPUTE_ENABLE_SVE2"], + includes = [ + "src/core/NEON/kernels/arm_conv", + "src/core/NEON/kernels/arm_gemm", + "src/core/NEON/kernels/assembly", + "src/core/cpu/kernels/assembly", + "src/cpu/kernels/assembly", + ], + linkopts = ["-fopenmp"], + deps = ["include"], +) + +cc_library( + name = "arm_compute_sve", + srcs = glob( + [ + "src/core/NEON/kernels/arm_gemm/kernels/sve_*/*.cpp", + "src/core/NEON/kernels/arm_conv/**/kernels/sve_*/*.cpp", + "src/core/NEON/kernels/arm_conv/depthwise/interleaves/sve_*.cpp", + "src/core/NEON/kernels/batchnormalization/impl/SVE/*.cpp", + "src/core/NEON/kernels/convolution/winograd/input_transforms/sve_fp32_6x6.cpp", + "src/cpu/kernels/**/sve/*.cpp", + "**/*.h", + "**/*.hpp", + "**/*.inl", + ], + ) + [ + "src/core/NEON/kernels/arm_gemm/mergeresults-sve.cpp", + "src/core/NEON/kernels/arm_gemm/transform-sve.cpp", + ], + copts = [ + "-march=armv8.2-a+sve", + "-fopenmp", + ], + defines = _COMPUTE_LIBRARY_DEFINES, + includes = [ + "src/core/NEON/kernels/arm_conv", + "src/core/NEON/kernels/arm_gemm", + "src/core/NEON/kernels/assembly", + "src/core/cpu/kernels/assembly", + "src/cpu/kernels/assembly", + ], + linkopts = ["-fopenmp"], + deps = ["include"], +) + +cc_library( + name = "arm_compute", + srcs = glob( + [ + "src/common/**/*.cpp", + "src/core/*.cpp", + "src/core/CPP/kernels/*.cpp", + "src/core/helpers/*.cpp", + "src/core/utils/**/*.cpp", + "src/runtime/**/*.cpp", + "src/c/*.cpp", + "src/core/NEON/kernels/*.cpp", + "src/core/NEON/kernels/convolution/**/*.cpp", + "src/core/NEON/kernels/arm_gemm/kernels/a64_*/*.cpp", + "src/core/NEON/kernels/arm_conv/pooling/*.cpp", + "src/core/NEON/kernels/arm_conv/**/kernels/a64_*/*.cpp", + "src/core/NEON/kernels/arm_conv/depthwise/*.cpp", + "src/core/NEON/kernels/arm_conv/depthwise/interleaves/a64_*.cpp", + "src/core/NEON/kernels/arm_conv/depthwise/interleaves/generic*.cpp", + "src/core/NEON/kernels/batchnormalization/impl/NEON/*.cpp", + "src/cpu/*.cpp", + "src/cpu/kernels/*.cpp", + "src/cpu/kernels/fuse_batch_normalization/**/*.cpp", + "src/cpu/kernels/*/generic/*.cpp", + "src/cpu/operators/**/*.cpp", + "src/cpu/utils/*.cpp", + "src/cpu/kernels/internal/*.cpp", + "src/cpu/kernels/**/neon/*.cpp", + "src/cpu/kernels/**/nchw/*.cpp", + "src/core/NEON/kernels/arm_gemm/*.cpp", + "**/*.h", + "**/*.hpp", + "**/*.inl", + ], + exclude = [ + "src/core/utils/logging/**", + "src/core/TracePoint.cpp", + "src/core/NEON/kernels/arm_gemm/mergeresults-sve.cpp", + "src/core/NEON/kernels/arm_gemm/transform-sve.cpp", + "src/core/NEON/kernels/convolution/winograd/input_transforms/sve_fp32_6x6.cpp", + "src/runtime/CL/**", + "src/gpu/**", + ], + ) + [ + "src/c/operators/AclActivation.cpp", + "src/core/CPP/CPPTypes.cpp", + "src/core/NEON/kernels/arm_conv/addressing.cpp", + "src/core/NEON/kernels/arm_conv/depthwise/interleaves/8b_mla.cpp", + "src/core/NEON/kernels/arm_conv/pooling/kernels/cpp_nhwc_1x1_stride_any_depthfirst/generic.cpp", + ], + hdrs = glob([ + "src/core/NEON/kernels/**/*.h", + "src/core/NEON/kernels/**/*.hpp", + "arm_compute/runtime/**/*.h", + "arm_compute/runtime/*.h", + "arm_compute/core/**/*.h", + "**/*.inl", + ]) + [ + "arm_compute_version.embed", + ], + copts = [ + "-march=armv8-a", + "-fopenmp", + ], + defines = _COMPUTE_LIBRARY_DEFINES, + includes = [ + "arm_compute/runtime", + "src/core/NEON/kernels/assembly", + "src/core/NEON/kernels/convolution/common", + "src/core/NEON/kernels/convolution/winograd", + "src/core/cpu/kernels/assembly", + "src/cpu/kernels/assembly", + ], + linkopts = ["-fopenmp"], + visibility = ["//visibility:public"], + deps = [ + "arm_compute_sve", + "arm_compute_sve2", + "include", + ], +) + config_setting( name = "build_with_acl", define_values = { diff --git a/third_party/compute_library/acl_fixed_format_kernels_striding.patch b/third_party/compute_library/acl_fixed_format_kernels_striding.patch new file mode 100644 index 00000000000000..8e501a1d6d9c79 --- /dev/null +++ b/third_party/compute_library/acl_fixed_format_kernels_striding.patch @@ -0,0 +1,70 @@ + ******************************************************************************* + Copyright 2022 Arm Limited and affiliates. + SPDX-License-Identifier: Apache-2.0 + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ******************************************************************************* + +diff --git a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp +index 77da83070..985f96761 100644 +--- a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp ++++ b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp +@@ -495,48 +495,6 @@ void Fallback::run(ITensorPack &tensors) + { + ldb = b->info()->strides_in_bytes().y() / sizeof(TypeInput); + multi_stride_b = b->info()->strides_in_bytes().z() / sizeof(TypeInput); +- const arm_compute::WeightFormat wf = assembly_utils::map_to_arm_compute_weight_format(_gemm_kernel_asm->get_config().weight_format); +- if(is_fixed_format(wf)) +- { +- // The 4D tensor of dimension O'HWI' created for the +- // OHWIoi format is in reality seen +- // as a 2D tensor at arm_gemm level, where the rows are +- // O'/ and the columns are * +- // H * W * I'. +- ITensorInfo *tensor_info = b->info(); +- const DataLayout data_layout = tensor_info->data_layout(); +- const TensorShape tensor_shape = tensor_info->tensor_shape(); +- const int tensor_height = tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT)]; +- const int tensor_width = tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH)]; +- int tensor_channels = tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL)]; +- const int interleave_by = arm_compute::interleave_by(wf); +- const int blocked_by = arm_compute::block_by(wf); +- // We need to find a new stride that is distance from the data for one +- // set of output channels to the next +- if(ldb == tensor_channels && multi_stride_b == tensor_channels * tensor_width) +- { +- // In this case dimensions that are packed are height, width and channel +- // so we need to stride it by interleave_by +- if(tensor_channels % blocked_by != 0) +- { +- // We need to pad +- tensor_channels = arm_gemm::iceildiv(tensor_channels, blocked_by) * blocked_by; +- } +- ldb = interleave_by * tensor_height * tensor_width * tensor_channels; +- } +- else if(multi_stride_b == 0 || (ldb == tensor_width && multi_stride_b == tensor_height * tensor_width)) +- { +- // In this case dimension that is packed is only height +- // so we need to stride only height by interleave_by +- ldb = interleave_by * tensor_height; +- } +- else +- { +- // If dimensions are not packed as above error is thrown +- // as at the moment other forms of packing are not supported +- ARM_COMPUTE_ERROR("Unsupported packing for fixed format kernel"); +- } +- } + in1_ptr = reinterpret_cast(b->buffer() + b->info()->offset_first_element_in_bytes()); + } + diff --git a/third_party/compute_library/acl_openmp_fix.patch b/third_party/compute_library/acl_openmp_fix.patch new file mode 100644 index 00000000000000..512148c8eca114 --- /dev/null +++ b/third_party/compute_library/acl_openmp_fix.patch @@ -0,0 +1,46 @@ + ******************************************************************************* + Copyright 2022 Arm Limited and affiliates. + SPDX-License-Identifier: Apache-2.0 + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ******************************************************************************* + +diff --git a/src/runtime/OMP/OMPScheduler.cpp b/src/runtime/OMP/OMPScheduler.cpp +index aad24b4f0..78d1523af 100644 +--- a/src/runtime/OMP/OMPScheduler.cpp ++++ b/src/runtime/OMP/OMPScheduler.cpp +@@ -90,18 +116,21 @@ void OMPScheduler::schedule_op(ICPPKernel *kernel, const Hints &hints, const Win + void OMPScheduler::run_workloads(std::vector &workloads) + { + const unsigned int amount_of_work = static_cast(workloads.size()); +- if(amount_of_work < 1 || _num_threads == 1) ++ const unsigned int num_threads_to_use = std::min(_num_threads, amount_of_work ); ++ ++ if(amount_of_work < 1 || num_threads_to_use == 1) + { + return; + } + + ThreadInfo info; + info.cpu_info = &cpu_info(); +- info.num_threads = _num_threads; +- #pragma omp parallel for firstprivate(info) num_threads(_num_threads) default(shared) proc_bind(close) schedule(static, 1) ++ info.num_threads = num_threads_to_use; ++ #pragma omp parallel for firstprivate(info) num_threads(num_threads_to_use) default(shared) proc_bind(close) schedule(static, 1) + for(unsigned int wid = 0; wid < amount_of_work; ++wid) + { + const int tid = omp_get_thread_num(); ++ + info.thread_id = tid; + workloads[wid](info); + } diff --git a/third_party/compute_library/compute_library.patch b/third_party/compute_library/compute_library.patch new file mode 100644 index 00000000000000..2b9619dd03503f --- /dev/null +++ b/third_party/compute_library/compute_library.patch @@ -0,0 +1,8 @@ +diff --git a/arm_compute_version.embed b/arm_compute_version.embed +new file mode 100644 +index 000000000..c986ad52a +--- /dev/null ++++ b/arm_compute_version.embed +@@ -0,0 +1,1 @@ ++"arm_compute_version=v22.11 Build options: {} Git hash=b'1b3192e8a23513031163dc14d248f47671986121'" +\ No newline at end of file diff --git a/third_party/mkl_dnn/onednn_acl_depthwise_convolution.patch b/third_party/mkl_dnn/onednn_acl_depthwise_convolution.patch index 950077665fb4b7..95f0374ec4ddd3 100644 --- a/third_party/mkl_dnn/onednn_acl_depthwise_convolution.patch +++ b/third_party/mkl_dnn/onednn_acl_depthwise_convolution.patch @@ -1,5 +1,5 @@ ******************************************************************************* - Copyright 2023 Arm Limited and affiliates. + Copyright 2022 Arm Limited and affiliates. SPDX-License-Identifier: Apache-2.0 Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,93 +14,87 @@ See the License for the specific language governing permissions and limitations under the License. ******************************************************************************* + diff --git a/src/cpu/aarch64/acl_convolution_utils.cpp b/src/cpu/aarch64/acl_convolution_utils.cpp -index 6b57374643..85e45ace9d 100644 +index fc93d2aa9..6ebac0d17 100644 --- a/src/cpu/aarch64/acl_convolution_utils.cpp +++ b/src/cpu/aarch64/acl_convolution_utils.cpp -@@ -48,11 +48,14 @@ status_t acl_init_conf(acl_conv_conf_t &acp, memory_desc_t &src_md, - if (!is_fwd) return status::unimplemented; - +@@ -54,10 +54,12 @@ status_t acl_init_conf(acl_conv_conf_t &acp, memory_desc_t &src_md, const int ndims = src_d.ndims(); -+ const bool is_depthwise = wei_d.ndims() == 5 && wei_d.dims()[1] == 1 -+ && wei_d.dims()[2] == 1; - -- ACL_CHECK_SUPPORT(ndims != 4, " only supports 2 spatial dimensions"); -+ ACL_CHECK_SUPPORT( -+ ndims != 4 && !is_depthwise, " only supports 2 spatial dimensions"); - - const int with_groups = wei_d.ndims() == src_d.ndims() + 1; -- ACL_CHECK_SUPPORT(with_groups, " does not support groups"); -+ ACL_CHECK_SUPPORT(with_groups && !is_depthwise, " does not support groups"); - - ACL_CHECK_SUPPORT(src_d.data_type() != data_type::f32 - || wei_d.data_type() != data_type::f32 -@@ -108,7 +111,8 @@ status_t acl_init_conf(acl_conv_conf_t &acp, memory_desc_t &src_md, + const bool is_1d = ndims == 3; + const bool is_3d = ndims == 5; ++ const bool is_depthwise = wei_d.ndims() == 5 && wei_d.dims()[1] == 1 && wei_d.dims()[2] == 1; ++ + bool is_nspc; - acp.with_bias = cd.bias_desc.format_kind != format_kind::undef; + // Compute Library unsupported shape scenarios +- if (one_of(true, is_3d, is_1d, with_groups)) { ++ if (one_of(true, is_3d, is_1d, (with_groups && !is_depthwise))) { + return status::unimplemented; + } -- if (wei_d.format_kind() != format_kind::any) return status::unimplemented; -+ if (wei_d.format_kind() != format_kind::any && !is_depthwise) -+ return status::unimplemented; +@@ -135,11 +137,11 @@ status_t acl_init_conf(acl_conv_conf_t &acp, memory_desc_t &src_md, + is_nspc = utils::one_of(src_tag, nhwc); - auto src_tag = memory_desc_matches_one_of_tag( - src_md, format_tag::nhwc, format_tag::nchw); -@@ -138,8 +142,12 @@ status_t acl_init_conf(acl_conv_conf_t &acp, memory_desc_t &src_md, - || src_tag != dst_tag) - return status::unimplemented; + memory_desc_t want_wei_md = weights_md; +- auto wei_tag = is_nspc ? ohwi : oihw; ++ auto wei_tag = is_depthwise ? hwigo : (is_nspc ? ohwi : oihw); + CHECK(memory_desc_init_by_tag(want_wei_md, wei_tag)); -- // Set weights to initially be the same as src -- CHECK(memory_desc_init_by_tag(weights_md, src_tag)); -+ if (is_depthwise) { -+ CHECK(memory_desc_init_by_tag(weights_md, format_tag::hwigo)); -+ } else { -+ // Set weights to initially be the same as src -+ CHECK(memory_desc_init_by_tag(weights_md, src_tag)); -+ } + // Compute Library does not support mismatching layouts +- if ((src_tag != wei_tag) || (src_tag != dst_tag)) ++ if (!is_depthwise && ((src_tag != wei_tag) || (src_tag != dst_tag))) + return status::unimplemented; - // Bias is just 1D, set to be the obvious format - if (acp.with_bias && bias_md.format_kind == format_kind::any) -@@ -166,6 +174,11 @@ status_t acl_init_conf(acl_conv_conf_t &acp, memory_desc_t &src_md, - 1, - acl_data_type, + if (weights_md.format_kind == format_kind::any) { +@@ -187,6 +189,12 @@ status_t acl_init_conf(acl_conv_conf_t &acp, memory_desc_t &src_md, + acl_wei_data_t, acl_layout); + + if(is_depthwise) { -+ // We need to set that values are not constant so that we -+ // we can update them in-place in ACL -+ acp.wei_tensor_info.set_are_values_constant(false); ++ // We need to set that values are not constant so that we ++ // we can update them in-place in ACL ++ acp.wei_info.set_are_values_constant(false); + } ++ + acp.dst_info = arm_compute::TensorInfo( + is_nspc ? arm_compute::TensorShape(oc, ow, oh, mb) : + arm_compute::TensorShape(ow, oh, oc, mb), +@@ -212,6 +220,12 @@ status_t acl_init_conf(acl_conv_conf_t &acp, memory_desc_t &src_md, + arm_compute::QuantizationInfo(1.0f / scales[0], 0)); + } - acp.dst_tensor_info = arm_compute::TensorInfo( - is_nhwc ? arm_compute::TensorShape(oc, ow, oh, mb) : -@@ -185,6 +198,11 @@ status_t acl_init_conf(acl_conv_conf_t &acp, memory_desc_t &src_md, - // Are we allowed to cast down to bf16 or not? - acp.fast_math - = one_of(attr.fpmath_mode_, fpmath_mode::bf16, fpmath_mode::any); -+ if (is_depthwise) { ++ if(is_depthwise) { + // There is no support for fixed format kernels for depthwise convolution + // in ACL so we are going to use weight format that we set up earlier + return status::success; + } - - // WeightFormat::ANY tells ACL we can handle any format ++ acp.weights_info = arm_compute::WeightsInfo( -@@ -252,6 +270,7 @@ status_t init_conf_gemm(acl_conv_conf_t &acp, memory_desc_t &src_md, - memory_desc_t &weights_md, memory_desc_t &dst_md, - memory_desc_t &bias_md, const convolution_desc_t &cd, + false, + kw, +@@ -302,6 +316,10 @@ status_t init_conf_gemm(acl_conv_conf_t &acp, memory_desc_t &src_md, const primitive_attr_t &attr) { -+ if (weights_md.ndims != 4) return status::unimplemented; + acp.is_indirect = false; ++ if(weights_md.ndims != 4) { ++ return status::unimplemented; ++ } ++ // General Compute Library checks, memory tags are also set there CHECK(acl_init_conf(acp, src_md, weights_md, dst_md, bias_md, cd, attr)); -@@ -277,6 +296,7 @@ status_t init_conf_indirect_gemm(acl_conv_conf_t &acp, memory_desc_t &src_md, - memory_desc_t &weights_md, memory_desc_t &dst_md, - memory_desc_t &bias_md, const convolution_desc_t &cd, - const primitive_attr_t &attr) { -+ if (weights_md.ndims != 4) return status::unimplemented; - // Indirect is slower for small convolution kernels - if (weights_md.dims[2] == 1 && weights_md.dims[3] == 1) -@@ -314,6 +334,22 @@ status_t init_conf_indirect_gemm(acl_conv_conf_t &acp, memory_desc_t &src_md, +@@ -330,7 +348,8 @@ status_t init_conf_indirect_gemm(acl_conv_conf_t &acp, memory_desc_t &src_md, + auto math_mode = get_fpmath_mode(); + // Indirect convolution results in slowdown for low thread count or 1x1 + // kernels, so fall back to GEMM-based convolution in these cases +- if (one_of(true, weights_md.dims[2] == 1, // kh ++ if (one_of(true, weights_md.ndims != 4, ++ weights_md.dims[2] == 1, // kh + weights_md.dims[3] == 1, // kw + (!math_mode && dnnl_get_max_threads() < 28))) { + return status::unimplemented; +@@ -355,6 +374,27 @@ status_t init_conf_indirect_gemm(acl_conv_conf_t &acp, memory_desc_t &src_md, return status::success; } @@ -108,26 +102,41 @@ index 6b57374643..85e45ace9d 100644 + memory_desc_t &weights_md, memory_desc_t &dst_md, + memory_desc_t &bias_md, const convolution_desc_t &cd, + const primitive_attr_t &attr) { -+ if (weights_md.ndims != 5) return status::unimplemented; ++ acp.is_indirect = false; ++ // We need to make sure that number of dimensions for weights is either 5 or 3 ++ if(weights_md.ndims != 5) ++ return status::unimplemented; + + CHECK(acl_init_conf(acp, src_md, weights_md, dst_md, bias_md, cd, attr)); + + ACL_CHECK_VALID(arm_compute::NEDepthwiseConvolutionLayer::validate( -+ &acp.src_tensor_info, &acp.wei_tensor_info, -+ acp.with_bias ? &acp.bia_tensor_info : nullptr, -+ &acp.dst_tensor_info, acp.padstride_info)); ++ &acp.src_info, ++ &acp.wei_info, ++ acp.with_bias ? &acp.bia_info : nullptr, ++ &acp.dst_info, ++ acp.padstride_info)); + + return status::success; +} + - } // namespace acl_convolution_utils - - } // namespace aarch64 + status_t init_conf_wino(acl_conv_conf_t &acp, memory_desc_t &src_md, + memory_desc_t &weights_md, memory_desc_t &dst_md, + memory_desc_t &bias_md, const convolution_desc_t &cd, +@@ -364,7 +404,8 @@ status_t init_conf_wino(acl_conv_conf_t &acp, memory_desc_t &src_md, + // Under these conditions, fallback to faster GEMM-based convolution + // unless the user explicitly specifies Winograd algorithm + // clang-format off +- if (one_of(true, src_md.dims[2] > 112, // ih ++ if (one_of(true, weights_md.ndims != 4, ++ src_md.dims[2] > 112, // ih + src_md.dims[3] > 112, // iw + src_md.dims[1] < 64, // ic + dst_md.dims[1] < 64, // oc diff --git a/src/cpu/aarch64/acl_convolution_utils.hpp b/src/cpu/aarch64/acl_convolution_utils.hpp -index e3d40a5e75..1ded5826c4 100644 +index 44dc8eecb..7eae5cbb1 100644 --- a/src/cpu/aarch64/acl_convolution_utils.hpp +++ b/src/cpu/aarch64/acl_convolution_utils.hpp -@@ -66,6 +66,11 @@ status_t init_conf_indirect_gemm(acl_conv_conf_t &acp, memory_desc_t &src_md, +@@ -67,6 +67,11 @@ status_t init_conf_indirect_gemm(acl_conv_conf_t &acp, memory_desc_t &src_md, memory_desc_t &bias_md, const convolution_desc_t &cd, const primitive_attr_t &attr); @@ -136,17 +145,37 @@ index e3d40a5e75..1ded5826c4 100644 + memory_desc_t &bias_md, const convolution_desc_t &cd, + const primitive_attr_t &attr); + - } // namespace acl_convolution_utils - - template > &impl_list_map() + CPU_INSTANCE_AARCH64(jit_sve_512_dw_convolution_fwd_t) + CPU_INSTANCE_AARCH64(jit_sve_512_1x1_convolution_fwd_f32_t) + CPU_INSTANCE_AARCH64(jit_sve_512_convolution_fwd_t) ++ CPU_INSTANCE_AARCH64_ACL(acl_depthwise_convolution_fwd_t) + CPU_INSTANCE_AARCH64_ACL(acl_indirect_gemm_convolution_fwd_t) + CPU_INSTANCE_AARCH64_ACL(acl_gemm_convolution_fwd_t) + CPU_INSTANCE(gemm_convolution_fwd_t) diff --git a/src/cpu/aarch64/acl_depthwise_convolution.cpp b/src/cpu/aarch64/acl_depthwise_convolution.cpp new file mode 100644 -index 0000000000..70ae6bceea +index 000000000..1beb8b8af --- /dev/null +++ b/src/cpu/aarch64/acl_depthwise_convolution.cpp -@@ -0,0 +1,42 @@ +@@ -0,0 +1,41 @@ +/******************************************************************************* -+* Copyright 2023 Arm Ltd. and affiliates ++* Copyright 2022 Arm Ltd. and affiliates +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. @@ -169,32 +198,31 @@ index 0000000000..70ae6bceea +namespace aarch64 { + +status_t acl_depthwise_convolution_fwd_t::execute_forward( -+ const exec_ctx_t &ctx) const { -+ std::lock_guard _lock {this->mtx}; -+ -+ auto *acl_resource -+ = ctx.get_resource_mapper() -+ ->get(this); -+ acl_obj_t &acl_depthwise_obj -+ = acl_resource->get_acl_obj(); -+ -+ return execute_forward_conv_acl< -+ acl_obj_t, pd_t, data_t>( -+ ctx, acl_depthwise_obj, pd()); -+} ++ const exec_ctx_t &ctx) const { ++ std::lock_guard _lock {this->mtx}; + -+} // namespace aarch64 -+} // namespace cpu -+} // namespace impl -+} // namespace dnnl ++ auto *acl_resource ++ = ctx.get_resource_mapper()->get( ++ this); ++ acl_obj_t &acl_depthwise_obj ++ = acl_resource->get_acl_obj(); ++ ++ return execute_forward_conv_acl, pd_t, ++ data_t>(ctx, acl_depthwise_obj, pd()); ++ } ++ ++} ++} ++} ++} diff --git a/src/cpu/aarch64/acl_depthwise_convolution.hpp b/src/cpu/aarch64/acl_depthwise_convolution.hpp new file mode 100644 -index 0000000000..3e3d02cf41 +index 000000000..d84fc4fb5 --- /dev/null +++ b/src/cpu/aarch64/acl_depthwise_convolution.hpp -@@ -0,0 +1,141 @@ +@@ -0,0 +1,139 @@ +/******************************************************************************* -+* Copyright 2023 Arm Ltd. and affiliates ++* Copyright 2022 Arm Ltd. and affiliates +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. @@ -212,8 +240,8 @@ index 0000000000..3e3d02cf41 +#ifndef CPU_AARCH64_ACL_DEPTHWISE_CONVOLUTION_HPP +#define CPU_AARCH64_ACL_DEPTHWISE_CONVOLUTION_HPP + -+#include "cpu/aarch64/acl_convolution_utils.hpp" +#include "cpu/cpu_convolution_pd.hpp" ++#include "cpu/aarch64/acl_convolution_utils.hpp" + +namespace dnnl { +namespace impl { @@ -222,16 +250,15 @@ index 0000000000..3e3d02cf41 + +struct acl_depthwise_convolution_resource_t : public resource_t { + acl_depthwise_convolution_resource_t() -+ : acl_obj_(utils::make_unique< -+ acl_obj_t>()) {} ++ : acl_obj_(utils::make_unique>()) {} + + status_t configure(const acl_conv_conf_t &acp) { -+ if (!acl_obj_) return status::out_of_memory; ++ if(!acl_obj_) return status::out_of_memory; + -+ acl_obj_->src_tensor.allocator()->init(acp.src_tensor_info); -+ acl_obj_->wei_tensor.allocator()->init(acp.wei_tensor_info); -+ acl_obj_->dst_tensor.allocator()->init(acp.dst_tensor_info); -+ acl_obj_->bia_tensor.allocator()->init(acp.bia_tensor_info); ++ acl_obj_->src_tensor.allocator()->init(acp.src_info); ++ acl_obj_->wei_tensor.allocator()->init(acp.wei_info); ++ acl_obj_->dst_tensor.allocator()->init(acp.dst_info); ++ acl_obj_->bia_tensor.allocator()->init(acp.bia_info); + + // clang-format off + acl_obj_->conv.configure( @@ -254,14 +281,14 @@ index 0000000000..3e3d02cf41 + DNNL_DISALLOW_COPY_AND_ASSIGN(acl_depthwise_convolution_resource_t); + +private: -+ std::unique_ptr> -+ acl_obj_; ++ std::unique_ptr> acl_obj_; ++ +}; + +struct acl_depthwise_convolution_fwd_t : public primitive_t { + + struct pd_t : public cpu_convolution_fwd_pd_t { -+ pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr, ++ pd_t(const convolution_desc_t* adesc, const primitive_attr_t *attr, + const typename pd_t::base_class *hint_fwd_pd) + : cpu_convolution_fwd_pd_t(adesc, attr, hint_fwd_pd), acp_() {} + @@ -270,18 +297,16 @@ index 0000000000..3e3d02cf41 + + status_t init(engine_t *engine) { + using namespace data_type; ++ using smask_t = primitive_attr_t::skip_mask_t; + -+ const bool is_fp16_ok = expect_data_types(f16, f16, f16, f16, undef) -+ && attr()->has_default_values( -+ primitive_attr_t::skip_mask_t::post_ops, f16); -+ const bool is_fp32_ok = expect_data_types(f32, f32, f32, f32, undef) -+ && attr()->has_default_values( -+ primitive_attr_t::skip_mask_t::post_ops, f32); + bool ok = is_fwd() + && set_default_alg_kind(alg_kind::convolution_direct) -+ && utils::one_of(true, is_fp16_ok, is_fp32_ok) -+ && !has_zero_dim_memory(); -+ if (!ok) return status::unimplemented; ++ && expect_data_types(data_type::f32, data_type::f32, ++ data_type::f32, data_type::f32, undef) ++ && !has_zero_dim_memory() ++ && attr()->has_default_values( ++ smask_t::post_ops, data_type::f32); ++ if(!ok) return status::unimplemented; + + CHECK(acl_convolution_utils::init_conf_depthwise(acp_, src_md_, + weights_md_, dst_md_, bias_md_, *desc(), *attr())); @@ -301,31 +326,32 @@ index 0000000000..3e3d02cf41 + acl_depthwise_convolution_fwd_t(const pd_t *apd) : primitive_t(apd) {} + + status_t create_resource( -+ engine_t *engine, resource_mapper_t &mapper) const override { -+ if (mapper.has_resource(this)) return status::success; ++ engine_t *engine, resource_mapper_t &mapper) const override { ++ if(mapper.has_resource(this)) return status::success; + -+ auto r = utils::make_unique(); -+ if (!r) return status::out_of_memory; ++ auto r = utils::make_unique(); ++ if(!r) return status::out_of_memory; + -+ CHECK(r->configure(pd()->acp_)); -+ mapper.add(this, std::move(r)); ++ CHECK(r->configure(pd()->acp_)); ++ mapper.add(this, std::move(r)); + -+ CHECK(pd()->post_ops.create_resource(engine, mapper)); ++ CHECK(pd()->post_ops.create_resource(engine, mapper)); + -+ return status::success; -+ } ++ return status::success; ++ } + -+ typedef typename prec_traits::type data_t; ++ typedef typename prec_traits::type data_t; + -+ status_t execute(const exec_ctx_t &ctx) const override { -+ return execute_forward(ctx); -+ } ++ status_t execute(const exec_ctx_t &ctx) const override { ++ return execute_forward(ctx); ++ } + +private: + mutable std::mutex mtx; + status_t execute_forward(const exec_ctx_t &ctx) const; + + const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } ++ +}; + +} // namespace aarch64 @@ -334,23 +360,3 @@ index 0000000000..3e3d02cf41 +} // namespace dnnl + +#endif // CPU_AARCH64_ACL_DEPTHWISE_CONVOLUTION_HPP -diff --git a/src/cpu/cpu_convolution_list.cpp b/src/cpu/cpu_convolution_list.cpp -index 094c73aa36..80385432d8 100644 ---- a/src/cpu/cpu_convolution_list.cpp -+++ b/src/cpu/cpu_convolution_list.cpp -@@ -63,6 +63,7 @@ using namespace dnnl::impl::cpu::x64; - #include "cpu/aarch64/jit_sve_512_x8s8s32x_convolution.hpp" - #include "cpu/aarch64/jit_uni_dw_convolution.hpp" - #if DNNL_AARCH64 && DNNL_AARCH64_USE_ACL -+#include "cpu/aarch64/acl_depthwise_convolution.hpp" - #include "cpu/aarch64/acl_gemm_convolution.hpp" - #include "cpu/aarch64/acl_indirect_gemm_convolution.hpp" - #endif -@@ -102,6 +103,7 @@ const std::map> &impl_list_map() - CPU_INSTANCE_AARCH64(jit_sve_512_dw_convolution_fwd_t) - CPU_INSTANCE_AARCH64(jit_sve_512_1x1_convolution_fwd_f32_t) - CPU_INSTANCE_AARCH64(jit_sve_512_convolution_fwd_t) -+ CPU_INSTANCE_AARCH64_ACL(acl_depthwise_convolution_fwd_t) - CPU_INSTANCE_AARCH64_ACL(acl_indirect_gemm_convolution_fwd_t) - CPU_INSTANCE_AARCH64_ACL(acl_gemm_convolution_fwd_t) - CPU_INSTANCE(gemm_convolution_fwd_t) diff --git a/third_party/mkl_dnn/onednn_acl_fixed_format_kernels.patch b/third_party/mkl_dnn/onednn_acl_fixed_format_kernels.patch index 282e839bf1eb36..2c8af08ab8a4ff 100644 --- a/third_party/mkl_dnn/onednn_acl_fixed_format_kernels.patch +++ b/third_party/mkl_dnn/onednn_acl_fixed_format_kernels.patch @@ -1,5 +1,5 @@ ******************************************************************************* - Copyright 2023 Arm Limited and affiliates. + Copyright 2022 Arm Limited and affiliates. SPDX-License-Identifier: Apache-2.0 Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,479 +14,178 @@ See the License for the specific language governing permissions and limitations under the License. ******************************************************************************* -diff --git a/src/common/matmul_pd.hpp b/src/common/matmul_pd.hpp -index 4330ad938b..df16c5fcca 100644 ---- a/src/common/matmul_pd.hpp -+++ b/src/common/matmul_pd.hpp -@@ -159,6 +159,19 @@ protected: - - return true; - } -+ -+ // All implementations that do not support sparse inputs/outputs should -+ // call this function. -+ bool is_dense_data() { -+#ifdef DNNL_EXPERIMENTAL_SPARSE -+ for (auto md : {&src_md_, &weights_md_, &bias_md_, &dst_md_}) { -+ if (memory_desc_wrapper(md).format_kind() == format_kind::sparse) -+ return false; -+ } -+#endif -+ return true; -+ } -+ - }; - - } // namespace impl + diff --git a/src/cpu/aarch64/acl_convolution_utils.cpp b/src/cpu/aarch64/acl_convolution_utils.cpp -index 37f8ecbc06..6b57374643 100644 +index c46d69757..fc93d2aa9 100644 --- a/src/cpu/aarch64/acl_convolution_utils.cpp +++ b/src/cpu/aarch64/acl_convolution_utils.cpp -@@ -41,25 +41,23 @@ status_t acl_init_conf(acl_conv_conf_t &acp, memory_desc_t &src_md, - const memory_desc_wrapper dst_d(&dst_md); - const memory_desc_wrapper bia_d(&bias_md); - -- auto math_mode = get_fpmath_mode(); -- acp.fast_math = one_of(math_mode, fpmath_mode::bf16, fpmath_mode::any); -- - // Compute Library currently supports forward propagation only - const prop_kind_t prop_kind = cd.prop_kind; - const bool is_fwd = (prop_kind == dnnl_forward_training) - || (prop_kind == dnnl_forward_inference); - if (!is_fwd) return status::unimplemented; - -- const int with_groups = wei_d.ndims() == src_d.ndims() + 1; - const int ndims = src_d.ndims(); -- const bool is_1d = ndims == 3; -- const bool is_3d = ndims == 5; -- bool is_nspc; - -- // Compute Library unsupported shape scenarios -- if (one_of(true, is_3d, is_1d, with_groups)) { -- return status::unimplemented; -- } -+ ACL_CHECK_SUPPORT(ndims != 4, " only supports 2 spatial dimensions"); -+ -+ const int with_groups = wei_d.ndims() == src_d.ndims() + 1; -+ ACL_CHECK_SUPPORT(with_groups, " does not support groups"); -+ -+ ACL_CHECK_SUPPORT(src_d.data_type() != data_type::f32 -+ || wei_d.data_type() != data_type::f32 -+ || dst_d.data_type() != data_type::f32, -+ " src, dst and wei must be fp32"); - - // batch size - const int mb = src_d.dims()[0]; -@@ -110,108 +108,143 @@ status_t acl_init_conf(acl_conv_conf_t &acp, memory_desc_t &src_md, - - acp.with_bias = cd.bias_desc.format_kind != format_kind::undef; - -- auto set_or_check_tags = [&](format_tag_t desired_src_tag, -- format_tag_t desired_dst_tag) -> status_t { -- using namespace format_tag; -- auto src_tag = any, dst_tag = any; -- -- if (src_d.format_kind() == format_kind::any) { -- CHECK(memory_desc_init_by_tag(src_md, desired_src_tag)); -- src_tag = desired_src_tag; -- } else { -- src_tag = memory_desc_matches_one_of_tag(src_md, nhwc, nchw); -- } -- -- if (dst_d.format_kind() == format_kind::any) { -- CHECK(memory_desc_init_by_tag(dst_md, desired_dst_tag)); -- dst_tag = desired_dst_tag; -- } else { -- dst_tag = memory_desc_matches_one_of_tag(dst_md, nhwc, nchw); -- } -- -- if (acp.with_bias && bias_md.format_kind == format_kind::any) -- CHECK(memory_desc_init_by_tag(bias_md, x)); -- -- is_nspc = utils::one_of(src_tag, nhwc); -- -- memory_desc_t want_wei_md = weights_md; -- auto wei_tag = is_nspc ? ohwi : oihw; -- CHECK(memory_desc_init_by_tag(want_wei_md, wei_tag)); -- -- // Compute Library does not support mismatching layouts -- if ((src_tag != wei_tag) || (src_tag != dst_tag)) -- return status::unimplemented; -+ if (wei_d.format_kind() != format_kind::any) return status::unimplemented; -+ -+ auto src_tag = memory_desc_matches_one_of_tag( -+ src_md, format_tag::nhwc, format_tag::nchw); -+ auto dst_tag = memory_desc_matches_one_of_tag( -+ dst_md, format_tag::nhwc, format_tag::nchw); -+ -+ // We want src and dst to match, preferrably both to be NHWC -+ if (src_d.format_kind() == format_kind::any -+ && dst_d.format_kind() == format_kind::any) { -+ CHECK(memory_desc_init_by_tag(src_md, format_tag::nhwc)); -+ CHECK(memory_desc_init_by_tag(dst_md, format_tag::nhwc)); -+ } else if (src_d.format_kind() == format_kind::any -+ && dst_tag != format_tag::undef) { -+ CHECK(memory_desc_init_by_tag(src_md, dst_tag)); -+ } else if (dst_d.format_kind() == format_kind::any -+ && src_tag != format_tag::undef) { -+ CHECK(memory_desc_init_by_tag(dst_md, src_tag)); -+ } - -- if (weights_md.format_kind == format_kind::any) { -- weights_md = want_wei_md; -- } -- return (want_wei_md == weights_md) ? status::success -- : status::unimplemented; -- }; -+ // Recompute tags after potentially running memory desc init -+ src_tag = memory_desc_matches_one_of_tag( -+ src_md, format_tag::nhwc, format_tag::nchw); -+ dst_tag = memory_desc_matches_one_of_tag( -+ dst_md, format_tag::nhwc, format_tag::nchw); - -- auto default_dat_tag = format_tag::nhwc; -- if (set_or_check_tags(default_dat_tag, default_dat_tag) != status::success) -+ if (src_tag == format_tag::undef || dst_tag == format_tag::undef -+ || src_tag != dst_tag) - return status::unimplemented; - -- const auto acl_layout = is_nspc ? arm_compute::DataLayout::NHWC -- : arm_compute::DataLayout::NCHW; -+ // Set weights to initially be the same as src -+ CHECK(memory_desc_init_by_tag(weights_md, src_tag)); - -- // For convolutions, int8 datatypes imply quantized types in ACL -- acp.is_int8 = utils::one_of(src_d.data_type(), s8, u8) -- && wei_d.data_type() == s8; -+ // Bias is just 1D, set to be the obvious format -+ if (acp.with_bias && bias_md.format_kind == format_kind::any) -+ CHECK(memory_desc_init_by_tag(bias_md, format_tag::x)); - -- auto acl_src_data_t -- = acl_utils::get_acl_data_t(src_d.data_type(), acp.is_int8); -- auto acl_wei_data_t -- = acl_utils::get_acl_data_t(wei_d.data_type(), acp.is_int8); -- auto acl_dst_data_t -- = acl_utils::get_acl_data_t(dst_d.data_type(), acp.is_int8); -- auto acl_bia_data_t -- = acl_utils::get_acl_data_t(bia_d.data_type(), acp.is_int8); -+ bool is_nhwc = src_tag == format_tag::nhwc; -+ // The layouts have to match (although we may later modify the weights) -+ const auto acl_layout = is_nhwc ? arm_compute::DataLayout::NHWC -+ : arm_compute::DataLayout::NCHW; - -- if (acl_bia_data_t == arm_compute::DataType::UNKNOWN) -- acl_bia_data_t = arm_compute::DataType::F32; -+ auto acl_data_type = arm_compute::DataType::F32; - - // clang-format off -- acp.src_info = arm_compute::TensorInfo( -- is_nspc ? arm_compute::TensorShape(ic, iw, ih, mb) : -+ acp.src_tensor_info = arm_compute::TensorInfo( -+ is_nhwc ? arm_compute::TensorShape(ic, iw, ih, mb) : - arm_compute::TensorShape(iw, ih, ic, mb), - 1, -- acl_src_data_t, -+ acl_data_type, - acl_layout); - -- acp.wei_info = arm_compute::TensorInfo( -- is_nspc ? arm_compute::TensorShape(ic, kw, kh, oc) : -+ acp.wei_tensor_info = arm_compute::TensorInfo( -+ is_nhwc ? arm_compute::TensorShape(ic, kw, kh, oc) : - arm_compute::TensorShape(kw, kh, ic, oc), - 1, -- acl_wei_data_t, -+ acl_data_type, - acl_layout); - -- acp.dst_info = arm_compute::TensorInfo( -- is_nspc ? arm_compute::TensorShape(oc, ow, oh, mb) : -+ acp.dst_tensor_info = arm_compute::TensorInfo( -+ is_nhwc ? arm_compute::TensorShape(oc, ow, oh, mb) : - arm_compute::TensorShape(ow, oh, oc, mb), - 1, -- acl_dst_data_t, -+ acl_data_type, - acl_layout); - -- acp.bia_info = arm_compute::TensorInfo( -+ acp.bia_tensor_info = arm_compute::TensorInfo( - acp.with_bias ? arm_compute::TensorShape(oc) - : arm_compute::TensorShape(), - 1, -- acl_bia_data_t, -+ acl_data_type, - acl_layout); - // clang-format on +@@ -212,6 +212,87 @@ status_t acl_init_conf(acl_conv_conf_t &acp, memory_desc_t &src_md, + arm_compute::QuantizationInfo(1.0f / scales[0], 0)); + } -- // Add quantization info to tensors -- if (acp.is_int8) { -- const float *scales = attr.output_scales_.scales_; -- acp.src_info.set_quantization_info(arm_compute::QuantizationInfo(1, 0)); -- acp.bia_info.set_quantization_info(arm_compute::QuantizationInfo(1, 0)); -- acp.wei_info.set_quantization_info(arm_compute::QuantizationInfo(1, 0)); -- acp.dst_info.set_quantization_info( -- arm_compute::QuantizationInfo(1.0f / scales[0], 0)); -+ // Are we allowed to cast down to bf16 or not? -+ acp.fast_math -+ = one_of(attr.fpmath_mode_, fpmath_mode::bf16, fpmath_mode::any); -+ -+ // WeightFormat::ANY tells ACL we can handle any format + acp.weights_info = arm_compute::WeightsInfo( -+ false, kw, kh, oc, false, arm_compute::WeightFormat::ANY); -+ -+ // Get the format that the ACL kernel will expect the weights to be -+ // in (if a kernel exists). Note that these are referred to as fixed format -+ // kernels, because they require one specific weights format ++ false, ++ kw, ++ kh, ++ oc, ++ false, ++ arm_compute::WeightFormat::ANY); + arm_compute::WeightFormat expected_weight_format; -+ ACL_CHECK_VALID(arm_compute::NEGEMMConvolutionLayer::has_opt_impl( -+ expected_weight_format, &acp.src_tensor_info, &acp.wei_tensor_info, -+ acp.with_bias ? &acp.bia_tensor_info : nullptr, -+ &acp.dst_tensor_info, acp.padstride_info, acp.weights_info, -+ acp.dilation_info, acp.act_info, acp.fast_math)); -+ -+ // Set weights info to the one returned by has_opt_impl ++ auto acl_st = arm_compute::NEGEMMConvolutionLayer::has_opt_impl( ++ expected_weight_format, ++ &acp.src_info, ++ &acp.wei_info, ++ acp.with_bias ? &acp.bia_info : nullptr, ++ &acp.dst_info, ++ acp.padstride_info, ++ acp.weights_info, ++ acp.dilation_info, ++ acp.act_info, ++ acp.fast_math); ++ if(acl_st.error_code() != arm_compute::ErrorCode::OK) { ++ return status::unimplemented; ++ } + acp.weights_info.set_weight_format(expected_weight_format); + -+ // has_opt_impl may return a non fast math kernel, even if we requested one -+ acp.fast_math -+ = arm_compute::is_fixed_format_fast_math(expected_weight_format); ++ int interleaved_by = arm_compute::interleave_by(expected_weight_format); ++ int block_by = arm_compute::block_by(expected_weight_format); + -+ // Map OIHW used in ACL WeightFormat to the logical dimensions of the memory descriptor -+ dim_t O_dim = 0; -+ dim_t I_dim = 1; -+ dim_t H_dim = 2; -+ dim_t W_dim = 3; -+ -+ if (!is_nhwc) { -+ // We can try to support NCHW by swapping IHW around, note that this -+ // requires weights_md.dims[I_dim] % block_by != 0 (see next block) -+ O_dim = 0; -+ I_dim = 3; -+ H_dim = 1; -+ W_dim = 2; - } - -+ // We can't currently support nchw and block_by != 1. If this is the case, -+ // try a non fast math kernel, which currently have no blocking -+ int block_by = arm_compute::block_by(acp.weights_info.weight_format()); -+ if (!is_nhwc && weights_md.dims[I_dim] % block_by != 0 && acp.fast_math) { ++ bool is_fast_math_kernel = arm_compute::is_fixed_format_fast_math(expected_weight_format); ++ if(!is_fast_math_kernel) { ++ // FP32 kernel is faster then BF16 + acp.fast_math = false; -+ acp.weights_info.set_weight_format(arm_compute::WeightFormat::ANY); -+ ACL_CHECK_VALID(arm_compute::NEGEMMConvolutionLayer::has_opt_impl( -+ expected_weight_format, &acp.src_tensor_info, -+ &acp.wei_tensor_info, -+ acp.with_bias ? &acp.bia_tensor_info : nullptr, -+ &acp.dst_tensor_info, acp.padstride_info, acp.weights_info, -+ acp.dilation_info, acp.act_info, acp.fast_math)); -+ acp.weights_info.set_weight_format(expected_weight_format); -+ block_by = arm_compute::block_by(expected_weight_format); -+ // This shouldn't happen, because non-fastmath have no blocking, but -+ // guard against it because it would silently return incorrect results -+ if (weights_md.dims[I_dim] % block_by != 0) ++ } ++ ++ memory_desc_t want_wei_md = weights_md; ++ ++ int ic_multiply = ic; ++ if(ic % block_by != 0) { ++ ic_multiply = utils::div_up(ic, block_by) * block_by; ++ // Also we need to set padded dimensions as well ++ want_wei_md.padded_dims[1] = ic_multiply; ++ } else { ++ // If we do not need to pad input channels for fast math mode ++ // then it would be faster to run convolution with im2row ++ // instead of using indirect buffer ++ if(acp.fast_math && acp.is_indirect) { + return status::unimplemented; ++ } ++ } ++ if(oc % interleaved_by != 0) { ++ int padded_dim = utils::div_up(oc, interleaved_by) * interleaved_by; ++ want_wei_md.padded_dims[0] = padded_dim; ++ } ++ ++ // Set strides based on blocking information ++ want_wei_md.format_desc.blocking.strides[0] = interleaved_by*ic_multiply*kw*kh; ++ want_wei_md.format_desc.blocking.strides[1] = interleaved_by*block_by; ++ want_wei_md.format_desc.blocking.strides[2] = interleaved_by*ic_multiply*kw; ++ want_wei_md.format_desc.blocking.strides[3] = interleaved_by*ic_multiply; ++ ++ acl_utils::update_strides_y_and_z( ++ acp.wei_info, ++ want_wei_md.format_desc.blocking.strides[0] * wei_d.data_type_size(), ++ acp.wei_info.strides_in_bytes().z()); ++ ++ // Set blocking ++ want_wei_md.format_desc.blocking.inner_nblks = (block_by > 1) + 1; ++ want_wei_md.format_desc.blocking.inner_idxs[0] = 0; // second to last dimension in abcd format ++ want_wei_md.format_desc.blocking.inner_blks[0] = interleaved_by; ++ ++ if(block_by > 1) { ++ want_wei_md.format_desc.blocking.inner_idxs[1] = 1; // second to last dimension in abcd format ++ want_wei_md.format_desc.blocking.inner_blks[1] = block_by; ++ } ++ ++ if(is_fast_math_kernel) { ++ // If it is fast math mode we need weights in BFloat16 ++ want_wei_md.data_type = dnnl_bf16; + } + -+ acl_utils::reorder_to_weight_format(acp.wei_tensor_info, weights_md, -+ expected_weight_format, I_dim, O_dim, {W_dim, H_dim}, {}); ++ weights_md = want_wei_md; + return status::success; } -@@ -226,10 +259,10 @@ status_t init_conf_gemm(acl_conv_conf_t &acp, memory_desc_t &src_md, - // clang-format off - // Validate convolution manually to check for return status - ACL_CHECK_VALID(arm_compute::NEGEMMConvolutionLayer::validate( -- &acp.src_info, -- &acp.wei_info, -- acp.with_bias ? &acp.bia_info : nullptr, -- &acp.dst_info, -+ &acp.src_tensor_info, -+ &acp.wei_tensor_info, -+ acp.with_bias ? &acp.bia_tensor_info : nullptr, -+ &acp.dst_tensor_info, - acp.padstride_info, - acp.weights_info, - acp.dilation_info, -@@ -244,28 +277,38 @@ status_t init_conf_indirect_gemm(acl_conv_conf_t &acp, memory_desc_t &src_md, +@@ -219,6 +300,7 @@ status_t init_conf_gemm(acl_conv_conf_t &acp, memory_desc_t &src_md, memory_desc_t &weights_md, memory_desc_t &dst_md, memory_desc_t &bias_md, const convolution_desc_t &cd, const primitive_attr_t &attr) { -- // Indirect convolution results in slowdown for low thread count or 1x1 -- // kernels, so fall back to GEMM-based convolution in these cases -- if (one_of(true, weights_md.dims[2] == 1, // kh -- weights_md.dims[3] == 1, // kw -- dnnl_get_max_threads() < 28)) { -+ -+ // Indirect is slower for small convolution kernels -+ if (weights_md.dims[2] == 1 && weights_md.dims[3] == 1) - return status::unimplemented; -- } ++ acp.is_indirect = false; + // General Compute Library checks, memory tags are also set there CHECK(acl_init_conf(acp, src_md, weights_md, dst_md, bias_md, cd, attr)); +@@ -244,11 +326,13 @@ status_t init_conf_indirect_gemm(acl_conv_conf_t &acp, memory_desc_t &src_md, + memory_desc_t &weights_md, memory_desc_t &dst_md, + memory_desc_t &bias_md, const convolution_desc_t &cd, + const primitive_attr_t &attr) { ++ acp.is_indirect = true; ++ auto math_mode = get_fpmath_mode(); + // Indirect convolution results in slowdown for low thread count or 1x1 + // kernels, so fall back to GEMM-based convolution in these cases + if (one_of(true, weights_md.dims[2] == 1, // kh + weights_md.dims[3] == 1, // kw +- dnnl_get_max_threads() < 28)) { ++ (!math_mode && dnnl_get_max_threads() < 28))) { + return status::unimplemented; + } -+ // Indirect is slower than gemm for low thread counts, except for fast math -+ if (dnnl_get_max_threads() < 28 && !acp.fast_math) -+ return status::unimplemented; -+ -+ // If we do not need to pad input channels for fast math mode then it would -+ // be faster to run convolution with im2row instead of using indirect kernel -+ int block_by = arm_compute::block_by(acp.weights_info.weight_format()); -+ int ic = src_md.dims[1]; -+ if (acp.fast_math && ic % block_by == 0) return status::unimplemented; -+ -+ // TODO: remove this once NEGEMMConv2d::validate allows src and weights to mismatch -+ acp.wei_tensor_info.set_data_layout(arm_compute::DataLayout::NHWC); -+ - // clang-format off - // NOTE: indirect convolution method supports only nhwc layout. - ACL_CHECK_VALID(arm_compute::NEGEMMConv2d::validate( -- &acp.src_info, -- &acp.wei_info, -- acp.with_bias ? &acp.bia_info : nullptr, -- &acp.dst_info, -+ &acp.src_tensor_info, -+ &acp.wei_tensor_info, -+ acp.with_bias ? &acp.bia_tensor_info : nullptr, -+ &acp.dst_tensor_info, - arm_compute::Conv2dInfo(acp.padstride_info, - acp.dilation_info, - acp.act_info, - acp.fast_math, -- 1))); -+ 1, {}, acp.weights_info))); - // clang-format on +@@ -275,6 +359,7 @@ status_t init_conf_wino(acl_conv_conf_t &acp, memory_desc_t &src_md, + memory_desc_t &weights_md, memory_desc_t &dst_md, + memory_desc_t &bias_md, const convolution_desc_t &cd, + const primitive_attr_t &attr) { ++ acp.is_indirect = false; - return status::success; + // Under these conditions, fallback to faster GEMM-based convolution + // unless the user explicitly specifies Winograd algorithm diff --git a/src/cpu/aarch64/acl_convolution_utils.hpp b/src/cpu/aarch64/acl_convolution_utils.hpp -index 0398ab06b9..e3d40a5e75 100644 +index 3e56245fa..44dc8eecb 100644 --- a/src/cpu/aarch64/acl_convolution_utils.hpp +++ b/src/cpu/aarch64/acl_convolution_utils.hpp -@@ -38,17 +38,17 @@ struct acl_obj_t { - - struct acl_conv_conf_t { - bool with_bias; -- bool is_int8; - bool fast_math; +@@ -43,6 +43,7 @@ struct acl_conv_conf_t { // If this is true, the result of the convolution goes into a temporarily // allocated ACL tensor to be accumulated into the oneDNN dst during postops bool use_dst_acc; -- arm_compute::TensorInfo src_info; -- arm_compute::TensorInfo wei_info; -- arm_compute::TensorInfo bia_info; -- arm_compute::TensorInfo dst_info; -+ arm_compute::TensorInfo src_tensor_info; -+ arm_compute::TensorInfo wei_tensor_info; -+ arm_compute::TensorInfo bia_tensor_info; -+ arm_compute::TensorInfo dst_tensor_info; - arm_compute::PadStrideInfo padstride_info; - arm_compute::Size2D dilation_info; -+ // Additional information about the weights not included in wei_tensor_info - arm_compute::WeightsInfo weights_info; - // Note: this will default to not enabled, and will do nothing - arm_compute::ActivationLayerInfo act_info; -diff --git a/src/cpu/aarch64/acl_gemm_convolution.hpp b/src/cpu/aarch64/acl_gemm_convolution.hpp -index 485db954ea..da58e4f610 100644 ---- a/src/cpu/aarch64/acl_gemm_convolution.hpp -+++ b/src/cpu/aarch64/acl_gemm_convolution.hpp -@@ -1,5 +1,5 @@ - /******************************************************************************* --* Copyright 2020-2022 Arm Ltd. and affiliates -+* Copyright 2020-2023 Arm Ltd. and affiliates - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. -@@ -36,10 +36,10 @@ struct acl_resource_t : public resource_t { - if (!acl_obj_) return status::out_of_memory; - - // Init Compute Library tensors based on info from descriptor -- acl_obj_->src_tensor.allocator()->init(acp.src_info); -- acl_obj_->wei_tensor.allocator()->init(acp.wei_info); -- acl_obj_->dst_tensor.allocator()->init(acp.dst_info); -- acl_obj_->bia_tensor.allocator()->init(acp.bia_info); -+ acl_obj_->src_tensor.allocator()->init(acp.src_tensor_info); -+ acl_obj_->wei_tensor.allocator()->init(acp.wei_tensor_info); -+ acl_obj_->dst_tensor.allocator()->init(acp.dst_tensor_info); -+ acl_obj_->bia_tensor.allocator()->init(acp.bia_tensor_info); - - acl_obj_->conv.configure(&acl_obj_->src_tensor, &acl_obj_->wei_tensor, - acp.with_bias ? &acl_obj_->bia_tensor : nullptr, ++ bool is_indirect; + arm_compute::TensorInfo src_info; + arm_compute::TensorInfo wei_info; + arm_compute::TensorInfo bia_info; diff --git a/src/cpu/aarch64/acl_indirect_gemm_convolution.hpp b/src/cpu/aarch64/acl_indirect_gemm_convolution.hpp -index bcf031a771..b7c8dce894 100644 +index bcf031a77..4ddc8cf91 100644 --- a/src/cpu/aarch64/acl_indirect_gemm_convolution.hpp +++ b/src/cpu/aarch64/acl_indirect_gemm_convolution.hpp -@@ -1,5 +1,5 @@ - /******************************************************************************* --* Copyright 2021-2022 Arm Ltd. and affiliates -+* Copyright 2021-2023 Arm Ltd. and affiliates - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. -@@ -35,10 +35,10 @@ struct acl_indirect_gemm_resource_t : public resource_t { - if (!acl_obj_) return status::out_of_memory; - - // Init Compute Library tensors based on info from descriptor -- acl_obj_->src_tensor.allocator()->init(acp.src_info); -- acl_obj_->wei_tensor.allocator()->init(acp.wei_info); -- acl_obj_->dst_tensor.allocator()->init(acp.dst_info); -- acl_obj_->bia_tensor.allocator()->init(acp.bia_info); -+ acl_obj_->src_tensor.allocator()->init(acp.src_tensor_info); -+ acl_obj_->wei_tensor.allocator()->init(acp.wei_tensor_info); -+ acl_obj_->dst_tensor.allocator()->init(acp.dst_tensor_info); -+ acl_obj_->bia_tensor.allocator()->init(acp.bia_tensor_info); +@@ -41,6 +41,7 @@ struct acl_indirect_gemm_resource_t : public resource_t { + acl_obj_->bia_tensor.allocator()->init(acp.bia_info); // clang-format off ++ arm_compute::experimental::PostOpList empty_post_ops = arm_compute::experimental::PostOpList {}; acl_obj_->conv.configure( -@@ -50,7 +50,9 @@ struct acl_indirect_gemm_resource_t : public resource_t { + &acl_obj_->src_tensor, + &acl_obj_->wei_tensor, +@@ -50,7 +51,9 @@ struct acl_indirect_gemm_resource_t : public resource_t { acp.dilation_info, acp.act_info, acp.fast_math, - 1)); + 1, -+ {}, ++ empty_post_ops, + acp.weights_info)); // clang-format on return status::success; diff --git a/src/cpu/aarch64/acl_inner_product.hpp b/src/cpu/aarch64/acl_inner_product.hpp -index c5e507085f..a27df640fb 100644 +index c5e507085..163ff066e 100644 --- a/src/cpu/aarch64/acl_inner_product.hpp +++ b/src/cpu/aarch64/acl_inner_product.hpp -@@ -40,11 +40,13 @@ struct acl_ip_conf_t { - // If this is true, the result of the inner product goes into a temporarily - // allocated ACL tensor to be accumulated into the oneDNN dst during postops - bool use_dst_acc; -- arm_compute::TensorInfo src_info; -- arm_compute::TensorInfo wei_info; -- arm_compute::TensorInfo bia_info; -- arm_compute::TensorInfo dst_info; -+ arm_compute::TensorInfo src_tensor_info; -+ arm_compute::TensorInfo wei_tensor_info; -+ arm_compute::TensorInfo bia_tensor_info; -+ arm_compute::TensorInfo dst_tensor_info; +@@ -45,6 +45,7 @@ struct acl_ip_conf_t { + arm_compute::TensorInfo bia_info; + arm_compute::TensorInfo dst_info; arm_compute::FullyConnectedLayerInfo fc_info; -+ // Additional information about the weights not included in wei_tensor_info + arm_compute::WeightsInfo weights_info; }; struct acl_ip_resource_t : public resource_t { acl_ip_resource_t() : acl_ip_obj_(utils::make_unique()) {} -@@ -53,10 +55,10 @@ struct acl_ip_resource_t : public resource_t { - if (!acl_ip_obj_) return status::out_of_memory; - - // Init Compute Library tensors based on info from descriptor -- acl_ip_obj_->src_tensor.allocator()->init(aip.src_info); -- acl_ip_obj_->wei_tensor.allocator()->init(aip.wei_info); -- acl_ip_obj_->dst_tensor.allocator()->init(aip.dst_info); -- acl_ip_obj_->bia_tensor.allocator()->init(aip.bia_info); -+ acl_ip_obj_->src_tensor.allocator()->init(aip.src_tensor_info); -+ acl_ip_obj_->wei_tensor.allocator()->init(aip.wei_tensor_info); -+ acl_ip_obj_->dst_tensor.allocator()->init(aip.dst_tensor_info); -+ acl_ip_obj_->bia_tensor.allocator()->init(aip.bia_tensor_info); - - // clang-format off - acl_ip_obj_->fc.configure( -@@ -64,7 +66,8 @@ struct acl_ip_resource_t : public resource_t { +@@ -64,7 +65,8 @@ struct acl_ip_resource_t : public resource_t { &acl_ip_obj_->wei_tensor, aip.with_bias ? &acl_ip_obj_->bia_tensor : nullptr, &acl_ip_obj_->dst_tensor, @@ -496,126 +195,41 @@ index c5e507085f..a27df640fb 100644 // clang-format on return status::success; -@@ -89,12 +92,16 @@ struct acl_inner_product_fwd_t : public primitive_t { - DECLARE_COMMON_PD_T("acl", acl_inner_product_fwd_t); +@@ -156,8 +158,8 @@ struct acl_inner_product_fwd_t : public primitive_t { + src_shape = (src_tag == nc) ? arm_compute::TensorShape(ic, n) + : arm_compute::TensorShape(n, ic); - status_t init(engine_t *engine) { -- const bool ok = is_fwd() && !has_zero_dim_memory() -- && expect_data_types(data_type::f32, data_type::f32, -- data_type::f32, data_type::f32, data_type::f32) -+ using namespace data_type; -+ const bool is_fp16_ok = expect_data_types(f16, f16, f16, f16, undef) -+ && attr()->has_default_values( -+ primitive_attr_t::skip_mask_t::post_ops, f16); -+ const bool is_fp32_ok = expect_data_types(f32, f32, f32, f32, undef) - && attr()->has_default_values( -- primitive_attr_t::skip_mask_t::post_ops, -- data_type::f32) -+ primitive_attr_t::skip_mask_t::post_ops, f32); -+ const bool ok = is_fwd() && !has_zero_dim_memory() -+ && utils::one_of(true, is_fp16_ok, is_fp32_ok) - && set_default_params() == status::success; - - if (!ok) return status::unimplemented; -@@ -121,88 +128,46 @@ struct acl_inner_product_fwd_t : public primitive_t { - ACL_CHECK_SUPPORT( - !(is_2d || is_4d), "ACL supports only 2d or 4d cases"); - -- // batch size -- const int n = src_md()->dims[0]; -- -- // input and output channels -- const int ic = src_md()->dims[1]; -- const int oc = dst_md()->dims[1]; -- -- // source spatial dimensions -- const int ih = is_4d ? src_md()->dims[ndims - 2] : 0; -- const int iw = is_4d ? src_md()->dims[ndims - 1] : 0; -- -- // weights spatial dimensions -- const int kh = is_4d ? weights_md()->dims[ndims - 2] : 0; -- const int kw = is_4d ? weights_md()->dims[ndims - 1] : 0; -- -- // Only NCHW or NHWC derivatives supported by ACL kernels - using namespace format_tag; -- auto src_tag = memory_desc_matches_one_of_tag( -- src_md_, nhwc, nchw, nc, cn); -- auto wei_tag = memory_desc_matches_one_of_tag( -- weights_md_, ohwi, oihw, oi, io); -- auto dst_tag = memory_desc_matches_one_of_tag(dst_md_, nc, cn); -+ auto src_tag -+ = memory_desc_matches_one_of_tag(src_md_, nhwc, nchw, nc); -+ auto dst_tag = memory_desc_matches_one_of_tag(dst_md_, nc); - - ACL_CHECK_SUPPORT( -- utils::one_of(format_tag::undef, src_tag, wei_tag, dst_tag), -+ utils::one_of(format_tag::undef, src_tag, dst_tag), - "unsupported memory layout"); - - ACL_CHECK_SUPPORT(is_2d && src_tag != dst_tag, - "for src and dst layouts must match"); - -- arm_compute::TensorShape src_shape, wei_shape; -- if (is_2d) { -- src_shape = (src_tag == nc) ? arm_compute::TensorShape(ic, n) -- : arm_compute::TensorShape(n, ic); -- - wei_shape = (wei_tag == io) ? arm_compute::TensorShape(oc, ic) - : arm_compute::TensorShape(ic, oc); -- } -- if (is_4d) { -- src_shape = (src_tag == nhwc) -- ? arm_compute::TensorShape(ic, iw, ih, n) -- : arm_compute::TensorShape(iw, ih, ic, n); -- -- // ACL requires the weights to be in 2D flattened shape -- const int flattened_ic = is_4d ? ic * kh * kw : ic; ++ // For fixed format kernels weight shape is always io ++ wei_shape = arm_compute::TensorShape(oc, ic); + } + if (is_4d) { + src_shape = (src_tag == nhwc) +@@ -166,7 +168,8 @@ struct acl_inner_product_fwd_t : public primitive_t { + + // ACL requires the weights to be in 2D flattened shape + const int flattened_ic = is_4d ? ic * kh * kw : ic; - wei_shape = arm_compute::TensorShape(flattened_ic, oc); -- } -- -- arm_compute::DataLayout src_layout = (src_tag == nhwc) -- ? arm_compute::DataLayout::NHWC -- : arm_compute::DataLayout::NCHW; -+ const dim_t ic_total = IC_total(); -+ const dim_t n = MB(); -+ const dim_t oc = OC(); - -- arm_compute::DataLayout wei_layout = (wei_tag == ohwi) -- ? arm_compute::DataLayout::NHWC -- : arm_compute::DataLayout::NCHW; -+ aip.src_tensor_info = arm_compute::TensorInfo( -+ arm_compute::TensorShape(ic_total, n), 1, -+ acl_utils::get_acl_data_t(src_md()->data_type)); - -- aip.src_info = arm_compute::TensorInfo( -- src_shape, 1, arm_compute::DataType::F32, src_layout); -+ // ACL requires the weights to be in 2D flattened shape -+ aip.wei_tensor_info = arm_compute::TensorInfo( -+ arm_compute::TensorShape(oc, ic_total), 1, -+ acl_utils::get_acl_data_t(weights_md(0)->data_type)); - -- aip.wei_info = arm_compute::TensorInfo( -- wei_shape, 1, arm_compute::DataType::F32, wei_layout); -- -- aip.dst_info -- = arm_compute::TensorInfo(arm_compute::TensorShape(oc, n), -- 1, arm_compute::DataType::F32); -+ auto acl_dst_data_t -+ = acl_utils::get_acl_data_t(dst_md()->data_type); -+ aip.dst_tensor_info = arm_compute::TensorInfo( -+ arm_compute::TensorShape(oc, n), 1, acl_dst_data_t); - - aip.with_bias = desc()->bias_desc.format_kind != format_kind::undef; -- aip.bia_info = arm_compute::TensorInfo(aip.with_bias -+ auto acl_bia_data_t = aip.with_bias -+ ? acl_utils::get_acl_data_t(weights_md(1)->data_type) -+ : acl_dst_data_t; -+ aip.bia_tensor_info = arm_compute::TensorInfo(aip.with_bias - ? arm_compute::TensorShape(oc) - : arm_compute::TensorShape(), ++ // For fixed format kernels weights shape is always io ++ wei_shape = arm_compute::TensorShape(oc, flattened_ic); + } + + arm_compute::DataLayout src_layout = (src_tag == nhwc) +@@ -183,6 +186,9 @@ struct acl_inner_product_fwd_t : public primitive_t { + aip.wei_info = arm_compute::TensorInfo( + wei_shape, 1, arm_compute::DataType::F32, wei_layout); + ++ aip.weights_info = arm_compute::WeightsInfo( ++ false, 1, 1, is_4d ? ic * kh *kw : ic, false, arm_compute::WeightFormat::ANY); ++ + aip.dst_info + = arm_compute::TensorInfo(arm_compute::TensorShape(oc, n), + 1, arm_compute::DataType::F32); +@@ -194,15 +200,7 @@ struct acl_inner_product_fwd_t : public primitive_t { 1, arm_compute::DataType::F32); -- aip.fc_info.weights_trained_layout = wei_layout; + aip.fc_info.weights_trained_layout = wei_layout; - if (is_2d && wei_tag != src_tag) { - // weights are already transposed - aip.fc_info.transpose_weights = false; @@ -629,536 +243,294 @@ index c5e507085f..a27df640fb 100644 // Fast math mode auto math_mode = get_fpmath_mode(); -@@ -214,15 +179,103 @@ struct acl_inner_product_fwd_t : public primitive_t { +@@ -214,6 +212,80 @@ struct acl_inner_product_fwd_t : public primitive_t { aip.fc_info.activation_info)); aip.use_dst_acc = post_ops.has_sum(); -+ // WeightFormat::ANY tells ACL we can handle any format -+ aip.weights_info = arm_compute::WeightsInfo(false, 1, 1, ic_total, -+ false, arm_compute::WeightFormat::ANY); -+ -+ // Get the format that the ACL kernel will expect the weights to be -+ // in (if a kernel exists) Note that these are referred to as fixed -+ // format kernels, because they require one specific weights format + arm_compute::WeightFormat expected_weight_format; -+ ACL_CHECK_VALID(arm_compute::NEFullyConnectedLayer::has_opt_impl( -+ expected_weight_format, &aip.src_tensor_info, -+ &aip.wei_tensor_info, -+ aip.with_bias ? &aip.bia_tensor_info : nullptr, -+ &aip.dst_tensor_info, aip.fc_info, aip.weights_info)); ++ auto acl_st = arm_compute::NEFullyConnectedLayer::has_opt_impl( ++ expected_weight_format, ++ &aip.src_info, ++ &aip.wei_info, ++ aip.with_bias ? &aip.bia_info : nullptr, ++ &aip.dst_info, ++ aip.fc_info, ++ aip.weights_info); ++ if(acl_st.error_code() != arm_compute::ErrorCode::OK) { ++ return status::unimplemented; ++ } + -+ // Set weights info to the one returned by has_opt_impl + aip.weights_info.set_weight_format(expected_weight_format); + -+ // has_opt_impl may return a non fast math kernel, even if requested -+ aip.fc_info.enable_fast_math -+ = arm_compute::is_fixed_format_fast_math( -+ expected_weight_format); ++ int interleaved_by = arm_compute::interleave_by(expected_weight_format); ++ int block_by = arm_compute::block_by(expected_weight_format); ++ bool is_fast_math_kernel = arm_compute::is_fixed_format_fast_math(expected_weight_format); + -+ // Inner product is the same as the matmul n x (chw) * (ihw) x o -+ // (note that the src c and weights i both correspond to the input -+ // channel). ACL FullyConnectedLayer assumes the chw dimensions of -+ // src and ihw dimensions of weights are collapsed, so we need to -+ // make sure that they have the same layout. Given that weights are -+ // more often fixed, (so reorders can be hoisted) it makes sense to -+ // reorder the weights to fit the src. ++ if(!is_fast_math_kernel) { ++ // FP32 kernel might be faster for some cases then BF16 ++ aip.fc_info.enable_fast_math = false; ++ } ++ ++ memory_desc_t want_wei_md = weights_md_; + -+ // For 4D tensors we need to: -+ // - reorder the ihw of the weights to match the src chw -+ // - collapse ihw -+ // - pad the collapsed ihw -+ // But there is not yet a way to express this collapse+pad as a -+ // reorder. So we try to reorder the weights to match the src, -+ // implicitly collapse ihw in our definition of the weights -+ // TensorInfo and hope that the inner_dim has zero padding -+ // (weights_md_.dims[inner_dim] % block_by == 0). If it does, we -+ // fall back to a kernel without blocking (currently this is -+ // equivalent to non-fastmath). ++ int ic_multiply = ic; ++ if(is_4d) { ++ ic_multiply = ic * kh * kw; + -+ // 2D just works because we just pad the only dimension. ++ // Since we are flattening dimensions the memory descriptor ++ // should also be for 2D ++ want_wei_md.ndims = 2; + -+ // o_dim is always the first logical dimension (oihw, ohwi, oi) -+ dim_t o_dim = 0; -+ dim_t inner_dim; -+ // Rest of logical dimensions in order of innermost to outermost -+ std::vector remaining_dims = {}; ++ want_wei_md.dims[1] = ic_multiply; ++ want_wei_md.padded_dims[1] = ic_multiply; ++ want_wei_md.format_desc.blocking.strides[1] = 1; + -+ if (src_tag == nchw) { -+ inner_dim = 3; // w -+ remaining_dims = {2, 1}; // h, i -+ } else if (src_tag == nhwc) { -+ inner_dim = 1; // i -+ remaining_dims = {3, 2}; // w, h -+ } else { // Only remaining case is 2D (nc) -+ inner_dim = 1; // i -+ remaining_dims = {}; // No other dimensions for 2D ++ want_wei_md.dims[0] = oc; ++ want_wei_md.padded_dims[0] = want_wei_md.padded_dims[1]; ++ want_wei_md.padded_dims[0] = oc; + } + -+ // Fallback -+ int block_by = arm_compute::block_by(expected_weight_format); -+ if (is_4d && weights_md_.dims[inner_dim] % block_by != 0 -+ && aip.fc_info.enable_fast_math) { -+ aip.fc_info.enable_fast_math = false; -+ aip.weights_info.set_weight_format( -+ arm_compute::WeightFormat::ANY); -+ ACL_CHECK_VALID( -+ arm_compute::NEFullyConnectedLayer::has_opt_impl( -+ expected_weight_format, &aip.src_tensor_info, -+ &aip.wei_tensor_info, -+ aip.with_bias ? &aip.bia_tensor_info : nullptr, -+ &aip.dst_tensor_info, aip.fc_info, -+ aip.weights_info)); -+ aip.weights_info.set_weight_format(expected_weight_format); -+ block_by = arm_compute::block_by(expected_weight_format); -+ if (weights_md_.dims[inner_dim] % block_by != 0) -+ return status::unimplemented; ++ want_wei_md.format_desc.blocking.strides[1] = interleaved_by * block_by; ++ if(want_wei_md.dims[1] % block_by != 0) { ++ want_wei_md.padded_dims[1] = utils::div_up(want_wei_md.dims[1], block_by) * block_by; + } ++ want_wei_md.format_desc.blocking.strides[0] = interleaved_by * want_wei_md.padded_dims[1]; + -+ acl_utils::reorder_to_weight_format(aip.wei_tensor_info, -+ weights_md_, expected_weight_format, inner_dim, o_dim, -+ remaining_dims, {}); ++ if(oc % interleaved_by != 0) { ++ int padded_dim = utils::div_up(oc, interleaved_by) * interleaved_by; ++ want_wei_md.padded_dims[0] = padded_dim; ++ } + - // clang-format off ++ int data_type_size = memory_desc_wrapper(want_wei_md).data_type_size(); ++ acl_utils::update_strides_y_and_z( ++ aip.wei_info, ++ want_wei_md.format_desc.blocking.strides[0] * data_type_size, ++ want_wei_md.format_desc.blocking.strides[1] * data_type_size); ++ ++ want_wei_md.format_desc.blocking.inner_nblks = (block_by > 1) + 1; ++ want_wei_md.format_desc.blocking.inner_idxs[0] = 0; ++ want_wei_md.format_desc.blocking.inner_blks[0] = interleaved_by; ++ if(block_by > 1) { ++ want_wei_md.format_desc.blocking.inner_idxs[1] = 1; ++ want_wei_md.format_desc.blocking.inner_blks[1] = block_by; ++ } ++ ++ if(is_fast_math_kernel) { ++ want_wei_md.data_type = dnnl_bf16; ++ } ++ ++ weights_md_ = want_wei_md; + + // clang-format off // Validate fully connected layer manually to check for return status ACL_CHECK_VALID(arm_compute::NEFullyConnectedLayer::validate( -- &aip.src_info, -- &aip.wei_info, -- aip.with_bias ? &aip.bia_info : nullptr, -- &aip.dst_info, -- aip.fc_info)); -+ &aip.src_tensor_info, -+ &aip.wei_tensor_info, -+ aip.with_bias ? &aip.bia_tensor_info : nullptr, -+ &aip.dst_tensor_info, -+ aip.fc_info, -+ aip.weights_info)); - // clang-format on -+ - return status::success; - } - }; // pd_t diff --git a/src/cpu/aarch64/acl_utils.cpp b/src/cpu/aarch64/acl_utils.cpp -index 79ea775d6d..5792fd4911 100644 +index 79ea775d6..7ee4c7398 100644 --- a/src/cpu/aarch64/acl_utils.cpp +++ b/src/cpu/aarch64/acl_utils.cpp -@@ -1,5 +1,5 @@ - /******************************************************************************* --* Copyright 2021-2022 Arm Ltd. and affiliates -+* Copyright 2021-2023 Arm Ltd. and affiliates - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. -@@ -261,6 +261,75 @@ int reorder_dimensions_by_stride(std::vector permuted_mds, - return reordered_dims; +@@ -157,6 +157,28 @@ status_t tensor_info( + return status::success; } -+void reorder_to_weight_format(arm_compute::TensorInfo &info, memory_desc_t &md, -+ arm_compute::WeightFormat wf, dim_t I_dim, dim_t O_dim, -+ std::vector spatial_dims, std::vector batch_dims) { -+ -+ md.format_kind = format_kind::blocked; -+ md.format_desc.blocking = blocking_desc_t {}; -+ const int interleaved_by = arm_compute::interleave_by(wf); -+ const int block_by = arm_compute::block_by(wf); -+ -+ // I dimension becomes densest (apart from blocking) -+ md.format_desc.blocking.strides[I_dim] = interleaved_by * block_by; -+ md.padded_dims[I_dim] = utils::rnd_up(md.dims[I_dim], block_by); -+ -+ // Then any spatial dimensions (e.g. HW) -+ dim_t ldb = interleaved_by * md.padded_dims[I_dim]; -+ for (dim_t sd : spatial_dims) { -+ md.format_desc.blocking.strides[sd] = ldb; -+ ldb *= md.padded_dims[sd]; -+ } ++status_t update_strides_y_and_z( ++ arm_compute::TensorInfo &info, const int y, const int z) { + -+ // O dim (which was the innermost) becomes the outermost (apart from batching) -+ md.format_desc.blocking.strides[O_dim] = ldb; -+ md.padded_dims[O_dim] = utils::rnd_up(md.dims[O_dim], interleaved_by); -+ -+ // Update the batch dimensions, starting with stride of the innermost batch -+ const dim_t innermost_batch_stride -+ = md.padded_dims[I_dim] * md.padded_dims[O_dim]; -+ dim_t batch_stride = innermost_batch_stride; -+ for (dim_t bd : batch_dims) { -+ md.format_desc.blocking.strides[bd] = batch_stride; -+ batch_stride *= md.padded_dims[bd]; -+ } -+ -+ // Weights can only be blocked if they are also interleaved -+ if (interleaved_by > 1) { -+ md.format_desc.blocking.inner_nblks = 1 + (block_by > 1); -+ -+ md.format_desc.blocking.inner_idxs[0] = O_dim; -+ md.format_desc.blocking.inner_blks[0] = interleaved_by; -+ if (block_by > 1) { -+ md.format_desc.blocking.inner_idxs[1] = I_dim; -+ md.format_desc.blocking.inner_blks[1] = block_by; -+ } -+ } ++ arm_compute::TensorShape shape = info.tensor_shape(); ++ arm_compute::Strides old_strides_in_bytes = info.strides_in_bytes(); + -+ if (arm_compute::is_fixed_format_fast_math(wf)) { -+ md.data_type = dnnl_bf16; -+ info.set_data_type(arm_compute::DataType::BFLOAT16); ++ arm_compute::Strides new_strides_in_bytes; ++ for(size_t i = 0; i < shape.num_dimensions(); ++i) { ++ new_strides_in_bytes.set(i, old_strides_in_bytes[i]); + } + -+ // The data layout is now determined by the manually set strides -+ info.set_data_layout(arm_compute::DataLayout::UNKNOWN); -+ -+ // x is ignored in fixed format kernels -+ // y is the leading dimension of b (ldb) in the GEMM d = a*b + c -+ // This is the stride of O_dim in the md -+ // z is the batch dimension (not strictly needed if there's only 1 batch) -+ // i.e. how much do I need to stride to get to the next matmul (ignoring -+ // the interleaving). Note that we use the innermost_batch_stride -+ // because all the batched dimensions are collapsed (as required by ACL). -+ arm_compute::Strides new_strides_in_bytes = info.strides_in_bytes(); -+ new_strides_in_bytes.set(1, ldb * info.element_size()); -+ new_strides_in_bytes.set(2, innermost_batch_stride * info.element_size()); ++ // set y ++ new_strides_in_bytes.set(1, y); ++ // set z ++ new_strides_in_bytes.set(2, z); + + info.init(info.tensor_shape(), info.num_channels(), info.data_type(), -+ new_strides_in_bytes, info.offset_first_element_in_bytes(), -+ memory_desc_wrapper(md).size()); ++ new_strides_in_bytes, info.offset_first_element_in_bytes(), info.total_size()); ++ ++ return status::success; +} + - } // namespace acl_utils + status_t insert_singleton_dimension(arm_compute::TensorInfo &ti, size_t dim_i) { - } // namespace aarch64 + // Max 6 dims in ACL, so we can't insert another diff --git a/src/cpu/aarch64/acl_utils.hpp b/src/cpu/aarch64/acl_utils.hpp -index 28693bb167..d9affe1c8f 100644 +index 28693bb16..c7c9e1278 100644 --- a/src/cpu/aarch64/acl_utils.hpp +++ b/src/cpu/aarch64/acl_utils.hpp -@@ -1,5 +1,5 @@ - /******************************************************************************* --* Copyright 2021-2022 Arm Ltd. and affiliates -+* Copyright 2021-2023 Arm Ltd. and affiliates - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. -@@ -74,6 +74,28 @@ status_t insert_singleton_dimension(arm_compute::TensorInfo &ti, size_t dim_i); - int reorder_dimensions_by_stride(std::vector permuted_mds, - std::vector mds); +@@ -62,6 +62,9 @@ status_t tensor_info(arm_compute::TensorInfo &info, const memory_desc_t &md); + status_t tensor_info( + arm_compute::TensorInfo &info, const memory_desc_wrapper &md); -+// Reorder a memory_desc_t and set the strides on a arm_compute::TensorInfo to -+// match an arm_compute::WeightFormat. You are required to specify how various -+// logical dimensions in oneDNN correspond to logical dimensions in arm_compute. -+// info TensorInfo where the strides will be changed to match the reordering -+// md memory descriptor where the stride and padded dimensions will be -+// changed or reordering -+// wf Describes the memory format/layout of the weights -+// I_dim The logical dimension of md corresponding to the input channel of -+// a convolution or the K dimension in a matmul -+// O_dim The logical dimension of md corresponding to the output channel of a -+//   convolution or the N dimension in a matmul -+// spatial_dims The logical dimensions of md corresponding to the spatial -+// dimensions of the weights (H, W, D for example). These will be -+// the next densest after the inner blocks and the input channel. -+// batch_dims The logical dimensions of md related to the batch in a batched -+// matmul, ordered from innermost to outermost. ACL calls these -+// the multi_stride_b. These will become the outermost (least dense) -+// dimensions and will be collapsed. -+void reorder_to_weight_format(arm_compute::TensorInfo &info, memory_desc_t &md, -+ arm_compute::WeightFormat wf, dim_t I_dim, dim_t O_dim, -+ std::vector spatial_dims, std::vector batch_dims = {}); ++// Update y and z strides in arm_compute::TensorInfo ++status_t update_strides_y_and_z(arm_compute::TensorInfo &info, const int y, const int z); + - // Logs a custom 'info' line describing an unsupported case - #define LOG_ACL_UNSUPPORTED(msg) \ - do { \ -diff --git a/src/cpu/aarch64/matmul/acl_matmul.cpp b/src/cpu/aarch64/matmul/acl_matmul.cpp -index dce220fb6e..ca1c7eb47e 100644 ---- a/src/cpu/aarch64/matmul/acl_matmul.cpp -+++ b/src/cpu/aarch64/matmul/acl_matmul.cpp -@@ -1,5 +1,5 @@ - /******************************************************************************* --* Copyright 2021-2022 Arm Ltd. and affiliates -+* Copyright 2021-2023 Arm Ltd. and affiliates - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. -@@ -31,36 +31,19 @@ status_t acl_matmul_t::execute_forward(const exec_ctx_t &ctx) const { - auto wei_base = CTX_IN_MEM(const data_t *, DNNL_ARG_WEIGHTS); - - bool is_transA = pd()->amp_.is_transA; -- bool is_transB = pd()->amp_.is_transB; - bool use_dst_acc = pd()->amp_.use_dst_acc; - - std::lock_guard _lock {this->mtx}; - auto *acl_resource = ctx.get_resource_mapper()->get(this); - acl_matmul_obj_t &acl_obj = acl_resource->get_acl_obj(); - // Run transpose kernel -- if (is_transA && !is_transB) { -+ if (is_transA) { - acl_obj.src_tensor.allocator()->allocate(); - acl_obj.src_acc_tensor.allocator()->import_memory( - const_cast(src_base)); - acl_obj.transA.run(); - acl_obj.wei_tensor.allocator()->import_memory( - const_cast(wei_base)); -- } else if (is_transB && !is_transA) { -- acl_obj.wei_tensor.allocator()->allocate(); -- acl_obj.wei_acc_tensor.allocator()->import_memory( -- const_cast(wei_base)); -- acl_obj.transB.run(); -- acl_obj.src_tensor.allocator()->import_memory( -- const_cast(src_base)); -- } else if (is_transA && is_transB) { -- acl_obj.src_tensor.allocator()->allocate(); -- acl_obj.src_acc_tensor.allocator()->import_memory( -- const_cast(src_base)); -- acl_obj.wei_tensor.allocator()->allocate(); -- acl_obj.wei_acc_tensor.allocator()->import_memory( -- const_cast(wei_base)); -- acl_obj.transA.run(); -- acl_obj.transB.run(); - } else { - acl_obj.src_tensor.allocator()->import_memory( - const_cast(src_base)); -@@ -69,7 +52,7 @@ status_t acl_matmul_t::execute_forward(const exec_ctx_t &ctx) const { - } - - if (use_dst_acc) { -- // Put the result in a new tensor, it will be accumalated to the dst -+ // Put the result in a new tensor, it will be accumulated to the dst - // during the post ops - acl_obj.dst_tensor.allocator()->allocate(); - } else { -@@ -82,7 +65,6 @@ status_t acl_matmul_t::execute_forward(const exec_ctx_t &ctx) const { - acl_obj.src_tensor.allocator()->free(); - acl_obj.wei_tensor.allocator()->free(); - if (is_transA) acl_obj.src_acc_tensor.allocator()->free(); -- if (is_transB) acl_obj.wei_acc_tensor.allocator()->free(); + // Insert a dimension of size 1 at the index dim_i of TensorInfo + status_t insert_singleton_dimension(arm_compute::TensorInfo &ti, size_t dim_i); - void *dst = acl_obj.dst_tensor.buffer(); - pd()->post_ops.execute(ctx, dst); -diff --git a/src/cpu/aarch64/matmul/acl_matmul.hpp b/src/cpu/aarch64/matmul/acl_matmul.hpp -index cdc942e995..832b1dbb68 100644 ---- a/src/cpu/aarch64/matmul/acl_matmul.hpp -+++ b/src/cpu/aarch64/matmul/acl_matmul.hpp -@@ -32,20 +32,15 @@ struct acl_resource_t : public resource_t { - - status_t configure(const acl_matmul_conf_t &) { - if (!acl_obj_) return status::out_of_memory; -- acl_obj_->src_tensor.allocator()->init(amp.src_info); -- acl_obj_->wei_tensor.allocator()->init(amp.wei_info); -- acl_obj_->dst_tensor.allocator()->init(amp.dst_info); -+ acl_obj_->src_tensor.allocator()->init(amp.src_tensor_info); -+ acl_obj_->wei_tensor.allocator()->init(amp.wei_tensor_info); -+ acl_obj_->dst_tensor.allocator()->init(amp.dst_tensor_info); - // Configure transpose kernel for src, wei or both - if (amp.is_transA) { - acl_obj_->src_acc_tensor.allocator()->init(amp.src_acc_info); - acl_obj_->transA.configure( - &acl_obj_->src_acc_tensor, &acl_obj_->src_tensor); - } -- if (amp.is_transB) { -- acl_obj_->wei_acc_tensor.allocator()->init(amp.wei_acc_info); -- acl_obj_->transB.configure( -- &acl_obj_->wei_acc_tensor, &acl_obj_->wei_tensor); -- } - // Configure GEMM - acl_obj_->gemm.configure(&acl_obj_->src_tensor, &acl_obj_->wei_tensor, - nullptr, &acl_obj_->dst_tensor, amp.alpha, 0.0f, amp.gemm_info); -@@ -72,12 +67,20 @@ struct acl_matmul_t : public primitive_t { - - status_t init(engine_t *engine) { - using smask_t = primitive_attr_t::skip_mask_t; -- bool ok = src_md()->data_type == data_type::f32 -- && weights_md()->data_type == data_type::f32 -- && desc()->accum_data_type == data_type::f32 -- && dst_md()->data_type == data_type::f32 -- && platform::has_data_type_support(data_type::f32) -+ const bool is_fp32_ok -+ = utils::everyone_is(data_type::f32, src_md()->data_type, -+ weights_md()->data_type, dst_md()->data_type, -+ desc()->accum_data_type) -+ && platform::has_data_type_support(data_type::f32); -+ const bool is_fp16_ok -+ = utils::everyone_is(data_type::f16, src_md()->data_type, -+ weights_md()->data_type, dst_md()->data_type) -+ && platform::has_data_type_support(data_type::f16); -+ bool ok = is_dense_data() -+ && utils::one_of(true, is_fp32_ok, is_fp16_ok) - && !has_zero_dim_memory() -+ && set_default_formats() - && attr()->has_default_values( - smask_t::oscale | smask_t::post_ops) - && attr_oscale_ok() && !has_runtime_dims_or_strides(); -@@ -92,9 +95,9 @@ struct acl_matmul_t : public primitive_t { - amp_.use_dst_acc = post_ops.has_sum(); - - // Validate ACL GEMM -- ACL_CHECK_VALID(arm_compute::NEGEMM::validate(&_.src_info, -- &_.wei_info, nullptr, &_.dst_info, amp_.alpha, 0.0f, -- amp_.gemm_info)); -+ ACL_CHECK_VALID(arm_compute::NEGEMM::validate(&_.src_tensor_info, -+ &_.wei_tensor_info, nullptr, &_.dst_tensor_info, -+ amp_.alpha, 0.0f, amp_.gemm_info)); - - return status::success; - } diff --git a/src/cpu/aarch64/matmul/acl_matmul_utils.cpp b/src/cpu/aarch64/matmul/acl_matmul_utils.cpp -index 679baec3a4..30bc2c1443 100644 +index 679baec3a..853277e37 100644 --- a/src/cpu/aarch64/matmul/acl_matmul_utils.cpp +++ b/src/cpu/aarch64/matmul/acl_matmul_utils.cpp -@@ -41,6 +41,7 @@ status_t init_conf_matmul(acl_matmul_conf_t &, memory_desc_t &src_md, - const dim_t src_batch = helper.src_batch(); - const dim_t wei_batch = helper.wei_batch(); - -+ // We can only broadcast on one of src or wei at once - // ACL supports broadcast for 3D shapes, and 4D shapes - // for e.g when ab in abcd is 1x1 - bool batch_ok = IMPLICATION(src_batch > 1, wei_batch == 1) -@@ -53,44 +54,33 @@ status_t init_conf_matmul(acl_matmul_conf_t &, memory_desc_t &src_md, - bool with_bias = md.bias_desc.format_kind != format_kind::undef; - ACL_CHECK_SUPPORT(with_bias, "ACL does not support bias for matmul"); +@@ -66,15 +66,12 @@ status_t init_conf_matmul(acl_matmul_conf_t &, memory_desc_t &src_md, -+ // The two innermost dimensions can be transposed, but the batch dimensions -+ // must be the outermost - using namespace format_tag; - auto src_tag = memory_desc_matches_one_of_tag( - src_md, abcd, abdc, abc, acb, ab, ba); -- auto wei_tag = memory_desc_matches_one_of_tag( -- wei_md, abcd, abdc, abc, acb, ab, ba); -- auto dst_tag -- = memory_desc_matches_one_of_tag(dst_md, abcd, abc, acb, ab, ba); -- ACL_CHECK_SUPPORT( -- utils::one_of(format_tag::undef, src_tag, wei_tag, dst_tag), -+ auto dst_tag = memory_desc_matches_one_of_tag(dst_md, abcd, abc, ab, ba); -+ ACL_CHECK_SUPPORT(utils::one_of(format_tag::undef, src_tag, dst_tag), - "Format tag is undefined"); - -- // Transpose A (src) or B (wei) -+ // Transpose A (src) + // Transpose A (src) or B (wei) amp.is_transA = helper.transA() == 'T'; - amp.is_transB = helper.transB() == 'T'; -+ -+ auto acl_src_data_t = acl_utils::get_acl_data_t(src_md.data_type); -+ auto acl_wei_data_t = acl_utils::get_acl_data_t(wei_md.data_type); -+ auto acl_dst_data_t = acl_utils::get_acl_data_t(dst_md.data_type); ++ amp.is_transB = false; + if (amp.is_transA) amp.src_acc_info = arm_compute::TensorInfo( arm_compute::TensorShape(M, K, 1, src_batch), 1, -- arm_compute::DataType::F32); + arm_compute::DataType::F32); - if (amp.is_transB) - amp.wei_acc_info = arm_compute::TensorInfo( - arm_compute::TensorShape(K, N, wei_batch), 1, - arm_compute::DataType::F32); -- -- amp.src_info = arm_compute::TensorInfo( -- arm_compute::TensorShape(K, M, 1, src_batch), 1, -- arm_compute::DataType::F32); -- amp.wei_info -- = arm_compute::TensorInfo(arm_compute::TensorShape(N, K, wei_batch), -- 1, arm_compute::DataType::F32); -- amp.dst_info = arm_compute::TensorInfo( -- arm_compute::TensorShape(N, M, 1, dst_batch), 1, -- arm_compute::DataType::F32); -- -- // Fast-math mode -- auto math_mode = get_fpmath_mode(); -- bool is_fastmath_enabled -- = utils::one_of(math_mode, fpmath_mode::bf16, fpmath_mode::any); -- amp.gemm_info.set_fast_math(is_fastmath_enabled); -+ acl_src_data_t); -+ -+ amp.src_tensor_info = arm_compute::TensorInfo( -+ arm_compute::TensorShape(K, M, 1, src_batch), 1, acl_src_data_t); -+ amp.wei_tensor_info = arm_compute::TensorInfo( -+ arm_compute::TensorShape(N, K, wei_batch), 1, acl_wei_data_t); -+ amp.dst_tensor_info = arm_compute::TensorInfo( -+ arm_compute::TensorShape(N, M, 1, dst_batch), 1, acl_dst_data_t); - // Set alpha (output scaling) - amp.alpha = attr.output_scales_.scales_[0]; -@@ -98,10 +88,45 @@ status_t init_conf_matmul(acl_matmul_conf_t &, memory_desc_t &src_md, - // Validate ACL transpose - if (amp.is_transA) + amp.src_info = arm_compute::TensorInfo( + arm_compute::TensorShape(K, M, 1, src_batch), 1, +@@ -103,6 +100,140 @@ status_t init_conf_matmul(acl_matmul_conf_t &, memory_desc_t &src_md, ACL_CHECK_VALID(arm_compute::NETranspose::validate( -- &.src_acc_info, &.src_info)); -- if (amp.is_transB) -- ACL_CHECK_VALID(arm_compute::NETranspose::validate( -- &.wei_acc_info, &.wei_info)); -+ &.src_acc_info, &.src_tensor_info)); -+ -+ bool is_fastmath_enabled = utils::one_of( -+ attr.fpmath_mode_, fpmath_mode::bf16, fpmath_mode::any); -+ amp.gemm_info.set_fast_math(is_fastmath_enabled); + &.wei_acc_info, &.wei_info)); + ++ arm_compute::WeightFormat expected_weight_format; + + amp.gemm_info.set_fixed_format(true); -+ -+ // WeightFormat::ANY tells ACL we can handle any format + amp.gemm_info.set_weight_format(arm_compute::WeightFormat::ANY); + -+ // Get the format that the ACL kernel will expect the weights to be -+ // in (if a kernel exists). Note that these are referred to as fixed format -+ // kernels, because they require one specific weights format -+ arm_compute::WeightFormat expected_weight_format; -+ ACL_CHECK_VALID(arm_compute::NEGEMM::has_opt_impl(expected_weight_format, -+ &.src_tensor_info, &.wei_tensor_info, nullptr, -+ &.dst_tensor_info, amp.alpha, 0.0f, amp.gemm_info)); ++ auto acl_st = arm_compute::NEGEMM::has_opt_impl( ++ expected_weight_format, ++ &.src_info, ++ &.wei_info, ++ nullptr, ++ &.dst_info, ++ amp.alpha, ++ 0.0f, ++ amp.gemm_info); ++ ++ if(acl_st.error_code() != arm_compute::ErrorCode::OK) { ++ return status::unimplemented; ++ } + -+ // Set gemm weights info to the one returned by has_opt_impl + amp.gemm_info.set_weight_format(expected_weight_format); + -+ // has_opt_impl may return a non fast math kernel, even if we requested one -+ amp.gemm_info.set_fast_math( -+ arm_compute::is_fixed_format_fast_math(expected_weight_format)); ++ memory_desc_t want_wei_md = wei_md; + -+ // Logical dimension indices -+ dim_t innermost_dim = wei_md.ndims - 1; -+ dim_t N_dim = innermost_dim; -+ dim_t K_dim = innermost_dim - 1; ++ // We need to transpose second to last dimension and use blocking ++ // as returned by interleave by from expecting strides ++ int interleaved_by = arm_compute::interleave_by(expected_weight_format); ++ int block_by = arm_compute::block_by(expected_weight_format); ++ bool is_fast_math_kernel = arm_compute::is_fixed_format_fast_math(expected_weight_format); ++ if(!is_fast_math_kernel) { ++ amp.gemm_info.set_fast_math(false); ++ } + -+ // The logical indices of dimensions related to the batch, ordered from -+ // innermost to outermost -+ std::vector batch_dims = {}; -+ for (dim_t i = K_dim - 1; i >= 0; --i) -+ batch_dims.push_back(i); ++ int blocked_first_dimension = -1; ++ int blocked_second_dimension = -1; ++ ++ // Assume that interleaved by is X and blocked by is Y ++ switch(want_wei_md.ndims) { ++ case 2: { ++ // For 2D case the format that we need to pass is BaXb and ++ // when doing fast mode BAXbYa ++ want_wei_md.format_desc.blocking.strides[0] = interleaved_by * block_by; ++ // check to see whether we need to pad ++ if(want_wei_md.dims[0] % block_by != 0) { ++ want_wei_md.padded_dims[0] = utils::div_up(want_wei_md.dims[0], block_by) * block_by; ++ } ++ want_wei_md.format_desc.blocking.strides[1] = interleaved_by * want_wei_md.padded_dims[0]; ++ if(want_wei_md.dims[1] % interleaved_by != 0) { ++ want_wei_md.padded_dims[1] = utils::div_up(want_wei_md.dims[1], interleaved_by) * interleaved_by; ++ } ++ ++ acl_utils::update_strides_y_and_z( ++ amp.wei_info, ++ want_wei_md.format_desc.blocking.strides[1] * wei_d.data_type_size(), ++ want_wei_md.format_desc.blocking.strides[0] * wei_d.data_type_size()); ++ ++ blocked_first_dimension = 1; ++ blocked_second_dimension = 0; ++ ++ break; ++ } ++ ++ case 3: { ++ // For 3D case the format we need to pass is aCbXc and ++ // when doing fast mode is aCBXcYb ++ want_wei_md.format_desc.blocking.strides[1] = interleaved_by*block_by; ++ if(want_wei_md.dims[1] % block_by != 0) { ++ want_wei_md.padded_dims[1] = utils::div_up(want_wei_md.dims[1], block_by) * block_by; ++ } ++ want_wei_md.format_desc.blocking.strides[2] = interleaved_by * want_wei_md.padded_dims[1]; ++ if(want_wei_md.dims[2] % interleaved_by != 0) { ++ want_wei_md.padded_dims[2] = utils::div_up(want_wei_md.dims[2], interleaved_by) * interleaved_by; ++ } ++ want_wei_md.format_desc.blocking.strides[0] = want_wei_md.padded_dims[2] * want_wei_md.padded_dims[1]; ++ ++ acl_utils::update_strides_y_and_z( ++ amp.wei_info, ++ want_wei_md.format_desc.blocking.strides[2] * wei_d.data_type_size(), ++ want_wei_md.format_desc.blocking.strides[0] * wei_d.data_type_size()); ++ ++ blocked_first_dimension = 2; ++ blocked_second_dimension = 1; ++ ++ break; ++ } ++ ++ case 4: { ++ // For 4D case the format we need to pass is abDcXd and ++ // when doing fast mode is abDCxdYc ++ int D_padded = want_wei_md.dims[3]; ++ if(D_padded % interleaved_by != 0) { ++ D_padded = utils::div_up(D_padded, interleaved_by) * interleaved_by; ++ want_wei_md.padded_dims[3] = D_padded; ++ } ++ ++ int C_padded = want_wei_md.dims[2]; ++ if(C_padded % block_by != 0) { ++ C_padded = utils::div_up(C_padded, block_by) * block_by; ++ want_wei_md.padded_dims[2] = C_padded; ++ } ++ ++ want_wei_md.format_desc.blocking.strides[0] = want_wei_md.dims[1]*D_padded*C_padded; ++ want_wei_md.format_desc.blocking.strides[1] = D_padded*C_padded; ++ want_wei_md.format_desc.blocking.strides[2] = interleaved_by*block_by; ++ want_wei_md.format_desc.blocking.strides[3] = interleaved_by*C_padded; ++ ++ acl_utils::update_strides_y_and_z( ++ amp.wei_info, ++ want_wei_md.format_desc.blocking.strides[3] * wei_d.data_type_size(), ++ want_wei_md.format_desc.blocking.strides[1] * wei_d.data_type_size()); ++ ++ blocked_first_dimension = 3; ++ blocked_second_dimension = 2; ++ ++ break; ++ } ++ ++ default: ++ return status::unimplemented; ++ } ++ ++ want_wei_md.format_desc.blocking.inner_nblks = (block_by > 1) + 1; ++ want_wei_md.format_desc.blocking.inner_idxs[0] = blocked_first_dimension; ++ want_wei_md.format_desc.blocking.inner_blks[0] = interleaved_by; ++ if(block_by > 1) { ++ want_wei_md.format_desc.blocking.inner_idxs[1] = blocked_second_dimension; ++ want_wei_md.format_desc.blocking.inner_blks[1] = block_by; ++ } ++ ++ if(is_fast_math_kernel) { ++ want_wei_md.data_type = dnnl_bf16; ++ } ++ ++ wei_md = want_wei_md; + -+ acl_utils::reorder_to_weight_format(amp.wei_tensor_info, wei_md, -+ expected_weight_format, K_dim, N_dim, {}, batch_dims); - return status::success; } -diff --git a/src/cpu/aarch64/matmul/acl_matmul_utils.hpp b/src/cpu/aarch64/matmul/acl_matmul_utils.hpp -index 0a5ee6a987..67bb2e78eb 100644 ---- a/src/cpu/aarch64/matmul/acl_matmul_utils.hpp -+++ b/src/cpu/aarch64/matmul/acl_matmul_utils.hpp -@@ -1,5 +1,5 @@ - /******************************************************************************* --* Copyright 2021-2022 Arm Ltd. and affiliates -+* Copyright 2021-2023 Arm Ltd. and affiliates - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. -@@ -29,25 +29,21 @@ namespace aarch64 { - struct acl_matmul_obj_t { - arm_compute::NEGEMM gemm; - arm_compute::NETranspose transA; -- arm_compute::NETranspose transB; - arm_compute::Tensor src_tensor; - arm_compute::Tensor src_acc_tensor; - arm_compute::Tensor wei_tensor; -- arm_compute::Tensor wei_acc_tensor; - arm_compute::Tensor dst_tensor; - }; - struct acl_matmul_conf_t { - bool is_transA; -- bool is_transB; - // If this is true, the result of the matmul goes into a temporarily - // allocated ACL tensor to be accumulated into the oneDNN dst during postops - bool use_dst_acc; -- arm_compute::TensorInfo src_info; -+ arm_compute::TensorInfo src_tensor_info; - arm_compute::TensorInfo src_acc_info; -- arm_compute::TensorInfo wei_info; -- arm_compute::TensorInfo wei_acc_info; -- arm_compute::TensorInfo dst_info; -+ arm_compute::TensorInfo wei_tensor_info; -+ arm_compute::TensorInfo dst_tensor_info; - arm_compute::GEMMInfo gemm_info; - float alpha; - }; diff --git a/third_party/mkl_dnn/onednn_acl_remove_winograd.patch b/third_party/mkl_dnn/onednn_acl_remove_winograd.patch deleted file mode 100644 index 18abcc8f54e922..00000000000000 --- a/third_party/mkl_dnn/onednn_acl_remove_winograd.patch +++ /dev/null @@ -1,326 +0,0 @@ - ******************************************************************************* - Copyright 2023 Arm Limited and affiliates. - SPDX-License-Identifier: Apache-2.0 - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ******************************************************************************* -diff --git a/src/cpu/aarch64/acl_convolution_utils.cpp b/src/cpu/aarch64/acl_convolution_utils.cpp -index c46d697575..37f8ecbc06 100644 ---- a/src/cpu/aarch64/acl_convolution_utils.cpp -+++ b/src/cpu/aarch64/acl_convolution_utils.cpp -@@ -271,54 +271,6 @@ status_t init_conf_indirect_gemm(acl_conv_conf_t &acp, memory_desc_t &src_md, - return status::success; - } - --status_t init_conf_wino(acl_conv_conf_t &acp, memory_desc_t &src_md, -- memory_desc_t &weights_md, memory_desc_t &dst_md, -- memory_desc_t &bias_md, const convolution_desc_t &cd, -- const primitive_attr_t &attr) { -- -- // Under these conditions, fallback to faster GEMM-based convolution -- // unless the user explicitly specifies Winograd algorithm -- // clang-format off -- if (one_of(true, src_md.dims[2] > 112, // ih -- src_md.dims[3] > 112, // iw -- src_md.dims[1] < 64, // ic -- dst_md.dims[1] < 64, // oc -- dnnl_get_max_threads() > 28) -- && cd.alg_kind == alg_kind::convolution_auto) { -- return status::unimplemented; -- } -- // clang-format on -- -- // General Compute Library checks, memory tags are also set there -- CHECK(acl_init_conf(acp, src_md, weights_md, dst_md, bias_md, cd, attr)); -- -- const bool shape_ok -- // only unit strides allowed -- = (acp.padstride_info.stride() == std::pair {1, 1}) -- // Note: Compute Library supports arbitrary padding for wino kernels -- // but we only allow small padding to be consistent with oneDNN -- && (acp.padstride_info.pad().first <= 1) // padding left/right -- && (acp.padstride_info.pad().second <= 1) // padding top/bottom -- // only non-dilated convolutions allowed -- && (acp.dilation_info == arm_compute::Size2D(1, 1)); -- -- ACL_CHECK_SUPPORT(!shape_ok, "shape not supported by winograd kernels"); -- -- // clang-format off -- // Validate convolution manually to check for return status -- ACL_CHECK_VALID(arm_compute::NEWinogradConvolutionLayer::validate( -- &acp.src_info, -- &acp.wei_info, -- acp.with_bias ? &acp.bia_info : nullptr, -- &acp.dst_info, -- acp.padstride_info, -- acp.act_info, -- true)); // enable_fast_math flag in ACL Winograd -- // clang-format on -- -- return status::success; --} -- - } // namespace acl_convolution_utils - - } // namespace aarch64 -diff --git a/src/cpu/aarch64/acl_convolution_utils.hpp b/src/cpu/aarch64/acl_convolution_utils.hpp -index 3e56245faf..0398ab06b9 100644 ---- a/src/cpu/aarch64/acl_convolution_utils.hpp -+++ b/src/cpu/aarch64/acl_convolution_utils.hpp -@@ -66,11 +66,6 @@ status_t init_conf_indirect_gemm(acl_conv_conf_t &acp, memory_desc_t &src_md, - memory_desc_t &bias_md, const convolution_desc_t &cd, - const primitive_attr_t &attr); - --status_t init_conf_wino(acl_conv_conf_t &acp, memory_desc_t &src_md, -- memory_desc_t &weights_md, memory_desc_t &dst_md, -- memory_desc_t &bias_md, const convolution_desc_t &cd, -- const primitive_attr_t &attr); -- - } // namespace acl_convolution_utils - - template _lock {this->mtx}; -- // Retrieve primitive resource and configured Compute Library objects -- auto *acl_resource -- = ctx.get_resource_mapper()->get(this); -- acl_obj_t &acl_wino_obj -- = acl_resource->get_acl_obj(); -- -- return execute_forward_conv_acl< -- acl_obj_t, pd_t, data_t>( -- ctx, acl_wino_obj, pd()); --} -- --} // namespace aarch64 --} // namespace cpu --} // namespace impl --} // namespace dnnl -diff --git a/src/cpu/aarch64/acl_winograd_convolution.hpp b/src/cpu/aarch64/acl_winograd_convolution.hpp -deleted file mode 100644 -index 215635fe3f..0000000000 ---- a/src/cpu/aarch64/acl_winograd_convolution.hpp -+++ /dev/null -@@ -1,146 +0,0 @@ --/******************************************************************************* --* Copyright 2020-2022 Arm Ltd. and affiliates --* --* Licensed under the Apache License, Version 2.0 (the "License"); --* you may not use this file except in compliance with the License. --* You may obtain a copy of the License at --* --* http://www.apache.org/licenses/LICENSE-2.0 --* --* Unless required by applicable law or agreed to in writing, software --* distributed under the License is distributed on an "AS IS" BASIS, --* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. --* See the License for the specific language governing permissions and --* limitations under the License. --*******************************************************************************/ -- --#ifndef CPU_AARCH64_ACL_WINOGRAD_CONVOLUTION_HPP --#define CPU_AARCH64_ACL_WINOGRAD_CONVOLUTION_HPP -- --#include "cpu/cpu_convolution_pd.hpp" -- --#include "cpu/aarch64/acl_convolution_utils.hpp" -- --namespace dnnl { --namespace impl { --namespace cpu { --namespace aarch64 { -- --struct acl_wino_resource_t : public resource_t { -- acl_wino_resource_t() -- : acl_wino_obj_(utils::make_unique< -- acl_obj_t>()) {} -- -- status_t configure(const acl_conv_conf_t &acp) { -- if (!acl_wino_obj_) return status::out_of_memory; -- -- // Init Compute Library tensors based on info from descriptor -- acl_wino_obj_->src_tensor.allocator()->init(acp.src_info); -- acl_wino_obj_->wei_tensor.allocator()->init(acp.wei_info); -- acl_wino_obj_->dst_tensor.allocator()->init(acp.dst_info); -- acl_wino_obj_->bia_tensor.allocator()->init(acp.bia_info); -- -- // clang-format off -- acl_wino_obj_->conv.configure( -- &acl_wino_obj_->src_tensor, -- &acl_wino_obj_->wei_tensor, -- acp.with_bias ? &acl_wino_obj_->bia_tensor : nullptr, -- &acl_wino_obj_->dst_tensor, -- acp.padstride_info, -- acp.act_info, -- true); // to support 5x5, 7x7 filter shapes in addition to 3x3 -- // clang-format on -- -- return status::success; -- } -- -- acl_obj_t &get_acl_obj() const { -- return *acl_wino_obj_; -- } -- -- DNNL_DISALLOW_COPY_AND_ASSIGN(acl_wino_resource_t); -- --private: -- std::unique_ptr> -- acl_wino_obj_; --}; // acl_wino_resource_t -- --struct acl_wino_convolution_fwd_t : public primitive_t { -- struct pd_t : public cpu_convolution_fwd_pd_t { -- pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr, -- const typename pd_t::base_class *hint_fwd_pd) -- : cpu_convolution_fwd_pd_t(adesc, attr, hint_fwd_pd) -- , acp_() -- , post_ops() {} -- -- DECLARE_COMMON_PD_T( -- "wino:acl", acl_wino_convolution_fwd_t, USE_GLOBAL_SCRATCHPAD); -- -- status_t init(engine_t *engine) { -- bool ok = is_fwd() -- && utils::one_of(desc()->alg_kind, -- alg_kind::convolution_auto, -- alg_kind::convolution_winograd) -- && expect_data_types(data_type::f32, data_type::f32, -- data_type::f32, data_type::f32, data_type::f32) -- && attr()->has_default_values( -- primitive_attr_t::skip_mask_t::post_ops, -- data_type::f32) -- && !has_zero_dim_memory(); -- if (!ok) return status::unimplemented; -- -- CHECK(acl_convolution_utils::init_conf_wino(acp_, src_md_, -- weights_md_, dst_md_, bias_md_, *desc(), *attr())); -- -- set_default_alg_kind(alg_kind::convolution_winograd); -- -- CHECK(post_ops.init( -- engine, attr_.post_ops_, dst_md_, acp_.act_info)); -- acp_.use_dst_acc = post_ops.has_sum(); -- -- return status::success; -- } -- -- acl_conv_conf_t acp_; -- acl_post_ops_t post_ops; -- }; -- -- acl_wino_convolution_fwd_t(const pd_t *apd) : primitive_t(apd) {} -- -- status_t create_resource( -- engine_t *engine, resource_mapper_t &mapper) const override { -- if (mapper.has_resource(this)) return status::success; -- -- auto r = utils::make_unique(); -- if (!r) return status::out_of_memory; -- -- // Configure the resource based on information from primitive descriptor -- CHECK(r->configure(pd()->acp_)); -- mapper.add(this, std::move(r)); -- -- CHECK(pd()->post_ops.create_resource(engine, mapper)); -- -- return status::success; -- } -- -- ~acl_wino_convolution_fwd_t() {} -- -- typedef typename prec_traits::type data_t; -- -- status_t execute(const exec_ctx_t &ctx) const override { -- return execute_forward(ctx); -- } -- --private: -- // To guard the const execute_forward(), the mutex must be 'mutable' -- mutable std::mutex mtx; -- status_t execute_forward(const exec_ctx_t &ctx) const; -- const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } --}; // acl_wino_convolution_fwd_t -- --} // namespace aarch64 --} // namespace cpu --} // namespace impl --} // namespace dnnl -- --#endif // CPU_AARCH64_ACL_WINOGRAD_CONVOLUTION_HPP -diff --git a/src/cpu/cpu_convolution_list.cpp b/src/cpu/cpu_convolution_list.cpp -index 4142dbc7e7..094c73aa36 100644 ---- a/src/cpu/cpu_convolution_list.cpp -+++ b/src/cpu/cpu_convolution_list.cpp -@@ -65,7 +65,6 @@ using namespace dnnl::impl::cpu::x64; - #if DNNL_AARCH64 && DNNL_AARCH64_USE_ACL - #include "cpu/aarch64/acl_gemm_convolution.hpp" - #include "cpu/aarch64/acl_indirect_gemm_convolution.hpp" --#include "cpu/aarch64/acl_winograd_convolution.hpp" - #endif - using namespace dnnl::impl::cpu::aarch64; - #endif -@@ -100,7 +99,6 @@ const std::map> &impl_list_map() - CPU_INSTANCE_SSE41(jit_sse41_1x1_convolution_fwd_t) - CPU_INSTANCE_AVX2(jit_avx2_convolution_fwd_t) - CPU_INSTANCE_SSE41(jit_sse41_convolution_fwd_t) -- CPU_INSTANCE_AARCH64_ACL(acl_wino_convolution_fwd_t) - CPU_INSTANCE_AARCH64(jit_sve_512_dw_convolution_fwd_t) - CPU_INSTANCE_AARCH64(jit_sve_512_1x1_convolution_fwd_f32_t) - CPU_INSTANCE_AARCH64(jit_sve_512_convolution_fwd_t) -diff --git a/tests/gtests/test_iface_wino_convolution.cpp b/tests/gtests/test_iface_wino_convolution.cpp -index 03861b1de4..2235ceae36 100644 ---- a/tests/gtests/test_iface_wino_convolution.cpp -+++ b/tests/gtests/test_iface_wino_convolution.cpp -@@ -59,9 +59,6 @@ protected: - input_f16.wino_supported = is_gpu; - input_int8.wino_supported = is_cpu && has_avx512_core; - input_f32.backward_supported = is_cpu && impl::dnnl_thr_syncable(); --#elif DNNL_AARCH64 && DNNL_AARCH64_USE_ACL -- const bool is_cpu = get_test_engine_kind() == engine::kind::cpu; -- input_f32.wino_supported = is_cpu; - #endif - - #else diff --git a/third_party/mkl_dnn/onednn_acl_reorder.patch b/third_party/mkl_dnn/onednn_acl_reorder.patch deleted file mode 100644 index 05ef1160e1469b..00000000000000 --- a/third_party/mkl_dnn/onednn_acl_reorder.patch +++ /dev/null @@ -1,352 +0,0 @@ -diff --git a/src/cpu/aarch64/acl_reorder.cpp b/src/cpu/aarch64/acl_reorder.cpp -new file mode 100644 -index 000000000..061751b55 ---- /dev/null -+++ b/src/cpu/aarch64/acl_reorder.cpp -@@ -0,0 +1,52 @@ -+/******************************************************************************* -+* Copyright 2023 Arm Ltd. and affiliates -+* -+* Licensed under the Apache License, Version 2.0 (the "License"); -+* you may not use this file except in compliance with the License. -+* You may obtain a copy of the License at -+* -+* http://www.apache.org/licenses/LICENSE-2.0 -+* -+* Unless required by applicable law or agreed to in writing, software -+* distributed under the License is distributed on an "AS IS" BASIS, -+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+* See the License for the specific language governing permissions and -+* limitations under the License. -+*******************************************************************************/ -+ -+#include "cpu/aarch64/acl_reorder.hpp" -+ -+namespace dnnl { -+namespace impl { -+namespace cpu { -+namespace aarch64 { -+ -+status_t acl_reorder_fwd_t::execute_forward(const exec_ctx_t &ctx) const { -+ // Lock here is needed because resource_mapper does not support -+ // concurrent multithreaded access. -+ std::lock_guard _lock {this->mtx}; -+ -+ auto src = CTX_IN_MEM(const void *, DNNL_ARG_FROM); -+ auto dst = CTX_OUT_MEM(void *, DNNL_ARG_TO); -+ -+ // Retrieve primitive resource and configured Compute Library objects -+ auto *acl_resource -+ = ctx.get_resource_mapper()->get(this); -+ -+ acl_reorder_obj_t &acl_obj = acl_resource->get_acl_obj(); -+ -+ acl_obj.src_tensor.allocator()->import_memory(const_cast(src)); -+ acl_obj.dst_tensor.allocator()->import_memory(dst); -+ -+ acl_obj.reorder.run(); -+ -+ acl_obj.src_tensor.allocator()->free(); -+ acl_obj.dst_tensor.allocator()->free(); -+ -+ return status::success; -+} -+ -+} // namespace aarch64 -+} // namespace cpu -+} // namespace impl -+} // namespace dnnl -diff --git a/src/cpu/aarch64/acl_reorder.hpp b/src/cpu/aarch64/acl_reorder.hpp -new file mode 100644 -index 000000000..91d23e06d ---- /dev/null -+++ b/src/cpu/aarch64/acl_reorder.hpp -@@ -0,0 +1,260 @@ -+/******************************************************************************* -+* Copyright 2023 Arm Ltd. and affiliates -+* -+* Licensed under the Apache License, Version 2.0 (the "License"); -+* you may not use this file except in compliance with the License. -+* You may obtain a copy of the License at -+* -+* http://www.apache.org/licenses/LICENSE-2.0 -+* -+* Unless required by applicable law or agreed to in writing, software -+* distributed under the License is distributed on an "AS IS" BASIS, -+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+* See the License for the specific language governing permissions and -+* limitations under the License. -+*******************************************************************************/ -+#ifndef CPU_AARCH64_ACL_REORDER_HPP -+#define CPU_AARCH64_ACL_REORDER_HPP -+ -+#include "cpu/aarch64/acl_utils.hpp" -+#include "cpu/reorder/cpu_reorder_pd.hpp" -+#include "arm_compute/core/Types.h" -+#include "common/utils.hpp" -+ -+namespace dnnl { -+namespace impl { -+namespace cpu { -+namespace aarch64 { -+ -+struct acl_reorder_obj_t { -+ arm_compute::NEReorderLayer reorder; -+ arm_compute::Tensor src_tensor; -+ arm_compute::Tensor dst_tensor; -+ arm_compute::WeightFormat src_wf; -+ arm_compute::WeightFormat dst_wf; -+}; -+ -+struct acl_reorder_conf_t { -+ arm_compute::TensorInfo src_info; -+ arm_compute::TensorInfo dst_info; -+ arm_compute::WeightFormat src_wf; -+ arm_compute::WeightFormat dst_wf; -+}; -+ -+struct acl_reorder_resource_t : public resource_t { -+ acl_reorder_resource_t() : acl_obj_(utils::make_unique()) {} -+ -+ status_t configure(const acl_reorder_conf_t &app) { -+ if (!acl_obj_) return status::out_of_memory; -+ -+ // Init Compute Library tensors based on info from descriptor -+ acl_obj_->src_tensor.allocator()->init(app.src_info); -+ acl_obj_->dst_tensor.allocator()->init(app.dst_info); -+ -+ // clang-format off -+ acl_obj_->reorder.configure( -+ &acl_obj_->src_tensor, -+ &acl_obj_->dst_tensor, -+ app.src_wf, -+ app.dst_wf -+ ); -+ // clang-format on -+ -+ return status::success; -+ } -+ -+ acl_reorder_obj_t &get_acl_obj() const { return *acl_obj_; } -+ DNNL_DISALLOW_COPY_AND_ASSIGN(acl_reorder_resource_t); -+ -+private: -+ std::unique_ptr acl_obj_; -+}; // acl_reorder_resource_t -+ -+struct acl_reorder_fwd_t : public primitive_t { -+ using primitive_t::primitive_t; -+ struct pd_t : public cpu_reorder_pd_t { -+ -+ using cpu_reorder_pd_t::cpu_reorder_pd_t; -+ -+ DECLARE_COMMON_PD_T("acl", acl_reorder_fwd_t); -+ -+ static status_t create(reorder_pd_t **reorder_pd, engine_t *engine, -+ const primitive_attr_t *attr, engine_t *src_engine, -+ const memory_desc_t *src_md, engine_t *dst_engine, -+ const memory_desc_t *dst_md) { -+ -+ using namespace acl_utils; -+ // using skip_mask_t = dnnl_primitive_attr::skip_mask_t; -+ -+ bool ok = src_md->data_type -+ == dst_md->data_type // ACL only supports matching src/dst data types -+ && utils::one_of(src_md->data_type, -+ data_type::f32) // Only supports f32 for now -+ && attr->has_default_values(); -+ if (!ok) return status::unimplemented; -+ -+ int mask = -1; -+ bool is_set = false; -+ // CHECK(attr->scales_.get(DNNL_ARG_DST, &mask, &is_set)); -+ const memory_desc_wrapper input_d(src_md); -+ if (input_d.has_runtime_dims_or_strides() && is_set && mask > 0) -+ return status::unimplemented; -+ -+ // Create and check primitive descriptor -+ auto _pd = new pd_t(attr, src_engine->kind(), src_md, -+ dst_engine->kind(), dst_md); -+ if (_pd == nullptr) return status::out_of_memory; -+ if (_pd->init(engine, src_engine, dst_engine) != status::success) { -+ delete _pd; -+ return status::unimplemented; -+ } -+ -+ const memory_desc_wrapper src_d(*src_md); -+ const memory_desc_wrapper dst_d(*dst_md); -+ -+ const int ndims = src_d.ndims(); -+ -+ auto src_tag = memory_desc_matches_one_of_tag( -+ *src_md, format_tag::ba, format_tag::cdba); -+ ACL_CHECK_SUPPORT( -+ utils::one_of(format_tag::undef, src_tag), -+ ""); -+ -+ arm_compute::TensorShape acl_tensor_shape_in; -+ arm_compute::TensorShape acl_tensor_shape_out; -+ // Need even amount of dims in dim 0 for ACL kernel (eg mulitple of 8 rows when blocking by 8) -+ int dim_0_rounded_up; -+ -+ // Switch for 2 or 4 dim tensors -+ switch(ndims) -+ { -+ // Currently for Ab4a and Ab8a -+ // No format_tag for these, have to deduce from stride -+ case 2: -+ { -+ if(dst_md->dims[0] == 1 || dst_md->dims[1] == 1){ -+ return status::unimplemented; -+ } -+ int dst_dim_1 = dst_md->dims[1]; -+ int dst_dim_0_stride = dst_md->format_desc.blocking.strides[0]; -+ int dst_dim_1_stride = dst_md->format_desc.blocking.strides[1]; -+ // Interleave of 4 or 8 that stride for dim 1 -+ if (dst_dim_1_stride != 4 && dst_dim_1_stride != 8){ -+ return status::unimplemented; -+ } -+ // Check to ensure it's a blocking transpose -+ if (dst_dim_1 * dst_dim_1_stride != dst_dim_0_stride){ -+ return status::unimplemented; -+ } -+ if(dst_dim_1_stride == 4){ -+ // Set Dest WeightFormat -+ _pd->app_.dst_wf = arm_compute::WeightFormat::OHWIo4; -+ dim_0_rounded_up -+ = utils::rnd_up(src_md->dims[0], 4); -+ } else { -+ // Set Dest WeightFormat -+ _pd->app_.dst_wf = arm_compute::WeightFormat::OHWIo8; -+ dim_0_rounded_up -+ = utils::rnd_up(src_md->dims[0], 8); -+ } -+ acl_tensor_shape_in = arm_compute::TensorShape(src_md->dims[1], src_md->dims[0]); -+ acl_tensor_shape_out = arm_compute::TensorShape(src_md->dims[1], dim_0_rounded_up); -+ -+ break; -+ } -+ // Currently for Acdb4a and Acdb8a -+ case 4: -+ { -+ -+ auto dst_tag = memory_desc_matches_one_of_tag( -+ *dst_md, format_tag::Acdb4a, format_tag::Acdb8a); -+ ACL_CHECK_SUPPORT( -+ utils::one_of(format_tag::undef, dst_tag), -+ ""); -+ if(dst_tag == format_tag::Acdb4a){ -+ // Set Dest WeightFormat -+ _pd->app_.dst_wf = arm_compute::WeightFormat::OHWIo4; -+ dim_0_rounded_up -+ = utils::rnd_up(src_md->dims[0], 4); -+ } -+ else{ -+ // Set Dest WeightFormat -+ _pd->app_.dst_wf = arm_compute::WeightFormat::OHWIo8; -+ dim_0_rounded_up -+ = utils::rnd_up(src_md->dims[0], 8); -+ } -+ // Currently only supporting AxBx1x1 cases -+ if(dst_md->dims[2] != 1 || dst_md->dims[3] != 1){ -+ return status::unimplemented; -+ } -+ -+ acl_tensor_shape_in = arm_compute::TensorShape(src_md->dims[3], src_md->dims[2], src_md->dims[1], src_md->dims[0]); -+ acl_tensor_shape_out = arm_compute::TensorShape(src_md->dims[3], src_md->dims[2], src_md->dims[1], dim_0_rounded_up); -+ break; -+ } -+ default: -+ return status::unimplemented; -+ } -+ -+ // Choose the data layout -+ // bool is_nspc = utils::one_of(src_tag, format_tag::nhwc); -+ const auto acl_layout = arm_compute::DataLayout::NCHW; -+ -+ // Set Source WeightFormat -+ _pd->app_.src_wf = arm_compute::WeightFormat::OHWI; -+ -+ // Create ACL tensor infos -+ const data_type_t data_type = src_d.data_type(); -+ const arm_compute::DataType acl_data_t -+ = acl_utils::get_acl_data_t(data_type); -+ _pd->app_.src_info = arm_compute::TensorInfo( -+ acl_tensor_shape_in, 1, acl_data_t, acl_layout); -+ _pd->app_.dst_info = arm_compute::TensorInfo( -+ acl_tensor_shape_out, 1, acl_data_t, acl_layout); -+ -+ // Init scratch memory, not used so 0 in this implementation -+ _pd->init_scratchpad_md(); -+ -+ return safe_ptr_assign(*reorder_pd, _pd); -+ } // create -+ -+ friend dnnl::impl::impl_list_item_t; -+ acl_reorder_conf_t app_; -+ -+ }; // pd_t -+ -+ acl_reorder_fwd_t(const pd_t *apd) : primitive_t(apd) {} -+ -+ status_t create_resource( -+ engine_t *engine, resource_mapper_t &mapper) const override { -+ if (mapper.has_resource(this)) return status::success; -+ -+ auto r = utils::make_unique(); -+ if (!r) return status::out_of_memory; -+ -+ // Configure the resource based on information from primitive descriptor -+ CHECK(r->configure(pd()->app_)); -+ -+ mapper.add(this, std::move(r)); -+ return status::success; -+ } -+ -+ status_t execute(const exec_ctx_t &ctx) const override { -+ return execute_forward(ctx); -+ } -+ -+private: -+ // To guard the const execute_forward, the mutex must be 'mutable' -+ mutable std::mutex mtx; -+ status_t execute_forward(const exec_ctx_t &ctx) const; -+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } -+ -+ -+}; // acl_reorder_fwd_t -+ -+} // namespace aarch64 -+} // namespace cpu -+} // namespace impl -+} // namespace dnnl -+ -+#endif // CPU_AARCH64_ACL_REORDER_HPP -diff --git a/src/cpu/reorder/cpu_reorder_regular_f32_f32.cpp b/src/cpu/reorder/cpu_reorder_regular_f32_f32.cpp -index bccd2f75f..5e5ea331b 100644 ---- a/src/cpu/reorder/cpu_reorder_regular_f32_f32.cpp -+++ b/src/cpu/reorder/cpu_reorder_regular_f32_f32.cpp -@@ -15,6 +15,7 @@ - *******************************************************************************/ - - #include "cpu/reorder/cpu_reorder.hpp" -+#include "cpu/aarch64/acl_reorder.hpp" - - namespace dnnl { - namespace impl { -@@ -27,6 +28,7 @@ const impl_list_map_t ®ular_f32_f32_impl_list_map() { - // f32 -> f32 - {{f32, f32, 0}, { - REG_FAST_DIRECT_COPY_F32_F32 -+ DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::acl_reorder_fwd_t)) - - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_blk_reorder_t)) - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_t)) -@@ -64,6 +66,7 @@ const impl_list_map_t ®ular_f32_f32_impl_list_map() { - nullptr, - }}, - {{f32, f32, 4}, { -+ DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::acl_reorder_fwd_t)) - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::wino_reorder_t)) - - CPU_REORDER_INSTANCE(rnn_weights_reorder_t) diff --git a/third_party/mkl_dnn/onednn_acl_reorder_padded.patch b/third_party/mkl_dnn/onednn_acl_reorder_padded.patch deleted file mode 100644 index f290f21ec87e9b..00000000000000 --- a/third_party/mkl_dnn/onednn_acl_reorder_padded.patch +++ /dev/null @@ -1,858 +0,0 @@ - ******************************************************************************* - Copyright 2022 Arm Limited and affiliates. - SPDX-License-Identifier: Apache-2.0 - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ******************************************************************************* - -diff --git a/src/cpu/aarch64/jit_uni_reorder.cpp b/src/cpu/aarch64/jit_uni_reorder.cpp -index 24d6220cf..a6cefaa20 100644 ---- a/src/cpu/aarch64/jit_uni_reorder.cpp -+++ b/src/cpu/aarch64/jit_uni_reorder.cpp -@@ -1,6 +1,7 @@ - /******************************************************************************* - * Copyright 2018-2021 Intel Corporation - * Copyright 2020-2021 FUJITSU LIMITED -+* Copyright 2022 Arm Ltd. and affiliates - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. -@@ -54,6 +55,35 @@ namespace aarch64 { - - namespace tr { - -+static bool prb_has_small_strides(const prb_t &prb) { -+ constexpr ptrdiff_t max_stride = (1LL << 31) - 1; -+ for (int d = 0; d < prb.ndims; ++d) { -+ const ptrdiff_t cms = max_stride / prb.nodes[d].n; -+ const bool small_strides = true -+ && prb.nodes[d].is < cms / (int)data_type_size(prb.itype) -+ && prb.nodes[d].os < cms / (int)data_type_size(prb.otype); -+ if (!small_strides) return false; -+ } -+ return true; -+} -+ -+static bool prb_tail_friendly(const prb_t &prb) { -+ /* find optimal ndims to makes it easier to -+ * identify the blk_chunk in the loop*/ -+ int ndims = prb.full_ndims - prb.ndims; -+ -+ int n = prb.nodes[0].is; -+ for (int d = 1; d < prb.ndims; ++d) { -+ if (d != prb.blk_chunk_idx) n *= prb.nodes[d].n; -+ } -+ if (prb.ip_tail > 0 -+ && ((ndims == 0 && n != 1) -+ || (ndims > 0 && prb.ndims > prb.blk_chunk_idx))) -+ return false; -+ -+ return true; -+} -+ - /** Minimal reasonable/desirable kernel size. - * The constant might be used to determine how a problem should be split - * between kernel and threading driver. */ -@@ -121,18 +151,10 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { - && utils::one_of(p.otype, f32, s32, data_type::s8, u8) - && utils::everyone_is(0, p.ioff, p.ooff) /* do we need this? */ - && utils::one_of(p.beta, 0.f, 1.f) /* anything else? */ -- && simple_impl_desc_init(p, nullptr); -+ && simple_impl_desc_init(p, nullptr) && prb_has_small_strides(p) -+ && prb_tail_friendly(p); - if (!ok) return false; - -- const ptrdiff_t max_stride = (1LL << 31) - 1; -- for (int d = 0; d < p.ndims; ++d) { -- const ptrdiff_t cms = max_stride / p.nodes[d].n; -- bool strides_ok = true -- && p.nodes[d].is < cms / (int)data_type_size(p.itype) -- && p.nodes[d].os < cms / (int)data_type_size(p.otype); -- if (!strides_ok) return false; -- } -- - return true; - } - -@@ -153,6 +175,13 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { - return (int)prb_.nodes[d].ss; - } - -+ int blk_cnt() { -+ assert(prb_.blk_chunk_idx < prb_.full_ndims); -+ return (int)prb_.nodes[prb_.blk_chunk_idx].n - 1; -+ } -+ int op_padding() { return prb_.op_tail ? prb_.iblock - prb_.op_tail : 0; } -+ int ip_padding() { return prb_.ip_tail ? prb_.oblock - prb_.ip_tail : 0; } -+ - void step(int off, int prev_i_off, int prev_o_off, int prev_s_off, - int &i_off, int &o_off, int &s_off, int step_size = 1) { - i_off = prev_i_off; -@@ -385,6 +414,7 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { - prb_.otype, u8, data_type::s8, s32, f32))) - && utils::everyone_is(8, n(0), n(1)) - && utils::everyone_is(1, os(0), is(1)) -+ && utils::everyone_is(0, prb_.ip_tail, prb_.op_tail) - && prb_.scale_type == scale_type_t::NONE && prb_.beta == 0.f; - } - -@@ -405,17 +435,14 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { - bool process_direct_copy(int len) { - using namespace data_type; - -- const int simd_w = cpu_isa_traits::vlen == 16 -- ? cpu_isa_traits::vlen / itype_sz /* use 128-bit VReg */ -- : cpu_isa_traits::vlen / itype_sz -- / 2; /* use lower half of 512-bit ZReg */ -- -+ const int simd_w = cpu_isa_traits::vlen / itype_sz; - bool can_do = true && mayiuse(isa) - && utils::everyone_is(1, os(0), is(0)) - && (false || prb_.itype == prb_.otype - || (prb_.itype == s32 && prb_.otype == f32) - || (prb_.itype == f32 && prb_.otype == s32)) - && len % simd_w == 0 && n(0) % len == 0 -+ && prb_.ip_tail % simd_w == 0 && prb_.op_tail % simd_w == 0 - && prb_.scale_type == scale_type_t::NONE && prb_.beta == 0.f; - if (!can_do) return false; - -@@ -511,7 +538,8 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { - } - - void process_unroll_generic_step(int reg_unroll, const int *i_off, -- const int *o_off, const int *s_off) { -+ const int *o_off, const int *s_off, const int *ip_padding, -+ const bool h_padded) { - using namespace data_type; - - auto cvt2ps -@@ -571,6 +599,8 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { - for (int ur = 1; ur < reg_unroll; ++ur) - if (o_off[ur] != o_off[ur - 1] + 1) can_store_xmm = false; - const int ur_step = can_store_xmm ? 4 : 1; -+ const int load_tail_step -+ = !can_load_xmm && can_store_xmm ? ur_step : load_step; - - const bool interim_f32 = false - || utils::one_of(f32, prb_.itype, prb_.otype) -@@ -579,55 +609,85 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { - const bool need_saturation - = (utils::one_of(prb_.otype, u8, data_type::s8, s32) - && interim_f32); -- -- if (!can_load_xmm && can_store_xmm) { -- assert(ur_step == 4); -- /* load with stride */ -- for (int ur = 0; ur < reg_unroll; ur += ur_step) { -- -+ if (h_padded) { -+ for (int ur = 0; ur < reg_unroll; ur += load_tail_step) { -+ if (itype_sz == 4) -+ movi(VReg4S(ur), 0); -+ else if (itype_sz == 2) -+ movi(VReg8H(ur), 0); -+ else -+ movi(VReg16B(ur), 0); - /* x_tmp_vec = X_TMP_0 - X_TMP_4 - Do not use X_TMP_? as the last arg. */ -- for (int r = 0; r < ur_step; ++r) { -- add_imm(x_tmp_vec[r], x_ptr_in_off, -- i_off[ur + r] * itype_sz, X_DEFAULT_ADDR); -+ for (int r = 0; r < load_tail_step; ++r) { -+ if (ip_padding[ur + r] == 0) { -+ add_imm(x_tmp_vec[r], x_ptr_in_off, -+ i_off[ur + r] * itype_sz, X_DEFAULT_ADDR); -+ } - } - -- for (int r = 0; r < ur_step; ++r) { -- if (itype_sz == 4) -- ld1(VReg4S(ur)[r], ptr(x_tmp_vec[r])); -- else if (itype_sz == 2) -- ld1(VReg8H(ur)[r], ptr(x_tmp_vec[r])); -- else -- ld1(VReg16B(ur)[r], ptr(x_tmp_vec[r])); -+ for (int r = 0; r < load_tail_step; ++r) { -+ if (ip_padding[ur + r] == 0) { -+ if (itype_sz == 4) -+ ld1(VReg4S(ur)[r], ptr(x_tmp_vec[r])); -+ else if (itype_sz == 2) -+ ld1(VReg8H(ur)[r], ptr(x_tmp_vec[r])); -+ else -+ ld1(VReg16B(ur)[r], ptr(x_tmp_vec[r])); -+ } - } - } - } else { -- int ur = 0; -- int tmp_ur = 0; -- while (ur < reg_unroll) { -- int count = 0; -+ if (!can_load_xmm && can_store_xmm) { -+ assert(ur_step == 4); -+ /* load with stride */ -+ for (int ur = 0; ur < reg_unroll; ur += ur_step) { - -- do { -- add_imm(x_tmp_vec[count++], x_ptr_in_off, -- i_off[ur] * itype_sz, X_DEFAULT_ADDR); -- ur += load_step; -- } while (ur < reg_unroll && count < x_tmp_vec_size); -+ /* x_tmp_vec = X_TMP_0 - X_TMP_4 -+ Do not use X_TMP_? as the last arg. */ -+ for (int r = 0; r < ur_step; ++r) { -+ add_imm(x_tmp_vec[r], x_ptr_in_off, -+ i_off[ur + r] * itype_sz, X_DEFAULT_ADDR); -+ } - -- for (int i = 0; i < count; i++) { -+ for (int r = 0; r < ur_step; ++r) { -+ if (itype_sz == 4) -+ ld1(VReg4S(ur)[r], ptr(x_tmp_vec[r])); -+ else if (itype_sz == 2) -+ ld1(VReg8H(ur)[r], ptr(x_tmp_vec[r])); -+ else -+ ld1(VReg16B(ur)[r], ptr(x_tmp_vec[r])); -+ } -+ } -+ } else { -+ int ur = 0; -+ int tmp_ur = 0; -+ while (ur < reg_unroll) { -+ int count = 0; -+ -+ do { -+ add_imm(x_tmp_vec[count++], x_ptr_in_off, -+ i_off[ur] * itype_sz, X_DEFAULT_ADDR); -+ ur += load_step; -+ } while (ur < reg_unroll && count < x_tmp_vec_size); -+ -+ for (int i = 0; i < count; i++) { - -- switch (load_step * itype_sz) { -- case 16: ldr(QReg(tmp_ur), ptr(x_tmp_vec[i])); break; -- case 8: ldr(DReg(tmp_ur), ptr(x_tmp_vec[i])); break; -- case 4: ldr(SReg(tmp_ur), ptr(x_tmp_vec[i])); break; -- case 2: ldr(HReg(tmp_ur), ptr(x_tmp_vec[i])); break; -- case 1: ldr(BReg(tmp_ur), ptr(x_tmp_vec[i])); break; -- default: assert(!"unreachable"); -+ switch (load_step * itype_sz) { -+ case 16: -+ ldr(QReg(tmp_ur), ptr(x_tmp_vec[i])); -+ break; -+ case 8: ldr(DReg(tmp_ur), ptr(x_tmp_vec[i])); break; -+ case 4: ldr(SReg(tmp_ur), ptr(x_tmp_vec[i])); break; -+ case 2: ldr(HReg(tmp_ur), ptr(x_tmp_vec[i])); break; -+ case 1: ldr(BReg(tmp_ur), ptr(x_tmp_vec[i])); break; -+ default: assert(!"unreachable"); -+ } -+ tmp_ur += load_step; - } -- tmp_ur += load_step; - } - } - } -- - /* xmm[:] <-- (f32)xmm[:] */ - if (interim_f32) { - const int cvt_step = nstl::max(load_step, ur_step); -@@ -708,7 +768,8 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { - if (s_off[r] != s_off[r - 1] + 0) - scale_load_type = scale_load_type_t::load; - -- if (scale_load_type == scale_load_type_t::bcast) { -+ if (scale_load_type == scale_load_type_t::bcast -+ && !h_padded) { - VReg4S v(xmm_scale.getIdx()); - VReg4S v_dst(ur); - add_imm(X_TMP_0, x_ptr_scale_off, s_off[ur] * stype_sz, -@@ -724,7 +785,8 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { - if (s_off[r] != s_off[r - 1] + 1) - scale_load_type = scale_load_type_t::gather; - -- if (scale_load_type == scale_load_type_t::load) { -+ if (scale_load_type == scale_load_type_t::load -+ && !h_padded) { - uint32_t idx = xmm_scale.getIdx(); - VReg4S v_dst(ur); - add_imm(X_TMP_0, x_ptr_scale_off, s_off[ur] * stype_sz, -@@ -739,14 +801,18 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { - // so gather the scale factors one by one - /*ur_step is 1 or 4. */ - for (int r = ur; r < ur + ur_step; ++r) { -- /* x_tmp_vec = X_TMP_0 - X_TMP_4 -+ if (ip_padding[r] == 0 || !h_padded) { -+ /* x_tmp_vec = X_TMP_0 - X_TMP_4 - Do not use X_TMP_? as the last arg. */ -- add_imm(x_tmp_vec[r - ur], x_ptr_scale_off, -- s_off[r] * stype_sz, X_DEFAULT_ADDR); -+ add_imm(x_tmp_vec[r - ur], x_ptr_scale_off, -+ s_off[r] * stype_sz, X_DEFAULT_ADDR); -+ } - } - for (int r = ur; r < ur + ur_step; ++r) { -- VReg4S v(xmm_scale.getIdx()); -- ld1(v[r - ur], ptr(x_tmp_vec[r - ur])); -+ if (ip_padding[r] == 0 || !h_padded) { -+ VReg4S v(xmm_scale.getIdx()); -+ ld1(v[r - ur], ptr(x_tmp_vec[r - ur])); -+ } - } - fmul(VReg4S(ur), VReg4S(ur), xmm_scale); - } -@@ -925,7 +991,15 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { - } - } - -- void process_unroll_generic(int len) { -+ void comp_padding_flag(int ndims, int off, int len, int &i_tail) { -+ const int ip_without_padding -+ = ndims == 0 ? len - ip_padding() : prb_.ip_tail; -+ if ((ndims == 0 && off >= ip_without_padding) -+ || (ndims > 0 && (off % prb_.oblock) >= ip_without_padding)) -+ i_tail = 1; -+ } -+ -+ void process_unroll_generic(const int ndims, int len, const bool h_padded) { - const int blk = 8; - - int i_off[2 * blk] = {0}; -@@ -936,22 +1010,37 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { - - for (int off = 0; off < len; off += blk) { - const int reg_unroll = nstl::min(off + blk, len) - off; -+ int ip_padding[blk] = {0}; - -- /* compute offsets */ -+ /* compute offsets and tail*/ - for (int ur = off != 0 ? 0 : 1; ur < reg_unroll; ++ur) { - const int ur_c = curr * blk + ur; - const int ur_p = (ur_c - 1 + 2 * blk) % (2 * blk); // prev ur - step(off + ur, i_off[ur_p], o_off[ur_p], s_off[ur_p], - i_off[ur_c], o_off[ur_c], s_off[ur_c]); -+ if (h_padded) -+ comp_padding_flag(ndims, off + ur, len, ip_padding[ur]); - } -- - process_unroll_generic_step(reg_unroll, i_off + curr * blk, -- o_off + curr * blk, s_off + curr * blk); -+ o_off + curr * blk, s_off + curr * blk, ip_padding, -+ h_padded); - - curr = 1 - curr; - } - } - -+ void compute_ker( -+ const int ndims, const int len_unroll, const bool h_padded) { -+ bool optimized = false; -+ optimized = optimized -+ || (process_direct_copy(len_unroll) && !h_padded); -+ optimized = optimized -+ || (process_direct_copy(len_unroll) && !h_padded); -+ optimized -+ = optimized || (process_unroll_tr8x8(len_unroll) && !h_padded); -+ if (!optimized) process_unroll_generic(ndims, len_unroll, h_padded); -+ } -+ - void loop_begin(Label &l, XReg reg_cnt, int len) { - mov(reg_cnt, len); - L(l); -@@ -985,6 +1074,28 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { - } - } - -+ void compute_blk_ker(const int len_unroll) { -+ int omp_ndims = prb_.full_ndims - prb_.ndims; -+ Label no_last_blk, end_label; -+ -+ if (prb_.ip_tail > 0 && prb_.op_tail == 0) { -+ if (omp_ndims == 0) { -+ cmp(reg_last_loop_cnt, 1); -+ bne(no_last_blk); -+ compute_ker(omp_ndims, len_unroll, true); -+ } else { -+ cmp(reg_blk_chunks, blk_cnt()); -+ bne(no_last_blk); -+ compute_ker(omp_ndims, len_unroll, true); -+ } -+ b(end_label); -+ } -+ -+ L(no_last_blk); -+ compute_ker(omp_ndims, len_unroll, false); -+ L(end_label); -+ } -+ - bool simple_impl() { - simple_impl_desc_t d; - if (!simple_impl_desc_init(prb_, &d)) return false; -@@ -1013,11 +1124,7 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { - if (n_jit_loops > 0) - loop_begin(l_loop[0], reg_cnt[0], n(nfu + 0) / ldu); - -- bool optimized = false; -- optimized = optimized || process_direct_copy(d.len_unroll); -- optimized = optimized || process_direct_copy(d.len_unroll); -- optimized = optimized || process_unroll_tr8x8(d.len_unroll); -- if (!optimized) process_unroll_generic(d.len_unroll); -+ compute_blk_ker(d.len_unroll); - - if (n_jit_loops > 0) - loop_end(l_loop[0], reg_cnt[0], n(nfu + 0) / ldu, is(nfu + 0) * ldu, -@@ -1236,9 +1343,13 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { - } - add_imm(X_TMP_0, abi_param1, PARAM(in), X_TMP_2); - add_imm(X_TMP_1, abi_param1, PARAM(out), X_TMP_2); -+ add_imm(reg_blk, abi_param1, PARAM(blk_chunks), reg_blk); - ldr(reg_ptr_in, ptr(X_TMP_0)); - ldr(reg_ptr_out, ptr(X_TMP_1)); -+ ldr(reg_blk_chunks, ptr(reg_blk)); -+ - #undef PARAM -+ mov_imm(reg_last_loop_cnt, 1); - - mov(x_ptr_in_off, XReg(reg_ptr_in.getIdx())); - mov(x_ptr_out_off, XReg(reg_ptr_out.getIdx())); -@@ -1282,6 +1393,10 @@ private: - XReg reg_off_out = x9; - XReg reg_off_scale = x10; - -+ XReg reg_blk = x11; -+ XReg reg_blk_chunks = x12; -+ XReg reg_last_loop_cnt = x11; -+ - XReg reg_tmp = x0; - - VReg4S xmm_scale = v15.s; -@@ -1416,10 +1531,16 @@ static void prb_thread_kernel_balance( - for (int d = 0; d < prb.ndims; ++d) - sz_total *= prb.nodes[d].n; - -+ /* The general expression for sz_drv_thr can be written as -+ * sz_drv_min = C0 + FC * (nthr > 1 ? 1 : 0) + VC * (nthr - 1) -+ * where FC and VC are fixed and variable costs respectively. -+ * Though for now, the below heuristic seems to be good enough */ -+ const size_t sz_drv_thr = (nthr > 1) ? 16 * nthr : 1; -+ - /* sz_drv_min is the minimal size for the parallel - * driver required for good parallelization */ - const size_t sz_drv_min -- = nstl::min(16 * nthr, utils::div_up(sz_total, 1024)); -+ = nstl::min(sz_drv_thr, utils::div_up(sz_total, 1024)); - - /* kdims -- # of dimensions processed by a kernel - * sz_ker_cur -- product of the dimension processed by a kernel -@@ -1440,7 +1561,8 @@ static void prb_thread_kernel_balance( - * (less than tr::ker_prb_size_min). In that case try to split the - * innermost driver dimension into two, to increase sz_ker_cur. */ - bool want_borrow_ker_from_drv = true && kdims < prb.ndims -- && sz_ker_cur < tr::ker_prb_size_min && sz_drv_cur > sz_drv_min; -+ && sz_ker_cur < tr::ker_prb_size_min && sz_drv_cur > sz_drv_min -+ && kdims != prb.blk_chunk_idx; - if (want_borrow_ker_from_drv) { - /* sz_want_borrow is the minimal sz, so that: - * o) sz_ker_cur * sz_want_borrow >= tr::ker_prb_size_min -@@ -1464,7 +1586,7 @@ static void prb_thread_kernel_balance( - * try to split the outermost kernel dimension into two, to increase - * sz_drv_cur. */ - bool want_borrow_drv_from_ker = true && sz_ker_cur > tr::ker_prb_size_min -- && sz_drv_cur < sz_drv_min; -+ && sz_drv_cur < sz_drv_min && kdims != prb.blk_chunk_idx; - if (want_borrow_drv_from_ker) { - size_t sz_want_borrow = utils::div_up(sz_drv_min, sz_drv_cur); - for (; prb.nodes[kdims - 1].n % sz_want_borrow; ++sz_want_borrow) -@@ -1518,6 +1640,8 @@ status_t jit_uni_reorder_t::pd_t::create(reorder_pd_t **reorder_pd, - prb_dump(prb); - }); - -+ CHECK(prb_check_blk(prb, *dst_md)); -+ - int ndims_ker_max; - int nthr = dnnl_get_max_threads(); - prb_thread_kernel_balance(prb, ndims_ker_max, nthr); -@@ -1552,7 +1676,7 @@ status_t jit_uni_reorder_t::pd_t::create(reorder_pd_t **reorder_pd, - - void jit_uni_reorder_t::omp_driver_0d( - int off, const char *in, char *out, const float *scale) const { -- tr::call_param_t c {in, out, scale}; -+ tr::call_param_t c {in, out, scale, 0}; - (*kernel_)(&c); - } - -@@ -1564,6 +1688,7 @@ void jit_uni_reorder_t::omp_driver_1d(int ithr, int nthr, int off, - c.in = in + d0 * ns[0].is * data_type_size(pd()->prb_.itype); - c.out = out + d0 * ns[0].os * data_type_size(pd()->prb_.otype); - c.scale = scale + d0 * ns[0].ss; -+ c.blk_chunks = d0; - (*kernel_)(&c); - }); - } -@@ -1571,6 +1696,7 @@ void jit_uni_reorder_t::omp_driver_1d(int ithr, int nthr, int off, - void jit_uni_reorder_t::omp_driver_2d(int ithr, int nthr, int off, - const char *in, char *out, const float *scale) const { - const tr::node_t *ns = pd()->prb_.nodes + off; -+ const int blk_idx_off = pd()->prb_.blk_chunk_idx - off; - for_nd(ithr, nthr, (ptrdiff_t)ns[1].n, (ptrdiff_t)ns[0].n, - [&](ptrdiff_t d1, ptrdiff_t d0) { - auto c = tr::call_param_t(); -@@ -1581,6 +1707,7 @@ void jit_uni_reorder_t::omp_driver_2d(int ithr, int nthr, int off, - + (d0 * ns[0].os + d1 * ns[1].os) - * data_type_size(pd()->prb_.otype); - c.scale = scale + d0 * ns[0].ss + d1 * ns[1].ss; -+ c.blk_chunks = utils::pick(blk_idx_off, d0, d1); - (*kernel_)(&c); - }); - } -@@ -1588,6 +1715,7 @@ void jit_uni_reorder_t::omp_driver_2d(int ithr, int nthr, int off, - void jit_uni_reorder_t::omp_driver_3d(int ithr, int nthr, int off, - const char *in, char *out, const float *scale) const { - const tr::node_t *ns = pd()->prb_.nodes + off; -+ const int blk_idx_off = pd()->prb_.blk_chunk_idx - off; - for_nd(ithr, nthr, (ptrdiff_t)ns[2].n, (ptrdiff_t)ns[1].n, - (ptrdiff_t)ns[0].n, [&](ptrdiff_t d2, ptrdiff_t d1, ptrdiff_t d0) { - auto c = tr::call_param_t(); -@@ -1598,6 +1726,7 @@ void jit_uni_reorder_t::omp_driver_3d(int ithr, int nthr, int off, - + (d0 * ns[0].os + d1 * ns[1].os + d2 * ns[2].os) - * data_type_size(pd()->prb_.otype); - c.scale = scale + d0 * ns[0].ss + d1 * ns[1].ss + d2 * ns[2].ss; -+ c.blk_chunks = utils::pick(blk_idx_off, d0, d1, d2); - (*kernel_)(&c); - }); - } -@@ -1605,6 +1734,7 @@ void jit_uni_reorder_t::omp_driver_3d(int ithr, int nthr, int off, - void jit_uni_reorder_t::omp_driver_4d(int ithr, int nthr, int off, - const char *in, char *out, const float *scale) const { - const tr::node_t *ns = pd()->prb_.nodes + off; -+ const int blk_idx_off = pd()->prb_.blk_chunk_idx - off; - for_nd(ithr, nthr, (ptrdiff_t)ns[3].n, (ptrdiff_t)ns[2].n, - (ptrdiff_t)ns[1].n, (ptrdiff_t)ns[0].n, - [&](ptrdiff_t d3, ptrdiff_t d2, ptrdiff_t d1, ptrdiff_t d0) { -@@ -1619,6 +1749,7 @@ void jit_uni_reorder_t::omp_driver_4d(int ithr, int nthr, int off, - * data_type_size(pd()->prb_.otype); - c.scale = scale + d0 * ns[0].ss + d1 * ns[1].ss + d2 * ns[2].ss - + d3 * ns[3].ss; -+ c.blk_chunks = utils::pick(blk_idx_off, d0, d1, d2, d3); - (*kernel_)(&c); - }); - } -diff --git a/src/cpu/aarch64/jit_uni_reorder.hpp b/src/cpu/aarch64/jit_uni_reorder.hpp -index 88762756c..2fb6f0f89 100644 ---- a/src/cpu/aarch64/jit_uni_reorder.hpp -+++ b/src/cpu/aarch64/jit_uni_reorder.hpp -@@ -1,6 +1,7 @@ - /******************************************************************************* - * Copyright 2018-2020 Intel Corporation - * Copyright 2020 FUJITSU LIMITED -+* Copyright 2022 Arm Ltd. and affiliates - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. -@@ -52,11 +53,19 @@ struct prb_t { - ptrdiff_t ooff; - scale_type_t scale_type; - float beta; -+ int full_ndims; -+ int ip_tail; -+ int op_tail; -+ int iblock; -+ int oblock; -+ int blk_chunk_idx; - }; - - status_t prb_init(prb_t &prb, const memory_desc_t &imd, - const memory_desc_t &omd, const primitive_attr_t *attr); - -+status_t prb_check_blk(prb_t &prb, const memory_desc_t &imd); -+ - /** sorts the problem nodes so that output strides come in ascending order */ - void prb_normalize(prb_t &p); - -@@ -82,6 +91,7 @@ struct call_param_t { - const void *in; - void *out; - const float *scale; -+ size_t blk_chunks; - }; - - struct kernel_t { -diff --git a/src/cpu/aarch64/jit_uni_reorder_utils.cpp b/src/cpu/aarch64/jit_uni_reorder_utils.cpp -index 3d6e424e3..7123811f8 100644 ---- a/src/cpu/aarch64/jit_uni_reorder_utils.cpp -+++ b/src/cpu/aarch64/jit_uni_reorder_utils.cpp -@@ -1,6 +1,7 @@ - /******************************************************************************* - * Copyright 2018-2021 Intel Corporation - * Copyright 2020 FUJITSU LIMITED -+* Copyright 2022 Arm Ltd. and affiliates - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. -@@ -15,7 +16,8 @@ - * limitations under the License. - *******************************************************************************/ - --#include -+#include -+#include - - #include "common/c_types_map.hpp" - #include "common/dnnl_thread.hpp" -@@ -46,8 +48,65 @@ struct layout_desc_t { - strides_t strides; - }; - --status_t cvt_mem_desc_to_layout_desc( -- const memory_desc_t &md_, layout_desc_t &ld, const dims_t &blocks) { -+static status_t compute_blk_and_tail( -+ const memory_desc_t &md_, const int idx, int &blk, int &tail) { -+ const auto md = memory_desc_wrapper(md_); -+ const auto &bd = md.blocking_desc(); -+ if (tail == 0) return status::success; -+ -+ const std::set unique_inner_idxs( -+ bd.inner_idxs, bd.inner_idxs + bd.inner_nblks); -+ std::set dims_with_multiple_blks; -+ for (dim_t dim : unique_inner_idxs) { -+ if (std::count(bd.inner_idxs, bd.inner_idxs + bd.inner_nblks, dim) > 1) -+ dims_with_multiple_blks.insert(dim); -+ } -+ -+ // Dims that have a tail and have multiple blocks are not supported by the jit kernel yet. -+ // For example: -+ // src_tag = abcd -+ // dst_tag = ABcd16b16a4b -+ // 16x15x3x3 -+ // In this case, 'b' dim has two blocks and has a tail. It is not a supported case. -+ if (dims_with_multiple_blks.find(idx) != dims_with_multiple_blks.end()) -+ return status::unimplemented; -+ -+ // Only supports inconsistent padding in single and double blocks -+ // and the total block size <= 256 -+ for (int iblk = bd.inner_nblks - 1; iblk > 0; --iblk) { -+ if (bd.inner_idxs[iblk] == idx) break; -+ blk *= bd.inner_blks[iblk]; -+ tail *= bd.inner_blks[iblk]; -+ } -+ if (unique_inner_idxs.size() > 2 || blk > 256) return status::unimplemented; -+ -+ return status::success; -+} -+ -+static status_t compute_chunk_idx(const prb_t &p, const memory_desc_t &imd_, -+ const memory_desc_t &omd_, const int blk_idx, int &chunk_idx) { -+ const auto imd = memory_desc_wrapper(imd_); -+ const auto omd = memory_desc_wrapper(omd_); -+ const auto &ibd = imd.blocking_desc(); -+ const auto &obd = omd.blocking_desc(); -+ if (p.ip_tail == 0 && p.op_tail == 0) return status::success; -+ -+ const ptrdiff_t is -+ = ibd.strides[blk_idx] * obd.inner_blks[obd.inner_idxs[blk_idx]]; -+ const ptrdiff_t os = obd.strides[blk_idx]; -+ -+ for (int i = blk_idx; i < omd.ndims(); ++i) { -+ if (p.nodes[i].os == os && p.nodes[i].is == is) { -+ chunk_idx = i; -+ return status::success; -+ } -+ } -+ -+ return status::invalid_arguments; -+} -+ -+status_t cvt_mem_desc_to_layout_desc(const memory_desc_t &md_, -+ layout_desc_t &ld, const dims_t &blocks, const dims_t &ext_padding) { - const auto md = memory_desc_wrapper(md_); - - bool ok = true && md.is_blocking_desc() && md.extra().flags == 0; -@@ -75,7 +134,7 @@ status_t cvt_mem_desc_to_layout_desc( - stride *= bd.inner_blks[iblk]; - } - } -- P(d, md.padded_dims()[d] / blocks[d], bd.strides[d]); -+ P(d, (md.padded_dims()[d] + ext_padding[d]) / blocks[d], bd.strides[d]); - - // TODO: NOW: revisit, do we need a reverse? - // TODO: NOW: consider using strides instead of block sizes in md -@@ -98,7 +157,8 @@ status_t prb_init(prb_t &p, const memory_desc_t &imd, const memory_desc_t &omd, - - auto check_post_ops = [](const primitive_attr_t *attr) { - const auto &po = attr->post_ops_; -- return po.len() == 0 || (po.len() == 1 && po.entry_[0].is_sum(false)); -+ return po.len() == 0 -+ || (po.len() == 1 && po.contain(primitive_kind::sum, 0)); - }; - - bool ok = im_d.is_blocking_desc() && om_d.is_blocking_desc() -@@ -110,26 +170,58 @@ status_t prb_init(prb_t &p, const memory_desc_t &imd, const memory_desc_t &omd, - && check_post_ops(attr); - if (!ok) return unimplemented; - -- dims_t iblocks, oblocks; -+ dims_t iblocks, oblocks, ip_padding, op_padding; - im_d.compute_blocks(iblocks); - om_d.compute_blocks(oblocks); -+ utils::array_set(ip_padding, 0, im_d.ndims()); -+ utils::array_set(op_padding, 0, om_d.ndims()); -+ -+ /* padding_dim consistency check -+ * only supports inconsitent padding for src -+ * TODO: Add inconsistent padding support for dst */ -+ int ip_tail = 0; -+ int op_tail = 0; -+ int iblk_w_tail = 1; -+ int oblk_w_tail = 1; -+ int blk_idx = 0; - -- /* padding_dim consistency check */ - for (int d = 0; d < im_d.ndims(); ++d) { -- const auto pdim = im_d.padded_dims()[d]; -- bool ok = true && pdim == om_d.padded_dims()[d] -- && pdim % iblocks[d] == 0 && pdim % oblocks[d] == 0; -- if (!ok) return unimplemented; -+ const int ip_tmp_dim = im_d.padded_dims()[d]; -+ const int op_tmp_dim = om_d.padded_dims()[d]; -+ const int ip_tmp_tail = ip_tmp_dim % oblocks[d]; -+ const int op_tmp_tail = op_tmp_dim % iblocks[d]; -+ -+ const bool pdim_consistent = ip_tmp_dim == op_tmp_dim -+ && ip_tmp_tail == 0 && op_tmp_tail == 0; -+ const bool pdim_tail = ip_tmp_tail > 0 -+ && (ip_tmp_dim + oblocks[d] - ip_tmp_tail) == op_tmp_dim -+ && op_tmp_tail == 0 && ip_tail == 0; -+ if (!pdim_consistent && !pdim_tail) return status::unimplemented; -+ if (pdim_tail) { -+ blk_idx = d; -+ ip_tail = ip_tmp_tail; -+ op_tail = op_tmp_tail; -+ iblk_w_tail = iblocks[d]; -+ oblk_w_tail = oblocks[d]; -+ ip_padding[d] = oblocks[d] - ip_tmp_tail; -+ op_padding[d] = iblocks[d] - op_tmp_tail; -+ } - } -+ CHECK(compute_blk_and_tail(omd, blk_idx, oblk_w_tail, ip_tail)); - - layout_desc_t ild, old; -- status_t status = cvt_mem_desc_to_layout_desc(imd, ild, iblocks); -+ status_t status -+ = cvt_mem_desc_to_layout_desc(imd, ild, iblocks, ip_padding); - if (status != success) return status; -- status = cvt_mem_desc_to_layout_desc(omd, old, oblocks); -+ status = cvt_mem_desc_to_layout_desc(omd, old, oblocks, op_padding); - if (status != success) return status; - - p.itype = ild.dt; - p.otype = old.dt; -+ p.ip_tail = ip_tail; -+ p.op_tail = op_tail; -+ p.iblock = iblk_w_tail; -+ p.oblock = oblk_w_tail; - - p.scale_type = attr->output_scales_.has_default_values() - ? scale_type_t::NONE -@@ -156,7 +248,6 @@ status_t prb_init(prb_t &p, const memory_desc_t &imd, const memory_desc_t &omd, - - while (i_pos < ild.ndims && o_pos < old.ndims) { - assert(ild.id[i_pos] == old.id[o_pos]); -- if (ild.id[i_pos] != old.id[o_pos]) return runtime_error; - - assert(ndims < max_ndims); - if (ndims == max_ndims) return runtime_error; -@@ -191,7 +282,12 @@ status_t prb_init(prb_t &p, const memory_desc_t &imd, const memory_desc_t &omd, - ild.dims[i_pos] = factor; - } - } -+ int blk_chunk_idx = ndims; -+ CHECK(compute_chunk_idx(p, imd, omd, blk_idx, blk_chunk_idx)); -+ - p.ndims = ndims; -+ p.full_ndims = ndims; -+ p.blk_chunk_idx = blk_chunk_idx; - - p.ioff = memory_desc_wrapper(imd).offset0(); - p.ooff = memory_desc_wrapper(omd).offset0(); -@@ -211,8 +307,28 @@ void prb_normalize(prb_t &p) { - && p.nodes[j].n < p.nodes[min_pos].n); - if (new_min) min_pos = j; - } -- if (min_pos != d) nstl::swap(p.nodes[d], p.nodes[min_pos]); -+ if (min_pos != d) { -+ nstl::swap(p.nodes[d], p.nodes[min_pos]); -+ if (p.blk_chunk_idx == min_pos || p.blk_chunk_idx == d) -+ p.blk_chunk_idx = p.blk_chunk_idx == min_pos ? d : min_pos; -+ } -+ } -+} -+ -+status_t prb_check_blk(prb_t &p, const memory_desc_t &md_) { -+ const auto md = memory_desc_wrapper(md_); -+ const auto &bd = md.blocking_desc(); -+ if (p.ip_tail == 0) return status::success; -+ -+ // Check if the inner blocks and p.nodes[blk].n in the firsti nblks -+ // is equivalent in reverse order when has tail in block layout. -+ const int nblk = bd.inner_nblks; -+ for (int iblk = 0; iblk < nblk; ++iblk) { -+ if (bd.inner_blks[nblk - iblk - 1] -+ != static_cast(p.nodes[iblk].n)) -+ return status::unimplemented; - } -+ return status::success; - } - - void prb_simplify(prb_t &p) { -@@ -225,18 +341,29 @@ void prb_simplify(prb_t &p) { - for (int d = 0; d < p.ndims - 1; ++d) { - auto &this_node = p.nodes[d + 0]; - auto &next_node = p.nodes[d + 1]; -+ const bool skip_blk_idx = (p.ip_tail > 0 || p.op_tail > 0) -+ && (p.blk_chunk_idx == d || p.blk_chunk_idx == d + 1); - const bool fold = false -- || next_node.n == (size_t)1 // trivial case, just drop next node -+ || (next_node.n == static_cast(1) -+ && !skip_blk_idx) // trivial case, just drop next node - || (true // or real folding if possible -- && next_node.is == (ptrdiff_t)this_node.n * this_node.is -- && next_node.os == (ptrdiff_t)this_node.n * this_node.os -+ && !skip_blk_idx -+ && next_node.is -+ == static_cast( -+ this_node.n * this_node.is) -+ && next_node.os -+ == static_cast( -+ this_node.n * this_node.os) - && next_node.ss -- == (ptrdiff_t)this_node.n * this_node.ss); -+ == static_cast( -+ this_node.n * this_node.ss)); - if (fold) { - this_node.n *= next_node.n; - for (int j = d + 2; j < p.ndims; ++j) - p.nodes[j - 1] = p.nodes[j]; -+ if (d < p.blk_chunk_idx) --p.blk_chunk_idx; - --p.ndims; -+ --p.full_ndims; - --d; // make another try - } - } -@@ -251,6 +378,8 @@ void prb_node_split(prb_t &p, int dim, size_t n1) { - assert(p.nodes[dim].n % n1 == 0); - - p.ndims += 1; -+ p.full_ndims += 1; -+ if (dim < p.blk_chunk_idx) p.blk_chunk_idx += 1; - - for (int d = p.ndims; d > dim + 1; --d) - p.nodes[d] = p.nodes[d - 1]; diff --git a/third_party/mkl_dnn/onednn_acl_reorder_update.patch b/third_party/mkl_dnn/onednn_acl_reorder_update.patch deleted file mode 100644 index 3ac5a62906ff4c..00000000000000 --- a/third_party/mkl_dnn/onednn_acl_reorder_update.patch +++ /dev/null @@ -1,4193 +0,0 @@ -From b84c533dad4db495a92fc6d390a7db5ebd938a88 Mon Sep 17 00:00:00 2001 -From: Kentaro Kawakami -Date: Tue, 1 Nov 2022 09:33:41 +0900 -Subject: [PATCH] cpu: aarch64: reorder: support jit-ed blk_reorder - ---- - src/cpu/aarch64/jit_generator.hpp | 20 + - src/cpu/aarch64/jit_uni_reorder.cpp | 2315 +++++++++++++---- - src/cpu/aarch64/jit_uni_reorder.hpp | 183 +- - src/cpu/aarch64/jit_uni_reorder_utils.cpp | 482 ++-- - .../reorder/cpu_reorder_regular_f32_f32.cpp | 6 + - .../reorder/cpu_reorder_regular_f32_s32.cpp | 2 + - .../reorder/cpu_reorder_regular_f32_s8.cpp | 2 + - .../reorder/cpu_reorder_regular_f32_u8.cpp | 2 + - src/cpu/reorder/cpu_reorder_regular_s32.cpp | 2 + - src/cpu/reorder/cpu_reorder_regular_s8.cpp | 2 + - src/cpu/reorder/cpu_reorder_regular_u8.cpp | 2 + - 11 files changed, 2272 insertions(+), 746 deletions(-) - -diff --git a/src/cpu/aarch64/jit_generator.hpp b/src/cpu/aarch64/jit_generator.hpp -index dd781a622e1..12de9fa8c01 100644 ---- a/src/cpu/aarch64/jit_generator.hpp -+++ b/src/cpu/aarch64/jit_generator.hpp -@@ -435,6 +435,26 @@ class jit_generator : public Xbyak_aarch64::CodeGenerator, public c_compatible { - Xbyak_aarch64::ZRegD(z3.getIdx())); - } - -+ void uni_ld1rw(const Xbyak_aarch64::VReg4S &dst, -+ const Xbyak_aarch64::XReg &base, const int64_t off) { -+ if (off == 0) { -+ ld1r(dst, ptr(base)); -+ } else { -+ add_imm(X_DEFAULT_ADDR, base, off, X_TMP_0); -+ ld1r(dst, ptr(X_DEFAULT_ADDR)); -+ } -+ } -+ -+ void uni_ld1rw(const Xbyak_aarch64::ZRegS &dst, -+ const Xbyak_aarch64::XReg &base, const int64_t off) { -+ if (-32 <= off && off < 32) { -+ ld1rw(dst, P_ALL_ONE / Xbyak_aarch64::T_z, ptr(base, (int)off)); -+ } else { -+ add_imm(X_DEFAULT_ADDR, base, off, X_TMP_0); -+ ld1rw(dst, P_ALL_ONE / Xbyak_aarch64::T_z, ptr(X_DEFAULT_ADDR)); -+ } -+ } -+ - void uni_ldr( - const Xbyak_aarch64::VReg &dst, const Xbyak_aarch64::XReg &addr) { - ldr(Xbyak_aarch64::QReg(dst.getIdx()), ptr(addr)); -diff --git a/src/cpu/aarch64/jit_uni_reorder.cpp b/src/cpu/aarch64/jit_uni_reorder.cpp -index a6cefaa20e8..a708da808c0 100644 ---- a/src/cpu/aarch64/jit_uni_reorder.cpp -+++ b/src/cpu/aarch64/jit_uni_reorder.cpp -@@ -1,6 +1,6 @@ - /******************************************************************************* --* Copyright 2018-2021 Intel Corporation --* Copyright 2020-2021 FUJITSU LIMITED -+* Copyright 2018-2022 Intel Corporation -+* Copyright 2020-2022 FUJITSU LIMITED - * Copyright 2022 Arm Ltd. and affiliates - * - * Licensed under the Apache License, Version 2.0 (the "License"); -@@ -19,19 +19,21 @@ - #include - #include - --#include "dnnl_debug.h" -+#include "oneapi/dnnl/dnnl_debug.h" - - #include "common/c_types_map.hpp" -+#include "common/dnnl_thread.hpp" - #include "common/memory_desc_wrapper.hpp" - #include "common/nstl.hpp" - #include "common/primitive.hpp" - #include "common/type_helpers.hpp" - #include "common/utils.hpp" - --#include "cpu/aarch64/jit_uni_reorder.hpp" - #include "cpu/cpu_primitive.hpp" - #include "cpu/reorder/cpu_reorder_pd.hpp" - -+#include "cpu/aarch64/jit_uni_reorder.hpp" -+ - #include "cpu/aarch64/jit_generator.hpp" - - // #define TR_DEBUG -@@ -67,23 +69,6 @@ static bool prb_has_small_strides(const prb_t &prb) { - return true; - } - --static bool prb_tail_friendly(const prb_t &prb) { -- /* find optimal ndims to makes it easier to -- * identify the blk_chunk in the loop*/ -- int ndims = prb.full_ndims - prb.ndims; -- -- int n = prb.nodes[0].is; -- for (int d = 1; d < prb.ndims; ++d) { -- if (d != prb.blk_chunk_idx) n *= prb.nodes[d].n; -- } -- if (prb.ip_tail > 0 -- && ((ndims == 0 && n != 1) -- || (ndims > 0 && prb.ndims > prb.blk_chunk_idx))) -- return false; -- -- return true; --} -- - /** Minimal reasonable/desirable kernel size. - * The constant might be used to determine how a problem should be split - * between kernel and threading driver. */ -@@ -96,6 +81,9 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { - void operator()(const call_param_t *c) const override { - jit_generator::operator()(c); - } -+ void operator()(const tail_call_param_t *c) const override { -+ jit_generator::operator()(c); -+ } - - status_t create_kernel() override { return jit_generator::create_kernel(); } - -@@ -105,30 +93,53 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { - }; - - struct simple_impl_desc_t { -- int ndims_full_unroll; -- int len_last_dim_unroll; -- int len_unroll; -+ int ndims_full_unroll = 0; -+ int len_last_dim_unroll = 0; -+ int tail_len_unroll = 0; -+ int len_unroll = 0; - }; - -+#define PARAM(x) \ -+ abi_param1, \ -+ prb_.is_tail_present ? offsetof(tail_call_param_t, base_params) \ -+ + offsetof(call_param_t, x) \ -+ : offsetof(call_param_t, x) -+#define TAIL_PARAM(x) abi_param1, offsetof(tail_call_param_t, x) -+ - static bool simple_impl_desc_init( - const prb_t &prb, simple_impl_desc_t *desc) { - const int ndims = prb.ndims; - - int ndims_full_unroll = 0; - int len_last_dim_unroll = 1; -+ int tail_len_unroll = 0; - int len_unroll = 1; - -- for (int d = 0; d < ndims; ++d) { -- auto &node = prb.nodes[d]; -- if (len_unroll * node.n <= len_unroll_max) { -- ndims_full_unroll++; -- len_unroll *= node.n; -- } else { -- len_last_dim_unroll = len_unroll_max / len_unroll; -- while (node.n % len_last_dim_unroll) -- --len_last_dim_unroll; -- len_unroll *= len_last_dim_unroll; -- break; -+ // It is responsible for finding as many values -+ // as kernel can unroll. If tail is present then -+ // kernel will unroll only last node (possible improvement). -+ // If there is no tail kernel can unroll a few nodes without any loops etc. -+ // ndims_full_unroll - how many nodes will be unrolled -+ // len_last_dim_unroll - what piece of last unrolled node will be unrolled -+ if (prb.is_tail_present) { -+ ndims_full_unroll = 1; -+ len_unroll = prb.nodes[0].n; -+ tail_len_unroll = prb.nodes[0].is_zero_pad_needed -+ ? 0 -+ : static_cast(prb.nodes[0].tail_size); -+ } else { -+ for (int d = 0; d < ndims; ++d) { -+ const auto &node = prb.nodes[d]; -+ if (len_unroll * node.n <= len_unroll_max) { -+ ndims_full_unroll++; -+ len_unroll *= node.n; -+ } else { -+ len_last_dim_unroll = len_unroll_max / len_unroll; -+ while (node.n % len_last_dim_unroll) -+ --len_last_dim_unroll; -+ len_unroll *= len_last_dim_unroll; -+ break; -+ } - } - } - -@@ -137,6 +148,7 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { - if (desc) { - desc->ndims_full_unroll = ndims_full_unroll; - desc->len_last_dim_unroll = len_last_dim_unroll; -+ desc->tail_len_unroll = tail_len_unroll; - desc->len_unroll = len_unroll; - } - -@@ -151,62 +163,69 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { - && utils::one_of(p.otype, f32, s32, data_type::s8, u8) - && utils::everyone_is(0, p.ioff, p.ooff) /* do we need this? */ - && utils::one_of(p.beta, 0.f, 1.f) /* anything else? */ -- && simple_impl_desc_init(p, nullptr) && prb_has_small_strides(p) -- && prb_tail_friendly(p); -- if (!ok) return false; -+ && simple_impl_desc_init(p, nullptr) -+ && prb_has_small_strides(p); - -- return true; -+ return ok; - } - -- int n(int d) { -- assert(d < prb_.ndims); -- return (int)prb_.nodes[d].n; -- } -- int is(int d) { -- assert(d < prb_.ndims); -- return (int)prb_.nodes[d].is; -- } -- int os(int d) { -- assert(d < prb_.ndims); -- return (int)prb_.nodes[d].os; -+ XReg o_addr(int o_off, bool with_type_multiplier = true) { -+ if (o_off) { -+ add_imm(X_DEFAULT_ADDR, x_ptr_out_off, -+ o_off * (with_type_multiplier ? otype_sz_ : 1), X_TMP_0); -+ return X_DEFAULT_ADDR; -+ } -+ -+ return x_ptr_out_off; - } -- int ss(int d) { -- assert(d < prb_.ndims); -- return (int)prb_.nodes[d].ss; -+ -+ XReg c_addr(int c_off) { -+ if (c_off) { -+ add_imm(X_DEFAULT_ADDR, x_ptr_comp_off, c_off, X_TMP_0); -+ return X_DEFAULT_ADDR; -+ } -+ -+ return x_ptr_comp_off; - } - -- int blk_cnt() { -- assert(prb_.blk_chunk_idx < prb_.full_ndims); -- return (int)prb_.nodes[prb_.blk_chunk_idx].n - 1; -+ XReg data_chunk_addr(int node_id) { -+ add_imm(X_DEFAULT_ADDR, abi_param1, -+ offsetof(tail_call_param_t, curr_data_chunks) -+ + sizeof(int64_t) * (node_id), -+ X_TMP_0); -+ return X_DEFAULT_ADDR; - } -- int op_padding() { return prb_.op_tail ? prb_.iblock - prb_.op_tail : 0; } -- int ip_padding() { return prb_.ip_tail ? prb_.oblock - prb_.ip_tail : 0; } - - void step(int off, int prev_i_off, int prev_o_off, int prev_s_off, -- int &i_off, int &o_off, int &s_off, int step_size = 1) { -+ int prev_c_off, int &i_off, int &o_off, int &s_off, int &c_off, -+ int step_size = 1) { - i_off = prev_i_off; - o_off = prev_o_off; - s_off = prev_s_off; -+ c_off = prev_c_off; - - if (off == 0) return; - - int start_dim = 0, dims_prod = 1; - for (; start_dim < prb_.ndims && dims_prod != step_size; ++start_dim) -- dims_prod *= n(start_dim); -+ dims_prod *= prb_.n(start_dim); - assert(start_dim < prb_.ndims); - off /= step_size; - -- for (int d = start_dim; d < prb_.ndims; ++d) { -- i_off += is(d); -- o_off += os(d); -- s_off += ss(d); -+ for (int dim_id = start_dim; dim_id < prb_.ndims; ++dim_id) { -+ i_off += prb_.is(dim_id); -+ o_off += prb_.os(dim_id); -+ s_off += prb_.ss(dim_id); -+ c_off += prb_.cs(dim_id); -+ -+ if (off % prb_.n(dim_id)) break; - -- if (off % n(d)) break; -+ i_off += -prb_.n(dim_id) * prb_.is(dim_id); -+ o_off += -prb_.n(dim_id) * prb_.os(dim_id); -+ s_off += -prb_.n(dim_id) * prb_.ss(dim_id); -+ c_off += -prb_.n(dim_id) * prb_.cs(dim_id); - -- i_off += -n(d) * is(d); -- o_off += -n(d) * os(d); -- s_off += -n(d) * ss(d); -- off /= n(d); -+ off /= prb_.n(dim_id); - - if (off == 0) break; /* FIXME: is it really required? */ - } -@@ -215,8 +234,8 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { - void step(int off, int prev_i_off, int prev_o_off, int &i_off, int &o_off, - int step_size = 1) { - int dummy = 0; -- step(off, prev_i_off, prev_o_off, dummy, i_off, o_off, dummy, -- step_size); -+ step(off, prev_i_off, prev_o_off, dummy, dummy, i_off, o_off, dummy, -+ dummy, step_size); - } - - void tr8x8_sve256(int i_off, int o_off) { -@@ -278,40 +297,36 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { - && interim_f32); - const uint64_t sveLen = get_sve_length(); - -- add_imm(X_TMP_0, XReg(x_ptr_in_off), i_off * itype_sz, X_DEFAULT_ADDR); -- add_imm(X_TMP_1, X_TMP_0, is(0) * itype_sz, X_DEFAULT_ADDR); -- add_imm(X_TMP_2, X_TMP_1, is(0) * itype_sz, X_DEFAULT_ADDR); -- add_imm(X_TMP_3, X_TMP_2, is(0) * itype_sz, X_DEFAULT_ADDR); -- -- if (unroll * itype_sz == 32) -- for (uint32_t i = 0; i < 4; i++) -- ld1w(ZRegS {i}, p_lsb_256 / T_z, ptr(x_tmp_vec[i])); -- else if (unroll * itype_sz == 16) -- for (uint32_t i = 0; i < 4; i++) -- ldr(QReg {i}, ptr(x_tmp_vec[i])); -- else if (unroll * itype_sz == 8) -- for (uint32_t i = 0; i < 4; i++) -- ldr(DReg {i}, ptr(x_tmp_vec[i])); -- -- add_imm(X_TMP_0, X_TMP_3, is(0) * itype_sz, X_DEFAULT_ADDR); -- add_imm(X_TMP_1, X_TMP_0, is(0) * itype_sz, X_DEFAULT_ADDR); -- add_imm(X_TMP_2, X_TMP_1, is(0) * itype_sz, X_DEFAULT_ADDR); -- add_imm(X_TMP_3, X_TMP_2, is(0) * itype_sz, X_DEFAULT_ADDR); -- -- if (unroll * itype_sz == 32) -- for (uint32_t i = 0; i < 4; i++) -- ld1w(ZRegS {4 + i}, p_lsb_256 / T_z, ptr(x_tmp_vec[i])); -- else if (unroll * itype_sz == 16) -- for (uint32_t i = 0; i < 4; i++) -- ldr(QReg {4 + i}, ptr(x_tmp_vec[i])); -- else if (unroll * itype_sz == 8) -- for (uint32_t i = 0; i < 4; i++) -- ldr(DReg {4 + i}, ptr(x_tmp_vec[i])); -+ PReg p_size(DUMMY_IDX); -+ switch (unroll * itype_sz_) { -+ case 32: p_size = p_lsb_256; break; -+ case 16: p_size = p_lsb_128; break; -+ case 8: p_size = p_lsb_64; break; -+ default: assert(!"unreachable"); -+ } -+ -+ const int node_0_input_stride = prb_.is(0); -+ add_imm(X_TMP_0, XReg(x_ptr_in_off), itype_sz_ * i_off, X_DEFAULT_ADDR); -+ for (int i = 1; i < unroll / 2; i++) { -+ add_imm(x_tmp_vec[i], x_tmp_vec[i - 1], -+ itype_sz_ * node_0_input_stride, X_DEFAULT_ADDR); -+ } -+ -+ for (uint32_t i = 0; i < unroll / 2; i++) -+ ld1w(ZRegS {i}, p_size / T_z, ptr(x_tmp_vec[i])); -+ -+ for (int i = 0; i < unroll / 2; i++) { -+ add_imm(x_tmp_vec[i], x_tmp_vec[(i + 3) % 4], -+ itype_sz_ * node_0_input_stride, X_DEFAULT_ADDR); -+ } -+ -+ for (uint32_t i = 0; i < unroll / 2; i++) -+ ld1w(ZRegS {4 + i}, p_size / T_z, ptr(x_tmp_vec[i])); - - if (interim_f32) cvt2ps(0, unroll, prb_.itype); - - #if 0 -- /* Deubg code */ -+ /* Debug code */ - index(z0.s, 0, 1); - mov(z0.s, P_NOT_256/T_m, 0); - mov(z_tmp_vec[0].s, 16); -@@ -348,9 +363,9 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { - for (uint32_t i = 0; i < unroll / 2; i++) { - ZRegB z {unroll / 2 + i}; - ZRegB z_tmp = z_tmp_vec[unroll / 2 + i].b; -- /* Move bit 128-255 to 0-127. */ -- ext(z, z, 16); - /* Move bit 0-127 to 128-255. */ -+ ext(z, z, 16); -+ /* Move bit 128-255 to 0-127. */ - ext(z_tmp, z_tmp, sveLen - 16); - } - -@@ -363,65 +378,64 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { - } - - if (need_saturation) { -- init_saturate_f32(ymm_zero, ymm_saturation_ubound, reg_tmp, -+ init_saturate_f32(ymm_zero_, ymm_saturation_ubound_, reg_tmp_, - interim_f32 ? f32 : prb_.itype, prb_.otype); - for (int i = 0; i < unroll; i++) -- saturate_f32(ZRegS(i), ymm_zero, ymm_saturation_ubound, -- prb_.otype, p_all); -+ saturate_f32(ZRegS(i), ymm_zero_, ymm_saturation_ubound_, -+ prb_.otype, P_ALL_ONE); - } - - if (prb_.otype != f32) - cvt2odt(0, unroll, prb_.otype, interim_f32 ? f32 : prb_.itype); - -- add_imm(X_TMP_0, XReg(x_ptr_out_off), o_off * otype_sz, X_DEFAULT_ADDR); -- add_imm(X_TMP_1, X_TMP_0, os(1) * otype_sz, X_DEFAULT_ADDR); -- add_imm(X_TMP_2, X_TMP_1, os(1) * otype_sz, X_DEFAULT_ADDR); -- add_imm(X_TMP_3, X_TMP_2, os(1) * otype_sz, X_DEFAULT_ADDR); -- -- if (unroll * otype_sz == 32) -- for (uint32_t i = 0; i < 4; i++) -- st1w(ZRegS {i}, p_lsb_256 / T_z, ptr(x_tmp_vec[i])); -- else if (unroll * otype_sz == 16) -- for (uint32_t i = 0; i < 4; i++) -- str(QReg {i}, ptr(x_tmp_vec[i])); -- else if (unroll * otype_sz == 8) -- for (uint32_t i = 0; i < 4; i++) -- str(DReg {i}, ptr(x_tmp_vec[i])); -- -- add_imm(X_TMP_0, X_TMP_3, os(1) * otype_sz, X_DEFAULT_ADDR); -- add_imm(X_TMP_1, X_TMP_0, os(1) * otype_sz, X_DEFAULT_ADDR); -- add_imm(X_TMP_2, X_TMP_1, os(1) * otype_sz, X_DEFAULT_ADDR); -- add_imm(X_TMP_3, X_TMP_2, os(1) * otype_sz, X_DEFAULT_ADDR); -- -- if (unroll * otype_sz == 32) -- for (uint32_t i = 0; i < 4; i++) -- st1w(ZRegS {4 + i}, p_lsb_256 / T_z, ptr(x_tmp_vec[i])); -- else if (unroll * otype_sz == 16) -- for (uint32_t i = 0; i < 4; i++) -- str(QReg {4 + i}, ptr(x_tmp_vec[i])); -- else if (unroll * otype_sz == 8) -- for (uint32_t i = 0; i < 4; i++) -- str(DReg {4 + i}, ptr(x_tmp_vec[i])); -+ const int node_1_output_stride = prb_.os(1); -+ -+ switch (unroll * otype_sz_) { -+ case 32: p_size = p_lsb_256; break; -+ case 16: p_size = p_lsb_128; break; -+ case 8: p_size = p_lsb_64; break; -+ default: assert(!"unreachable"); -+ } -+ -+ add_imm(X_TMP_0, XReg(x_ptr_out_off), otype_sz_ * o_off, -+ X_DEFAULT_ADDR); -+ for (int i = 1; i < unroll / 2; i++) { -+ add_imm(x_tmp_vec[i], x_tmp_vec[i - 1], -+ otype_sz_ * node_1_output_stride, X_DEFAULT_ADDR); -+ } -+ -+ for (uint32_t i = 0; i < 4; i++) -+ st1w(ZRegS {i}, p_size / T_z, ptr(x_tmp_vec[i])); -+ -+ for (int i = 0; i < unroll / 2; i++) { -+ add_imm(x_tmp_vec[i], x_tmp_vec[(i + 3) % 4], -+ otype_sz_ * node_1_output_stride, X_DEFAULT_ADDR); -+ } -+ -+ for (uint32_t i = 0; i < unroll / 2; i++) -+ st1w(ZRegS {4 + i}, p_size / T_z, ptr(x_tmp_vec[i])); - } - - bool can_do_tr8x8() { - using namespace data_type; - -- return get_sve_length() >= Xbyak_aarch64::util::SVE_256 -- && prb_.ndims >= 2 -+ static constexpr int desirable_node_size = 8; -+ static constexpr int desirable_stride = 1; -+ -+ return mayiuse(sve_256) && prb_.ndims >= 2 - && ((utils::one_of(prb_.itype, u8, data_type::s8, s32, f32) - && utils::one_of( - prb_.otype, u8, data_type::s8, s32, f32))) -- && utils::everyone_is(8, n(0), n(1)) -- && utils::everyone_is(1, os(0), is(1)) -- && utils::everyone_is(0, prb_.ip_tail, prb_.op_tail) -+ && utils::everyone_is(desirable_node_size, prb_.n(0), prb_.n(1)) -+ && utils::everyone_is(desirable_stride, prb_.os(0), prb_.is(1)) -+ && !prb_.is_tail_present - && prb_.scale_type == scale_type_t::NONE && prb_.beta == 0.f; - } - -- bool process_unroll_tr8x8(int len) { -+ bool process_unroll_tr8x8(const int ndims, const int len) { - if (!can_do_tr8x8()) return false; - -- const int step_size = n(0) * n(1); -+ const int step_size = prb_.n(0) * prb_.n(1); - int i_off = 0, o_off = 0; - for (int off = 0; off < len; off += step_size) { - step(off, i_off, o_off, i_off, o_off, step_size); -@@ -432,23 +446,56 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { - } - - template -- bool process_direct_copy(int len) { -+ bool process_direct_copy(const int ndims, const int len) { - using namespace data_type; - -- const int simd_w = cpu_isa_traits::vlen / itype_sz; -- bool can_do = true && mayiuse(isa) -- && utils::everyone_is(1, os(0), is(0)) -- && (false || prb_.itype == prb_.otype -+ static constexpr int desirable_stride = 1; -+ using TRegS = -+ typename utils::conditional::type; -+ const int simd_w = cpu_isa_traits::vlen / itype_sz_; -+ -+ // TODO: support tail_processing for direct copy -+ -+ const bool do_src_zp = prb_.req_src_zp; -+ const bool do_dst_zp = prb_.req_dst_zp; -+ const bool zp_applicable = IMPLICATION( -+ (do_src_zp || do_dst_zp), utils::one_of(prb_.itype, s32, f32)); -+ const bool can_do = true && mayiuse(isa) -+ && compensation_needed_ == false -+ && utils::everyone_is(desirable_stride, prb_.os(0), prb_.is(0)) -+ && (false || (prb_.itype == prb_.otype ? zp_applicable : false) - || (prb_.itype == s32 && prb_.otype == f32) - || (prb_.itype == f32 && prb_.otype == s32)) -- && len % simd_w == 0 && n(0) % len == 0 -- && prb_.ip_tail % simd_w == 0 && prb_.op_tail % simd_w == 0 -+ && len % simd_w == 0 && prb_.n(0) % len == 0 -+ && !prb_.is_tail_present - && prb_.scale_type == scale_type_t::NONE && prb_.beta == 0.f; - if (!can_do) return false; - -+ static constexpr int vmm_zp_last_idx = 15; -+ const auto vmm_src_zp -+ = TRegS(do_dst_zp ? vmm_zp_last_idx - 1 : vmm_zp_last_idx); -+ if (do_src_zp) { -+ uni_ld1rw(vmm_src_zp, PARAM(src_zp)); -+ uni_scvtf(vmm_src_zp, vmm_src_zp); -+ } -+ const auto vmm_dst_zp = TRegS(vmm_zp_last_idx); -+ if (do_dst_zp) { -+ uni_ld1rw(vmm_dst_zp, PARAM(dst_zp)); -+ uni_scvtf(vmm_dst_zp, vmm_dst_zp); -+ } -+ -+ const auto apply_zp_ps = [&](const TRegS vmm) { -+ if (do_src_zp) fsub(vmm, vmm, vmm_src_zp); -+ if (do_dst_zp) fadd(vmm, vmm, vmm_dst_zp); -+ }; -+ - for (int off = 0; off < len;) { -- const int unroll -+ // TODO: we need extra reg for proper saturation if otype == s32 -+ int unroll - = nstl::min(16 - (prb_.otype == s32), (len - off) / simd_w); -+ unroll = (do_src_zp || do_dst_zp) -+ ? nstl::min(unroll, 16 - do_src_zp - do_dst_zp) -+ : unroll; - - int ur = 0; - int tmp_ur = 0; -@@ -458,14 +505,11 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { - - do { - add_imm(x_tmp_vec[count++], x_ptr_in_off, -- (off + ur * simd_w) * itype_sz, X_DEFAULT_ADDR); -+ (off + ur * simd_w) * itype_sz_, X_DEFAULT_ADDR); - ur++; - } while (ur < unroll && count < x_tmp_vec_size); - - for (int i = 0; i < count; i++) { -- /* if (vlen == 64) -- ldr(ZReg(tmp_ur + i), ptr(x_tmp_vec[i])); -- else */ - if (vlen == 64 || vlen == 32) - ld1w(ZRegS(tmp_ur + i), p_lsb_256 / T_z, - ptr(x_tmp_vec[i])); -@@ -478,33 +522,28 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { - } - - if (prb_.itype != prb_.otype) { -- const int vlen = cpu_isa_traits::vlen; - for (int ur = 0; ur < unroll; ++ur) { -+ TRegS r(ur); - if (prb_.itype == s32 && prb_.otype == f32) { -- if (vlen == 64 || vlen == 32) { -- ZRegS r(ur); -- /* MSB side 256 bits are ignored. */ -- scvtf(r, p_all / T_m, r); -- } else if (vlen == 16) { -- VReg4S r(ur); -- scvtf(r, r); -- } else -- assert(!"unreachable"); -+ uni_scvtf(r, r); - } else if (prb_.itype == f32 && prb_.otype == s32) { -- /* Out of order can be expected. */ -- if (vlen == 64 || vlen == 32) { -- ZRegS r(ur); -- frinti(r, p_all / T_m, r); -- fcvtzs(r, p_all / T_m, r); -- } else if (vlen == 16) { -- VReg4S r(ur); -- frinti(r, r); -- fcvtzs(r, r); -- } else -- assert(!"unreachable"); -+ uni_frinti(r, r); -+ uni_fcvtzs(r, r); - } else - assert(!"unreachable"); - } -+ } else if (do_src_zp || do_dst_zp) { -+ for (int ur = 0; ur < unroll; ++ur) { -+ const auto vmm = TRegS(ur); -+ if (prb_.otype == f32) { -+ apply_zp_ps(vmm); -+ } else if (prb_.otype == s32) { -+ uni_scvtf(vmm, vmm); -+ apply_zp_ps(vmm); -+ uni_frinti(vmm, vmm); -+ uni_fcvtzs(vmm, vmm); -+ } -+ } - } - - ur = 0; -@@ -515,7 +554,7 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { - - do { - add_imm(x_tmp_vec[count++], x_ptr_out_off, -- (off + ur * simd_w) * otype_sz, X_DEFAULT_ADDR); -+ (off + ur * simd_w) * otype_sz_, X_DEFAULT_ADDR); - ur++; - } while (ur < unroll && count < x_tmp_vec_size); - -@@ -538,8 +577,8 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { - } - - void process_unroll_generic_step(int reg_unroll, const int *i_off, -- const int *o_off, const int *s_off, const int *ip_padding, -- const bool h_padded) { -+ const int *o_off, const int *s_off, const int *c_off, -+ const int *zero_padding, const bool tail_processing) { - using namespace data_type; - - auto cvt2ps -@@ -588,76 +627,84 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { - } - }; - -+ auto load_bytes_addr = [=](const int ur, const int r) { -+ add_imm(x_tmp_vec[r], x_ptr_in_off, i_off[ur + r] * itype_sz_, -+ X_DEFAULT_ADDR); -+ }; -+ auto load_bytes = [=](const int ur, int size, int r) { -+ switch (size) { -+ case 4: ld1(VReg4S(ur)[r], ptr(x_tmp_vec[r])); break; -+ case 2: ld1(VReg8H(ur)[r], ptr(x_tmp_vec[r])); break; -+ case 1: ld1(VReg16B(ur)[r], ptr(x_tmp_vec[r])); break; -+ default: assert(!"unreachable"); -+ } -+ }; -+ -+ auto store = [=](const XReg &addr, const VReg ymm, int size) { -+ const uint32_t xmm = ymm.getIdx(); -+ switch (size) { -+ case 16: str(QReg(xmm), ptr(addr)); break; -+ case 8: str(DReg(xmm), ptr(addr)); break; -+ case 4: str(SReg(xmm), ptr(addr)); break; -+ case 2: str(HReg(xmm), ptr(addr)); break; -+ case 1: str(BReg(xmm), ptr(addr)); break; -+ default: assert(!"unreachable"); -+ } -+ }; -+ - /* check whether loading 4 values at once is possible */ -- bool can_load_xmm = reg_unroll % 4 == 0; -+ static constexpr int xmm_vlen = 4; -+ bool can_load_xmm = reg_unroll % xmm_vlen == 0; - for (int ur = 1; ur < reg_unroll; ++ur) -- if (i_off[ur] != i_off[ur - 1] + 1) can_load_xmm = false; -- const int load_step = can_load_xmm ? 4 : 1; -+ if (i_off[ur] != i_off[ur - 1] + 1) { -+ can_load_xmm = false; -+ break; -+ } -+ const int load_step = can_load_xmm ? xmm_vlen : 1; - - /* check whether storing 4 values at once is possible */ -- bool can_store_xmm = reg_unroll % 4 == 0; -+ bool can_store_xmm = reg_unroll % xmm_vlen == 0; - for (int ur = 1; ur < reg_unroll; ++ur) -- if (o_off[ur] != o_off[ur - 1] + 1) can_store_xmm = false; -+ if (o_off[ur] != o_off[ur - 1] + 1) { -+ can_store_xmm = false; -+ break; -+ } - const int ur_step = can_store_xmm ? 4 : 1; - const int load_tail_step - = !can_load_xmm && can_store_xmm ? ur_step : load_step; - -- const bool interim_f32 = false -- || utils::one_of(f32, prb_.itype, prb_.otype) -- || prb_.scale_type != scale_type_t::NONE || prb_.beta != 0.f; -+ const bool interim_f32 = interim_f32_needed(); - - const bool need_saturation - = (utils::one_of(prb_.otype, u8, data_type::s8, s32) - && interim_f32); -- if (h_padded) { -+ -+ std::vector store_masks; -+ if (tail_processing) { - for (int ur = 0; ur < reg_unroll; ur += load_tail_step) { -- if (itype_sz == 4) -- movi(VReg4S(ur), 0); -- else if (itype_sz == 2) -- movi(VReg8H(ur), 0); -- else -- movi(VReg16B(ur), 0); -- /* x_tmp_vec = X_TMP_0 - X_TMP_4 -- Do not use X_TMP_? as the last arg. */ -- for (int r = 0; r < load_tail_step; ++r) { -- if (ip_padding[ur + r] == 0) { -- add_imm(x_tmp_vec[r], x_ptr_in_off, -- i_off[ur + r] * itype_sz, X_DEFAULT_ADDR); -- } -- } -+ uni_clear(VReg(ur)); -+ store_masks.push_back(0); - - for (int r = 0; r < load_tail_step; ++r) { -- if (ip_padding[ur + r] == 0) { -- if (itype_sz == 4) -- ld1(VReg4S(ur)[r], ptr(x_tmp_vec[r])); -- else if (itype_sz == 2) -- ld1(VReg8H(ur)[r], ptr(x_tmp_vec[r])); -- else -- ld1(VReg16B(ur)[r], ptr(x_tmp_vec[r])); -+ if (zero_padding[ur + r] == 0) { -+ store_masks.back() += 1 << r; -+ load_bytes_addr(ur, r); - } - } -+ -+ for (int r = 0; r < load_tail_step; ++r) -+ if (zero_padding[ur + r] == 0) load_bytes(ur, itype_sz_, r); - } - } else { - if (!can_load_xmm && can_store_xmm) { -- assert(ur_step == 4); -+ assert(ur_step == xmm_vlen); - /* load with stride */ - for (int ur = 0; ur < reg_unroll; ur += ur_step) { -- -- /* x_tmp_vec = X_TMP_0 - X_TMP_4 -- Do not use X_TMP_? as the last arg. */ - for (int r = 0; r < ur_step; ++r) { -- add_imm(x_tmp_vec[r], x_ptr_in_off, -- i_off[ur + r] * itype_sz, X_DEFAULT_ADDR); -- } -- -- for (int r = 0; r < ur_step; ++r) { -- if (itype_sz == 4) -- ld1(VReg4S(ur)[r], ptr(x_tmp_vec[r])); -- else if (itype_sz == 2) -- ld1(VReg8H(ur)[r], ptr(x_tmp_vec[r])); -- else -- ld1(VReg16B(ur)[r], ptr(x_tmp_vec[r])); -+ load_bytes_addr(ur, r); - } -+ for (int r = 0; r < ur_step; ++r) -+ load_bytes(ur, itype_sz_, r); - } - } else { - int ur = 0; -@@ -667,13 +714,13 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { - - do { - add_imm(x_tmp_vec[count++], x_ptr_in_off, -- i_off[ur] * itype_sz, X_DEFAULT_ADDR); -+ i_off[ur] * itype_sz_, X_DEFAULT_ADDR); - ur += load_step; - } while (ur < reg_unroll && count < x_tmp_vec_size); - - for (int i = 0; i < count; i++) { - -- switch (load_step * itype_sz) { -+ switch (load_step * itype_sz_) { - case 16: - ldr(QReg(tmp_ur), ptr(x_tmp_vec[i])); - break; -@@ -688,6 +735,7 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { - } - } - } -+ - /* xmm[:] <-- (f32)xmm[:] */ - if (interim_f32) { - const int cvt_step = nstl::max(load_step, ur_step); -@@ -702,30 +750,32 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { - if (fast_return) { - if (prb_.scale_type == scale_type_t::COMMON) - for (int ur = 0; ur < reg_unroll; ur += load_step) -- fmul(VReg4S(ur), VReg4S(ur), xmm_scale); -+ fmul(VReg4S(ur), VReg4S(ur), xmm_scale_); - if (prb_.otype != f32) { -- init_saturate_f32(xmm_zero, xmm_saturation_ubound, reg_tmp, -- interim_f32 ? f32 : prb_.itype, prb_.otype); -- for (int ur = 0; ur < reg_unroll; ur += load_step) -+ init_saturate_f32(xmm_zero_, xmm_saturation_ubound_, -+ reg_tmp_, interim_f32 ? f32 : prb_.itype, -+ prb_.otype); -+ for (int ur = 0; ur < reg_unroll; ur += load_step) { - if (need_saturation) -- saturate_f32(VReg4S(ur), xmm_zero, -- xmm_saturation_ubound, prb_.otype, p_all); -+ saturate_f32(VReg4S(ur), xmm_zero_, -+ xmm_saturation_ubound_, prb_.otype, -+ P_ALL_ONE); -+ } - - for (int ur = 0; ur < reg_unroll; ur += load_step) - cvt2odt(ur, 1, prb_.otype, - interim_f32 ? f32 : prb_.itype); - } -- /* load_step is 1 or 4. */ - for (int ur = 0; ur < reg_unroll; ur += load_step) { - for (int r = 0; r < load_step; ++r) { - add_imm(x_tmp_vec[r], x_ptr_out_off, -- o_off[ur + r] * otype_sz, X_DEFAULT_ADDR); -+ o_off[ur + r] * otype_sz_, X_DEFAULT_ADDR); - } - - for (int r = 0; r < load_step; ++r) { -- if (otype_sz == 4) -+ if (otype_sz_ == 4) - st1(VReg4S(ur)[r], ptr(x_tmp_vec[r])); -- else if (otype_sz == 2) -+ else if (otype_sz_ == 2) - st1(VReg8H(ur)[r], ptr(x_tmp_vec[r])); - else - st1(VReg16B(ur)[r], ptr(x_tmp_vec[r])); -@@ -735,7 +785,7 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { - } - - /* scatter elements of xmm into 4 xmms */ -- if (itype_sz == 4 || interim_f32) { -+ if (itype_sz_ == 4 || interim_f32) { - for (int ur = 0; ur < reg_unroll; ur += load_step) - for (int r = 1; r < load_step; ++r) { - VReg4S v(ur); -@@ -747,7 +797,18 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { - for (int ur = 0; ur < reg_unroll; ur += load_step) - for (int r = 1; r < load_step; ++r) - ext(VReg16B(ur + r), VReg16B(ur), VReg16B(ur), -- itype_sz * r); -+ itype_sz_ * r); -+ } -+ } -+ -+ /* src zero point application */ -+ if (prb_.req_src_zp) { -+ for (int ur = 0; ur < reg_unroll; ur += ur_step) { -+ const auto xmm = VReg4S(ur); -+ if (interim_f32) -+ fsub(xmm, xmm, xmm_src_zp_); -+ else -+ sub(xmm, xmm, xmm_src_zp_); - } - } - -@@ -756,7 +817,7 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { - /* xmm <-- scale * xmm[:] */ - if (prb_.scale_type == scale_type_t::COMMON) { - for (int ur = 0; ur < reg_unroll; ur += ur_step) -- fmul(VReg4S(ur), VReg4S(ur), xmm_scale); -+ fmul(VReg4S(ur), VReg4S(ur), xmm_scale_); - } else if (prb_.scale_type == scale_type_t::MANY) { - enum class scale_load_type_t { bcast, load, gather }; - -@@ -769,13 +830,12 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { - scale_load_type = scale_load_type_t::load; - - if (scale_load_type == scale_load_type_t::bcast -- && !h_padded) { -- VReg4S v(xmm_scale.getIdx()); -+ && !tail_processing) { -+ VReg4S v(xmm_scale_.getIdx()); - VReg4S v_dst(ur); -- add_imm(X_TMP_0, x_ptr_scale_off, s_off[ur] * stype_sz, -+ add_imm(X_TMP_0, x_ptr_scale_off, s_off[ur] * stype_sz_, - X_DEFAULT_ADDR); -- ldr(W_TMP_0, ptr(X_TMP_0)); -- dup(v, W_TMP_0); -+ ld1r(v, ptr(X_TMP_0)); - fmul(v_dst, v_dst, v); - continue; - } -@@ -786,10 +846,10 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { - scale_load_type = scale_load_type_t::gather; - - if (scale_load_type == scale_load_type_t::load -- && !h_padded) { -- uint32_t idx = xmm_scale.getIdx(); -+ && !tail_processing) { -+ uint32_t idx = xmm_scale_.getIdx(); - VReg4S v_dst(ur); -- add_imm(X_TMP_0, x_ptr_scale_off, s_off[ur] * stype_sz, -+ add_imm(X_TMP_0, x_ptr_scale_off, s_off[ur] * stype_sz_, - X_DEFAULT_ADDR); - - ldr(QReg {idx}, ptr(X_TMP_0)); -@@ -799,22 +859,15 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { - - // load doesn't work as well - // so gather the scale factors one by one -- /*ur_step is 1 or 4. */ -- for (int r = ur; r < ur + ur_step; ++r) { -- if (ip_padding[r] == 0 || !h_padded) { -- /* x_tmp_vec = X_TMP_0 - X_TMP_4 -- Do not use X_TMP_? as the last arg. */ -+ for (int r = ur; r < ur + ur_step; ++r) -+ if (zero_padding[r] == 0 || !tail_processing) { - add_imm(x_tmp_vec[r - ur], x_ptr_scale_off, -- s_off[r] * stype_sz, X_DEFAULT_ADDR); -- } -- } -- for (int r = ur; r < ur + ur_step; ++r) { -- if (ip_padding[r] == 0 || !h_padded) { -- VReg4S v(xmm_scale.getIdx()); -- ld1(v[r - ur], ptr(x_tmp_vec[r - ur])); -+ s_off[r] * stype_sz_, X_DEFAULT_ADDR); - } -- } -- fmul(VReg4S(ur), VReg4S(ur), xmm_scale); -+ for (int r = ur; r < ur + ur_step; ++r) -+ if (zero_padding[r] == 0 || !tail_processing) -+ ld1(xmm_scale_[r - ur], ptr(x_tmp_vec[r - ur])); -+ fmul(VReg4S(ur), VReg4S(ur), xmm_scale_); - } - } - -@@ -829,7 +882,7 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { - - do { - add_imm(x_tmp_vec[count++], x_ptr_out_off, -- o_off[ur] * otype_sz, X_DEFAULT_ADDR); -+ o_off[ur] * otype_sz_, X_DEFAULT_ADDR); - ur += ur_step; - } while (ur < reg_unroll && count < x_tmp_vec_size); - -@@ -873,7 +926,7 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { - if (prb_.scale_type == scale_type_t::COMMON) { - for (int ur = 0; ur < reg_unroll; ur += ur_step) { - VReg4S tmp(ur); -- fmul(tmp, tmp, VReg4S(xmm_scale.getIdx())); -+ fmul(tmp, tmp, VReg4S(xmm_scale_.getIdx())); - } - } else if (prb_.scale_type == scale_type_t::MANY) { - int ur = 0; -@@ -883,7 +936,7 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { - - do { - add_imm(x_tmp_vec[count++], x_ptr_scale_off, -- s_off[ur] * stype_sz, X_DEFAULT_ADDR); -+ s_off[ur] * stype_sz_, X_DEFAULT_ADDR); - ur += ur_step; - } while (ur < reg_unroll && count < x_tmp_vec_size); - -@@ -908,7 +961,7 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { - - do { - add_imm(x_tmp_vec[count++], x_ptr_out_off, -- o_off[ur] * otype_sz, X_DEFAULT_ADDR); -+ o_off[ur] * otype_sz_, X_DEFAULT_ADDR); - ur += ur_step; - } while (ur < reg_unroll && count < (x_tmp_vec_size / 2)); - -@@ -951,94 +1004,272 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { - } - } - -- if (need_saturation) { -- init_saturate_f32( -- xmm_zero, xmm_saturation_ubound, reg_tmp, f32, prb_.otype); -+ /* dst zero point application */ -+ if (prb_.req_dst_zp) { - for (int ur = 0; ur < reg_unroll; ur += ur_step) { -- saturate_f32(VReg4S(ur), xmm_zero, xmm_saturation_ubound, -- prb_.otype, p_all); -+ const auto xmm = VReg4S(ur); -+ if (interim_f32) -+ fadd(xmm, xmm, xmm_dst_zp_); -+ else -+ add(xmm, xmm, xmm_dst_zp_); - } - } - -- for (int ur = 0; ur < reg_unroll; ur += ur_step) { -- if (prb_.otype != f32) -- cvt2odt(ur, 1, prb_.otype, interim_f32 ? f32 : prb_.itype); -+ /* adjust scale application */ -+ if (prb_.scale_adjust != 1.f) { -+ dup(xmm_tmp_, reg_scale_adjust_); -+ for (int ur = 0; ur < reg_unroll; ur += ur_step) { -+ fmul(VReg4S(ur), VReg4S(ur), xmm_tmp_); -+ } -+ } -+ -+ if (need_saturation) { -+ init_saturate_f32(xmm_zero_, xmm_saturation_ubound_, reg_tmp_, f32, -+ prb_.otype); -+ for (int ur = 0; ur < reg_unroll; ur += ur_step) { -+ saturate_f32(VReg4S(ur), xmm_zero_, xmm_saturation_ubound_, -+ prb_.otype, P_ALL_ONE); -+ } -+ -+ // reset back xmm_zero_ if needed. -+ if (compensation_needed_ && (prb_.req_src_zp || prb_.req_dst_zp)) -+ uni_clear(VReg(xmm_zero_.getIdx())); - } - -- int ur = 0; -- int tmp_ur = 0; -- while (ur < reg_unroll) { -- int count = 0; -+ if (compensation_needed_) { -+ const uint32_t xmm_begin = 9; -+ const uint32_t xmm_end = 11; -+ uint32_t xmm_id = xmm_begin; -+ const auto get_temp_xmm = [&] { -+ const Xbyak_aarch64::VReg temp {xmm_id++}; -+ -+ if (xmm_id > xmm_end) { xmm_id = xmm_begin; } -+ -+ return temp; -+ }; -+ if (can_store_xmm) { -+ enum class comp_load_type_t { bcast, load, gather }; -+ -+ for (int ur = 0; ur < reg_unroll; ur += ur_step) { -+ -+ bool all_ip_padding_one = true; -+ bool all_ip_padding_zero = true; -+ for (int r = ur; r < ur + ur_step; r++) { -+ if (zero_padding[r] != 1) -+ all_ip_padding_one = false; -+ else -+ all_ip_padding_zero = false; -+ } -+ if (all_ip_padding_one) continue; -+ -+ comp_load_type_t comp_load_type = comp_load_type_t::bcast; -+ -+ for (int r = ur + 1; r < ur + ur_step; ++r) -+ if (c_off[r] != c_off[r - 1] + 0) { -+ comp_load_type = comp_load_type_t::load; -+ break; -+ } - -- do { -- add_imm(x_tmp_vec[count++], x_ptr_out_off, o_off[ur] * otype_sz, -- X_DEFAULT_ADDR); -- ur += ur_step; -- } while (ur < reg_unroll && count < x_tmp_vec_size); -+ if (comp_load_type == comp_load_type_t::bcast -+ && all_ip_padding_zero) { -+ const auto reduction_xmm = get_temp_xmm().s4; -+ const auto xmm_reorder_result = VReg4S(ur); -+ frinti(reduction_xmm, xmm_reorder_result); -+ addv(SReg(reduction_xmm.getIdx()), reduction_xmm); -+ const auto comp_addr = c_addr(c_off[ur]); -+ const auto xmm_tmp_ = get_temp_xmm().s4; -+ ldr(SReg(xmm_tmp_.getIdx()), ptr(comp_addr)); -+ add(xmm_tmp_, xmm_tmp_, reduction_xmm); -+ str(SReg(xmm_tmp_.getIdx()), ptr(comp_addr)); -+ continue; -+ } -+ -+ if (comp_load_type == comp_load_type_t::load) -+ for (int r = ur + 1; r < ur + ur_step; ++r) -+ if (c_off[r] != c_off[r - 1] + 1) { -+ comp_load_type = comp_load_type_t::gather; -+ break; -+ } -+ -+ if (comp_load_type == comp_load_type_t::load -+ && all_ip_padding_zero) { -+ const auto xmm_reorder_result_dq = get_temp_xmm().s4; -+ const auto xmm_reorder_result = VReg4S(ur); -+ const auto comp_addr = c_addr(c_off[ur]); -+ frinti(xmm_reorder_result_dq, xmm_reorder_result); -+ const auto xmm_tmp_ = get_temp_xmm().s4; -+ ldr(SReg(xmm_tmp_.getIdx()), ptr(comp_addr)); -+ add(xmm_reorder_result_dq, xmm_reorder_result_dq, -+ xmm_tmp_); -+ str(SReg(xmm_tmp_.getIdx()), ptr(comp_addr)); -+ continue; -+ } - -- for (int i = 0; i < count; i++) { -+ const auto xmm_reorder_result_dq = get_temp_xmm().s4; -+ const auto xmm_reorder_result = VReg4S(ur); -+ frinti(xmm_reorder_result_dq, xmm_reorder_result); - -- switch (ur_step * otype_sz) { -- case 16: str(QReg(tmp_ur), ptr(x_tmp_vec[i])); break; -- case 8: str(DReg(tmp_ur), ptr(x_tmp_vec[i])); break; -- case 4: str(SReg(tmp_ur), ptr(x_tmp_vec[i])); break; -- case 2: str(HReg(tmp_ur), ptr(x_tmp_vec[i])); break; -- case 1: str(BReg(tmp_ur), ptr(x_tmp_vec[i])); break; -- default: assert(!"unreachable"); -+ for (int r = ur; r < ur + ur_step; ++r) { -+ if (zero_padding[r] == 0 || !tail_processing) { -+ mov(W_TMP_0, xmm_reorder_result_dq[r]); -+ const auto comp_addr = c_addr(c_off[ur]); -+ str(W_TMP_0, ptr(comp_addr)); -+ } -+ } -+ } -+ } else { -+ for (int ur = 0; ur < reg_unroll; ur += ur_step) { -+ if (zero_padding[ur] == 0 || !tail_processing) { -+ const auto xmm_reorder_result_dq = get_temp_xmm().s4; -+ const auto xmm_reorder_result = VReg4S(ur); -+ const auto comp_addr = c_addr(c_off[ur]); -+ frinti(xmm_reorder_result_dq, xmm_reorder_result); -+ const auto xmm_tmp_ = get_temp_xmm().s4; -+ ldr(SReg(xmm_tmp_.getIdx()), ptr(comp_addr)); -+ add(xmm_reorder_result_dq, xmm_reorder_result_dq, -+ xmm_tmp_); -+ str(SReg(xmm_tmp_.getIdx()), ptr(comp_addr)); -+ } - } -- tmp_ur += ur_step; - } - } -+ -+ for (int ur = 0; ur < reg_unroll; ur += ur_step) { -+ if (prb_.req_src_zp || prb_.req_dst_zp) { -+ const bool use_store_masks = !store_masks.empty(); -+ if (use_store_masks) { -+ const auto mask = (~store_masks[ur / ur_step]) & 0xF; -+ switch (mask) { -+ case 0x0: -+ /* Do nothing */ -+ break; -+ case 0x1: ins(VReg4S(ur)[0], xmm_zero_[0]); break; -+ case 0x2: ins(VReg4S(ur)[1], xmm_zero_[1]); break; -+ case 0x3: -+ ins(VReg2D(ur)[0], VReg2D(xmm_zero_.getIdx())[0]); -+ break; -+ case 0x4: ins(VReg4S(ur)[2], xmm_zero_[2]); break; -+ case 0x5: -+ ins(VReg4S(ur)[0], xmm_zero_[0]); -+ ins(VReg4S(ur)[2], xmm_zero_[2]); -+ break; -+ case 0x6: -+ ins(VReg4S(ur)[1], xmm_zero_[1]); -+ ins(VReg4S(ur)[2], xmm_zero_[2]); -+ break; -+ case 0x7: -+ ins(VReg2D(ur)[0], VReg2D(xmm_zero_.getIdx())[0]); -+ ins(VReg4S(ur)[2], xmm_zero_[2]); -+ break; -+ case 0x8: ins(VReg4S(ur)[3], xmm_zero_[3]); break; -+ case 0x9: -+ ins(VReg4S(ur)[0], xmm_zero_[0]); -+ ins(VReg4S(ur)[3], xmm_zero_[3]); -+ break; -+ case 0xa: -+ ins(VReg4S(ur)[1], xmm_zero_[1]); -+ ins(VReg4S(ur)[3], xmm_zero_[3]); -+ break; -+ case 0xb: -+ ins(VReg2D(ur)[0], VReg2D(xmm_zero_.getIdx())[0]); -+ ins(VReg4S(ur)[3], xmm_zero_[3]); -+ break; -+ case 0xc: -+ ins(VReg2D(ur)[1], VReg2D(xmm_zero_.getIdx())[1]); -+ break; -+ case 0xd: -+ ins(VReg4S(ur)[0], xmm_zero_[0]); -+ ins(VReg2D(ur)[1], VReg2D(xmm_zero_.getIdx())[1]); -+ break; -+ case 0xe: -+ ins(VReg4S(ur)[1], xmm_zero_[1]); -+ ins(VReg2D(ur)[1], VReg2D(xmm_zero_.getIdx())[1]); -+ break; -+ case 0xf: movi(VReg16B(ur), 0); break; -+ default: assert(!"unreachable"); -+ } -+ } -+ } -+ if (prb_.otype != f32) -+ cvt2odt(ur, 1, prb_.otype, interim_f32 ? f32 : prb_.itype); -+ -+ store(o_addr(o_off[ur]), VReg(ur), ur_step * otype_sz_); -+ } - } - -- void comp_padding_flag(int ndims, int off, int len, int &i_tail) { -- const int ip_without_padding -- = ndims == 0 ? len - ip_padding() : prb_.ip_tail; -- if ((ndims == 0 && off >= ip_without_padding) -- || (ndims > 0 && (off % prb_.oblock) >= ip_without_padding)) -- i_tail = 1; -+ bool interim_f32_needed() { -+ using namespace data_type; -+ -+ return utils::one_of(f32, prb_.itype, prb_.otype) -+ || prb_.scale_type != scale_type_t::NONE || prb_.beta != 0.f -+ || ((prb_.req_src_zp || prb_.req_dst_zp) -+ ? !(prb_.itype == s32 && prb_.otype == s32) -+ : false) -+ || (prb_.itype != f32 && compensation_needed_) -+ || prb_.scale_adjust != 1.f; - } - -- void process_unroll_generic(const int ndims, int len, const bool h_padded) { -+ void process_unroll_generic( -+ const int ndims, int len, const bool tail_processing) { -+ assert(IMPLICATION(prb_.nodes[0].tail_size > 0, -+ len == static_cast(prb_.nodes[0].n) -+ || len == static_cast(prb_.nodes[0].tail_size))); -+ - const int blk = 8; - - int i_off[2 * blk] = {0}; - int o_off[2 * blk] = {0}; - int s_off[2 * blk] = {0}; -+ int c_off[2 * blk] = {0}; - - int curr = 0; // will switch between 0 and 1 - -+ const bool interim_f32 = interim_f32_needed(); -+ -+ if (prb_.req_src_zp) { -+ add_imm(X_DEFAULT_ADDR, PARAM(src_zp), X_TMP_0); -+ ld1r(xmm_src_zp_, ptr(X_DEFAULT_ADDR)); -+ if (interim_f32) scvtf(xmm_src_zp_, xmm_src_zp_); -+ } -+ if (prb_.req_dst_zp) { -+ add_imm(X_DEFAULT_ADDR, PARAM(dst_zp), X_TMP_0); -+ ld1r(xmm_dst_zp_, ptr(X_DEFAULT_ADDR)); -+ if (interim_f32) scvtf(xmm_dst_zp_, xmm_dst_zp_); -+ } -+ - for (int off = 0; off < len; off += blk) { - const int reg_unroll = nstl::min(off + blk, len) - off; -- int ip_padding[blk] = {0}; -+ int zero_padding[blk] = {0}; -+ const auto curr_blk = curr * blk; - - /* compute offsets and tail*/ - for (int ur = off != 0 ? 0 : 1; ur < reg_unroll; ++ur) { -- const int ur_c = curr * blk + ur; -+ const int ur_c = curr_blk + ur; - const int ur_p = (ur_c - 1 + 2 * blk) % (2 * blk); // prev ur -+ const bool is_tail -+ = off + ur >= static_cast(prb_.nodes[0].tail_size); - step(off + ur, i_off[ur_p], o_off[ur_p], s_off[ur_p], -- i_off[ur_c], o_off[ur_c], s_off[ur_c]); -- if (h_padded) -- comp_padding_flag(ndims, off + ur, len, ip_padding[ur]); -+ c_off[ur_p], i_off[ur_c], o_off[ur_c], s_off[ur_c], -+ c_off[ur_c]); -+ if (tail_processing && is_tail) zero_padding[ur] = 1; - } -- process_unroll_generic_step(reg_unroll, i_off + curr * blk, -- o_off + curr * blk, s_off + curr * blk, ip_padding, -- h_padded); -+ -+ process_unroll_generic_step(reg_unroll, i_off + curr_blk, -+ o_off + curr_blk, s_off + curr_blk, c_off + curr_blk, -+ zero_padding, tail_processing); - - curr = 1 - curr; - } - } - - void compute_ker( -- const int ndims, const int len_unroll, const bool h_padded) { -+ const int ndims, const int len_unroll, const bool tail_processing) { - bool optimized = false; -- optimized = optimized -- || (process_direct_copy(len_unroll) && !h_padded); -- optimized = optimized -- || (process_direct_copy(len_unroll) && !h_padded); -- optimized -- = optimized || (process_unroll_tr8x8(len_unroll) && !h_padded); -- if (!optimized) process_unroll_generic(ndims, len_unroll, h_padded); -+ optimized = optimized || process_direct_copy(ndims, len_unroll) -+ || process_direct_copy(ndims, len_unroll) -+ || process_unroll_tr8x8(ndims, len_unroll); -+ if (!optimized) -+ process_unroll_generic(ndims, len_unroll, tail_processing); - } - - void loop_begin(Label &l, XReg reg_cnt, int len) { -@@ -1046,97 +1277,287 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { - L(l); - } - -+ void check_if_this_is_last_chunk(const XReg reg_curr_chunk, int node_id) { -+ // Chunks are backwards numered i.e: -+ // [0] -> [node_size] -+ // [1] -> [node_size - 1] -+ // ... -+ // [node_size - 1] -> [1] -+ -+ // It is done like this, because it is easier to decrement counter -+ // and check if it is equal to zero than increment and check -+ // if it is equal to node_size. -+ static constexpr int64_t last_chunk = 1; -+ cmp(reg_curr_chunk, last_chunk); -+ } -+ -+ void zero_dst_memory(const int bytes_to_zeroing) { -+ static constexpr int num_of_bytes_in_xmm = 128 / 8; -+ -+ const int xmms_to_zeroing -+ = std::div(bytes_to_zeroing, num_of_bytes_in_xmm).quot; -+ const int tail_to_zeroing -+ = std::div(bytes_to_zeroing, num_of_bytes_in_xmm).rem; -+ -+ movi(xmm_tmp_, 0); -+ -+ if (xmms_to_zeroing > 0) { -+ Label loop; -+ -+ mov(reg_tmp_, xmms_to_zeroing); -+ L(loop); -+ str(QReg(xmm_tmp_.getIdx()), ptr(o_addr(0))); -+ add_imm(reg_off_out_, reg_off_out_, num_of_bytes_in_xmm, X_TMP_0); -+ add_imm(x_ptr_out_off, x_ptr_out_off, num_of_bytes_in_xmm, X_TMP_0); -+ subs(reg_tmp_, reg_tmp_, 1); -+ mov(X_TMP_0, 32); -+ b(NE, loop); -+ } -+ -+ if (tail_to_zeroing) mov_imm(W_TMP_0, 0); -+ -+ for (int i = 0; i < tail_to_zeroing; i++) -+ strb(W_TMP_0, ptr(o_addr(i, false))); -+ -+ // Restore dst offset to initial value -+ if (xmms_to_zeroing > 0) { -+ sub_imm(reg_off_out_, reg_off_out_, -+ num_of_bytes_in_xmm * xmms_to_zeroing, X_TMP_0); -+ sub_imm(x_ptr_out_off, x_ptr_out_off, -+ num_of_bytes_in_xmm * xmms_to_zeroing, X_TMP_0); -+ } -+ } -+ -+ void finalize_tail_loop(int i_step, int o_step, int s_step, int c_step, -+ const int curr_node_id) { -+ static constexpr int empty_chunk_info = -1; -+ -+ mov(reg_tmp_, empty_chunk_info); -+ str(reg_tmp_, ptr(data_chunk_addr(curr_node_id))); -+ -+ const int padded_area = prb_.nodes[curr_node_id].n -+ - prb_.nodes[curr_node_id].tail_size; -+ -+ if (prb_.nodes[curr_node_id].is_zero_pad_needed) { -+ int num_of_zero_padded_values = padded_area; -+ for (int i = curr_node_id - 1; i >= 0; i--) { -+ num_of_zero_padded_values *= prb_.nodes[i].n; -+ } -+ -+ const int bytes_to_zeroing = num_of_zero_padded_values * otype_sz_; -+ zero_dst_memory(bytes_to_zeroing); -+ } -+ -+ // This function is called by loop_end. At the end -+ // of loop_end is section that is responsible for -+ // restoring offset values. Restoring is based on -+ // len value which is equal to prb.nodes[x].n. -+ // If fill_zero_padded_area is called then it means -+ // offsets were shifted prb.nodes[x].tail_size times. -+ // Therefore, this function has to shift offsets by -+ // zero pad area. -+ add_imm(reg_off_in_, reg_off_in_, padded_area * i_step * itype_sz_, -+ X_TMP_0); -+ add_imm(reg_off_out_, reg_off_out_, padded_area * o_step * otype_sz_, -+ X_TMP_0); -+ add_imm(x_ptr_in_off, x_ptr_in_off, padded_area * i_step * itype_sz_, -+ X_TMP_0); -+ add_imm(x_ptr_out_off, x_ptr_out_off, padded_area * o_step * otype_sz_, -+ X_TMP_0); -+ if (prb_.scale_type == scale_type_t::MANY) { -+ add_imm(reg_off_scale_, reg_off_scale_, -+ padded_area * s_step * stype_sz_, X_TMP_0); -+ add_imm(x_ptr_scale_off, x_ptr_scale_off, -+ padded_area * s_step * stype_sz_, X_TMP_0); -+ } -+ if (compensation_needed_) { -+ add_imm(reg_off_comp_, reg_off_comp_, -+ padded_area * c_step * sizeof(int32_t), X_TMP_0); -+ add_imm(x_ptr_comp_off, x_ptr_comp_off, -+ padded_area * c_step * sizeof(int32_t), X_TMP_0); -+ } -+ } -+ - void loop_end(Label &l, XReg reg_cnt, int len, int i_step, int o_step, -- int s_step) { -- add_imm(reg_off_in, reg_off_in, i_step * itype_sz, X_TMP_0); -- add_imm(reg_off_out, reg_off_out, o_step * otype_sz, X_TMP_0); -- add_imm(x_ptr_in_off, x_ptr_in_off, i_step * itype_sz, X_TMP_0); -- add_imm(x_ptr_out_off, x_ptr_out_off, o_step * otype_sz, X_TMP_0); -+ int s_step, int c_step, const int curr_node_id) { -+ add_imm(reg_off_in_, reg_off_in_, i_step * itype_sz_, X_TMP_0); -+ add_imm(reg_off_out_, reg_off_out_, o_step * otype_sz_, X_TMP_0); -+ add_imm(x_ptr_in_off, x_ptr_in_off, i_step * itype_sz_, X_TMP_0); -+ add_imm(x_ptr_out_off, x_ptr_out_off, o_step * otype_sz_, X_TMP_0); - - if (prb_.scale_type == scale_type_t::MANY) { -- add_imm(reg_off_scale, reg_off_scale, s_step * stype_sz, X_TMP_0); -- add_imm(x_ptr_scale_off, x_ptr_scale_off, s_step * stype_sz, -+ add_imm(reg_off_scale_, reg_off_scale_, s_step * stype_sz_, -+ X_TMP_0); -+ add_imm(x_ptr_scale_off, x_ptr_scale_off, s_step * stype_sz_, - X_TMP_0); - } -+ -+ if (compensation_needed_) { -+ add_imm(reg_off_comp_, reg_off_comp_, c_step * sizeof(int32_t), -+ X_TMP_0); -+ add_imm(x_ptr_comp_off, x_ptr_comp_off, c_step * sizeof(int32_t), -+ X_TMP_0); -+ } -+ - subs(reg_cnt, reg_cnt, 1); - b(NE, l); - -- sub_imm(reg_off_in, reg_off_in, len * i_step * itype_sz, X_TMP_0); -- sub_imm(reg_off_out, reg_off_out, len * o_step * otype_sz, X_TMP_0); -- sub_imm(x_ptr_in_off, x_ptr_in_off, len * i_step * itype_sz, X_TMP_0); -- sub_imm(x_ptr_out_off, x_ptr_out_off, len * o_step * otype_sz, X_TMP_0); -+ if (prb_.tail(curr_node_id) != 0) { -+ Label if_end; -+ -+ // On the stack should be an information if node -+ // was processed with tail or not. -+ ldr(reg_tmp_, post_ptr(X_SP, reg_tmp_.getBit() / 8)); -+ -+ cmp(reg_tmp_, with_tail_info_); -+ b(NE, if_end); -+ finalize_tail_loop(i_step, o_step, s_step, c_step, curr_node_id); -+ L(if_end); -+ } -+ -+ // Restore offset to initial values. It means before -+ // loop execution. -+ sub_imm(reg_off_in_, reg_off_in_, len * i_step * itype_sz_, X_TMP_0); -+ sub_imm(reg_off_out_, reg_off_out_, len * o_step * otype_sz_, X_TMP_0); -+ sub_imm(x_ptr_in_off, x_ptr_in_off, len * i_step * itype_sz_, X_TMP_0); -+ sub_imm(x_ptr_out_off, x_ptr_out_off, len * o_step * otype_sz_, -+ X_TMP_0); - - if (prb_.scale_type == scale_type_t::MANY) { -- sub_imm(reg_off_scale, reg_off_scale, len * s_step * stype_sz, -+ sub_imm(reg_off_scale_, reg_off_scale_, len * s_step * stype_sz_, - X_TMP_0); -- sub_imm(x_ptr_scale_off, x_ptr_scale_off, len * s_step * stype_sz, -+ sub_imm(x_ptr_scale_off, x_ptr_scale_off, len * s_step * stype_sz_, - X_TMP_0); - } -+ if (compensation_needed_) { -+ sub_imm(reg_off_comp_, reg_off_comp_, -+ len * c_step * sizeof(int32_t), X_TMP_0); -+ sub_imm(x_ptr_comp_off, x_ptr_comp_off, -+ len * c_step * sizeof(int32_t), X_TMP_0); -+ } - } - -- void compute_blk_ker(const int len_unroll) { -+ void compute_blk_ker(const simple_impl_desc_t &desc) { -+ static constexpr bool with_tail_processing = true; -+ Label no_last_chunk, end_label; - int omp_ndims = prb_.full_ndims - prb_.ndims; -- Label no_last_blk, end_label; - -- if (prb_.ip_tail > 0 && prb_.op_tail == 0) { -- if (omp_ndims == 0) { -- cmp(reg_last_loop_cnt, 1); -- bne(no_last_blk); -- compute_ker(omp_ndims, len_unroll, true); -- } else { -- cmp(reg_blk_chunks, blk_cnt()); -- bne(no_last_blk); -- compute_ker(omp_ndims, len_unroll, true); -+ if (prb_.nodes[0].tail_size > 0) { -+ if (!prb_.nodes[0].is_parent_empty()) { -+ const int parent_node_id = prb_.nodes[0].parent_node_id; -+ ldr(reg_tmp_, ptr(data_chunk_addr(parent_node_id))); -+ check_if_this_is_last_chunk(reg_tmp_, parent_node_id); -+ b(NE, no_last_chunk); - } -+ -+ const int len_unroll = desc.tail_len_unroll > 0 -+ ? desc.tail_len_unroll -+ : desc.len_unroll; -+ compute_ker(omp_ndims, len_unroll, with_tail_processing); - b(end_label); - } - -- L(no_last_blk); -- compute_ker(omp_ndims, len_unroll, false); -+ L(no_last_chunk); -+ compute_ker(omp_ndims, desc.len_unroll, !with_tail_processing); - L(end_label); - } - -+ void create_loops(const simple_impl_desc_t &desc, -+ const std::array ®_cnt, int jit_loop) { -+ assert(jit_loop <= ndims_jit_loop_max); -+ -+ if (jit_loop > 0) { -+ const int nfu = desc.ndims_full_unroll; -+ const int unroll_factor -+ = jit_loop == 1 ? desc.len_last_dim_unroll : 1; -+ const int curr_node_id = nfu + (jit_loop - 1); -+ const int parent_node_id = prb_.nodes[curr_node_id].parent_node_id; -+ const int tail_size = prb_.tail(curr_node_id) / unroll_factor; -+ const int node_size = prb_.n(curr_node_id) / unroll_factor; -+ const XReg reg_loop_cnt = reg_cnt[jit_loop - 1]; -+ const bool curr_node_has_tail = prb_.tail(curr_node_id) != 0; -+ Label loop, if_no_tail, if_end; -+ -+ if (curr_node_has_tail) { -+ const size_t reg_bytes = reg_tmp_.getBit() / 8; -+ if (prb_.nodes[curr_node_id].is_parent_empty()) { -+ mov(reg_loop_cnt, tail_size); -+ // Put info that node is being processed with tail. -+ mov(reg_tmp_, with_tail_info_); -+ str(reg_tmp_, pre_ptr(X_SP, -reg_bytes)); -+ } else { -+ ldr(reg_tmp_, ptr(data_chunk_addr(parent_node_id))); -+ check_if_this_is_last_chunk(reg_tmp_, parent_node_id); -+ b(NE, if_no_tail); -+ mov(reg_loop_cnt, tail_size); -+ // Put info that node is being processed with tail. -+ mov(reg_tmp_, with_tail_info_); -+ str(reg_tmp_, pre_ptr(X_SP, -reg_bytes)); -+ b(if_end); -+ -+ L(if_no_tail); -+ mov(reg_loop_cnt, node_size); -+ // Put info that node is being processed without tail. -+ mov(reg_tmp_, without_tail_info_); -+ str(reg_tmp_, pre_ptr(X_SP, -reg_bytes)); -+ L(if_end); -+ } -+ } -+ -+ if (prb_.is_tail_in_one_of_child_nodes(curr_node_id)) { -+ if (!curr_node_has_tail) { -+ mov(reg_loop_cnt, node_size); -+ str(reg_loop_cnt, ptr(data_chunk_addr(curr_node_id))); -+ } -+ L(loop); -+ if (!prb_.nodes[curr_node_id].is_parent_empty()) { -+ Label if_no_tail_in_child_node; -+ ldr(reg_tmp_, ptr(data_chunk_addr(parent_node_id))); -+ check_if_this_is_last_chunk(reg_tmp_, parent_node_id); -+ b(NE, if_no_tail_in_child_node); -+ str(reg_loop_cnt, ptr(data_chunk_addr(curr_node_id))); -+ L(if_no_tail_in_child_node); -+ } else { -+ str(reg_loop_cnt, ptr(data_chunk_addr(curr_node_id))); -+ } -+ } else if (curr_node_has_tail) { -+ L(loop); -+ } else { -+ loop_begin(loop, reg_loop_cnt, node_size); -+ } -+ create_loops(desc, reg_cnt, jit_loop - 1); -+ -+ loop_end(loop, reg_loop_cnt, node_size, -+ prb_.is(curr_node_id) * unroll_factor, -+ prb_.os(curr_node_id) * unroll_factor, -+ prb_.ss(curr_node_id) * unroll_factor, -+ prb_.cs(curr_node_id) * unroll_factor, curr_node_id); -+ } else { -+ compute_blk_ker(desc); -+ } -+ } -+ - bool simple_impl() { - simple_impl_desc_t d; - if (!simple_impl_desc_init(prb_, &d)) return false; - -- const int nfu = d.ndims_full_unroll; -- const int ldu = d.len_last_dim_unroll; -- const int n_jit_loops = prb_.ndims - d.ndims_full_unroll; -- assert(n_jit_loops <= ndims_jit_loop_max); -- -- eor(reg_off_in, reg_off_in, reg_off_in); -- eor(reg_off_out, reg_off_out, reg_off_out); -- mov(x_ptr_in_off, XReg(reg_ptr_in.getIdx())); -- mov(x_ptr_out_off, XReg(reg_ptr_out.getIdx())); -+ eor(reg_off_in_, reg_off_in_, reg_off_in_); -+ eor(reg_off_out_, reg_off_out_, reg_off_out_); -+ mov(x_ptr_in_off, reg_ptr_in_); -+ mov(x_ptr_out_off, reg_ptr_out_); - if (prb_.scale_type == scale_type_t::MANY) { -- eor(reg_off_scale, reg_off_scale, reg_off_scale); -- mov(x_ptr_scale_off, XReg(reg_ptr_scale.getIdx())); -+ mov(reg_off_scale_, 0); -+ mov(x_ptr_scale_off, reg_ptr_scale_); -+ } -+ if (compensation_needed_) { -+ eor(reg_off_comp_, reg_off_comp_, reg_off_comp_); -+ mov(x_ptr_comp_off, reg_off_comp_); - } - -- Label l_loop[3]; -- XReg reg_cnt[3] = {x15, x14, x13}; -- -- if (n_jit_loops > 2) loop_begin(l_loop[2], reg_cnt[2], n(nfu + 2)); -- -- if (n_jit_loops > 1) loop_begin(l_loop[1], reg_cnt[1], n(nfu + 1)); -- -- if (n_jit_loops > 0) -- loop_begin(l_loop[0], reg_cnt[0], n(nfu + 0) / ldu); -- -- compute_blk_ker(d.len_unroll); -- -- if (n_jit_loops > 0) -- loop_end(l_loop[0], reg_cnt[0], n(nfu + 0) / ldu, is(nfu + 0) * ldu, -- os(nfu + 0) * ldu, ss(nfu + 0) * ldu); -- -- if (n_jit_loops > 1) -- loop_end(l_loop[1], reg_cnt[1], n(nfu + 1), is(nfu + 1), -- os(nfu + 1), ss(nfu + 1)); -+ std::array reg_cnt({{x15, x14, x13}}); - -- if (n_jit_loops > 2) -- loop_end(l_loop[2], reg_cnt[2], n(nfu + 2), is(nfu + 2), -- os(nfu + 2), ss(nfu + 2)); -+ const int n_jit_loops = prb_.ndims - d.ndims_full_unroll; -+ create_loops(d, reg_cnt, n_jit_loops); - - return true; - } -@@ -1156,7 +1577,7 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { - inst(__VA_ARGS__); - - void cvt_z_s32_f32(const size_t startIdx, const size_t regNum) { -- UNROLL_INST(scvtf, ZRegS, tmp, p_all / T_m, tmp); -+ UNROLL_INST(scvtf, ZRegS, tmp, P_ALL_ONE / T_m, tmp); - } - - void cvt_v_s32_f32(const size_t startIdx, const size_t regNum) { -@@ -1164,8 +1585,8 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { - } - - void cvt_z_f32_s32(const size_t startIdx, const size_t regNum) { -- UNROLL_INST(frinti, ZRegS, tmp, p_all / T_m, tmp); -- UNROLL_INST(fcvtzs, ZRegS, tmp, p_all / T_m, tmp); -+ UNROLL_INST(frinti, ZRegS, tmp, P_ALL_ONE / T_m, tmp); -+ UNROLL_INST(fcvtzs, ZRegS, tmp, P_ALL_ONE / T_m, tmp); - } - - void cvt_v_f32_s32(const size_t startIdx, const size_t regNum) { -@@ -1175,7 +1596,7 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { - - void cvt_z_s8_s32(const size_t startIdx, const size_t regNum) { - cvt_z_b_s(startIdx, regNum); -- UNROLL_INST(sxtb, ZRegS, tmp, p_all / T_m, tmp); -+ UNROLL_INST(sxtb, ZRegS, tmp, P_ALL_ONE / T_m, tmp); - } - - void cvt_v_s8_s32(const size_t startIdx, const size_t regNum) { -@@ -1214,7 +1635,7 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { - - void cvt_z_u8_s32(const size_t startIdx, const size_t regNum) { - cvt_z_b_s(startIdx, regNum); -- UNROLL_INST(uxtb, ZRegS, tmp, p_all / T_m, tmp); -+ UNROLL_INST(uxtb, ZRegS, tmp, P_ALL_ONE / T_m, tmp); - } - - void cvt_v_u8_s32(const size_t startIdx, const size_t regNum) { -@@ -1285,7 +1706,7 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { - - dupm(z_tmp7.s, 255); - UNROLL_INST2(smax, ZRegS(i), 0); -- UNROLL_INST2(smin, ZRegS(i), p_all / T_m, z_tmp7.s); -+ UNROLL_INST2(smin, ZRegS(i), P_ALL_ONE / T_m, z_tmp7.s); - UNROLL_INST(uzp1, ZRegH, tmp, tmp, tmp); - UNROLL_INST(uzp1, ZRegB, tmp, tmp, tmp); - UNROLL_INST2(mov, ZRegB(i), P_NOT_128 / T_m, 0); -@@ -1320,107 +1741,514 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { - #undef UNROLL_INST - #undef UNROLL_INST - -- jit_uni_reorder_kernel_f32_t(const desc_t &desc) : kernel_t(desc) { -- itype_sz = data_type_size(prb_.itype); -- otype_sz = data_type_size(prb_.otype); -- stype_sz = sizeof(float); -+ jit_uni_reorder_kernel_f32_t(const desc_t &desc) -+ : kernel_t(desc), isa_(get_max_cpu_isa()) { -+ assert(!utils::one_of(isa_, isa_undef, isa_all)); -+ itype_sz_ = data_type_size(prb_.itype); -+ otype_sz_ = data_type_size(prb_.otype); -+ stype_sz_ = sizeof(float); - } - - void generate() override { - using namespace Xbyak_aarch64::util; - uint64_t sveLen = get_sve_length(); -+ Label end_of_kernel; - - preamble(); --#define PARAM(x) offsetof(call_param_t, x) -+ - if (prb_.scale_type == scale_type_t::COMMON) { -- add_imm(X_DEFAULT_ADDR, abi_param1, PARAM(scale), X_TMP_1); -+ add_imm(X_DEFAULT_ADDR, PARAM(scale), X_TMP_1); - ldr(X_TMP_0, ptr(X_DEFAULT_ADDR)); -- ldr(W_TMP_1, ptr(X_TMP_0)); -- dup(xmm_scale, W_TMP_1); -+ ld1r(xmm_scale_, ptr(X_TMP_0)); - } else if (prb_.scale_type == scale_type_t::MANY) { -- add_imm(X_DEFAULT_ADDR, abi_param1, PARAM(scale), X_TMP_0); -- ldr(reg_ptr_scale, ptr(X_DEFAULT_ADDR)); -+ add_imm(X_DEFAULT_ADDR, PARAM(scale), X_TMP_0); -+ ldr(reg_ptr_scale_, ptr(X_DEFAULT_ADDR)); - } -- add_imm(X_TMP_0, abi_param1, PARAM(in), X_TMP_2); -- add_imm(X_TMP_1, abi_param1, PARAM(out), X_TMP_2); -- add_imm(reg_blk, abi_param1, PARAM(blk_chunks), reg_blk); -- ldr(reg_ptr_in, ptr(X_TMP_0)); -- ldr(reg_ptr_out, ptr(X_TMP_1)); -- ldr(reg_blk_chunks, ptr(reg_blk)); -- --#undef PARAM -- mov_imm(reg_last_loop_cnt, 1); -+ if (compensation_needed_) { -+ add_imm(X_DEFAULT_ADDR, PARAM(compensation_scratch), X_TMP_0); -+ ldr(reg_ptr_comp_, ptr(X_DEFAULT_ADDR)); -+ } -+ if (prb_.scale_adjust == 0.5f) { mov(reg_scale_adjust_, 0x3f000000); } -+ add_imm(X_TMP_0, PARAM(in), X_TMP_2); -+ add_imm(X_TMP_1, PARAM(out), X_TMP_2); -+ ldr(reg_ptr_in_, ptr(X_TMP_0)); -+ ldr(reg_ptr_out_, ptr(X_TMP_1)); - -- mov(x_ptr_in_off, XReg(reg_ptr_in.getIdx())); -- mov(x_ptr_out_off, XReg(reg_ptr_out.getIdx())); -- mov(x_ptr_scale_off, XReg(reg_ptr_scale.getIdx())); -+ mov(x_ptr_in_off, reg_ptr_in_); -+ mov(x_ptr_out_off, reg_ptr_out_); -+ mov(x_ptr_scale_off, reg_ptr_scale_); -+ mov(x_ptr_comp_off, reg_ptr_comp_); - - if (sveLen) { /* SVE is available. */ - ptrue(p_lsb_256.b, VL32); -- ptrue(p_all.b); -+ ptrue(p_lsb_128.b, VL16); -+ ptrue(p_lsb_64.b, VL8); - } - -- if (can_do_tr8x8()) { -- dup(ymm_zero, 0); -- -- if (prb_.itype == data_type::u8 && prb_.otype == data_type::s8) { -- mov_imm(reg_tmp, 0x7f7f7f7f7f7f7f7f); -- mov(VReg4S(ymm_8x127b.getIdx())[0], WReg(reg_tmp.getIdx())); -+ bool is_tail_in_drv_dims = false; -+ for (int i = prb_.ndims; i < prb_.full_ndims; i++) -+ if (prb_.nodes[i].tail_size > 0) { -+ is_tail_in_drv_dims = true; -+ break; - } -- } else if (mayiuse(sve_512)) { -- movi(xmm_zero, 0); - -- if (prb_.itype == data_type::u8 && prb_.otype == data_type::s8) { -- mov(WReg(reg_tmp.getIdx()), 0x7f7f7f7f); -- mov(xmm_4x127b[0], WReg(reg_tmp.getIdx())); -+ if (is_tail_in_drv_dims) { -+ Label reorder_kernel; -+ add_imm(X_DEFAULT_ADDR, TAIL_PARAM(skip_kernel_execution), X_TMP_0); -+ ldr(reg_tmp_, ptr(X_DEFAULT_ADDR)); -+ cmp(reg_tmp_, static_cast(true)); -+ b(EQ, end_of_kernel); -+ -+ add_imm(X_DEFAULT_ADDR, TAIL_PARAM(zeroing_data), X_TMP_0); -+ ldr(reg_tmp_, ptr(X_DEFAULT_ADDR)); -+ cmp(reg_tmp_, static_cast(false)); -+ b(EQ, reorder_kernel); -+ // If zeroing data is set then all dst memory -+ // will be zeroed and nothing more will be done. -+ int bytes_to_zeroing = otype_sz_; -+ for (int i = 0; i < prb_.ndims; i++) { -+ bytes_to_zeroing *= prb_.nodes[i].n; - } -+ eor(reg_off_out_, reg_off_out_, reg_off_out_); -+ mov(x_ptr_out_off, reg_ptr_out_); -+ zero_dst_memory(bytes_to_zeroing); -+ b(end_of_kernel); -+ L(reorder_kernel); -+ } -+ -+ if (can_do_tr8x8()) { -+ dup(ymm_zero_, 0); -+ } else { -+ movi(xmm_zero_, 0); - } - - impl(); -+ -+ L(end_of_kernel); - postamble(); - } - -+ ~jit_uni_reorder_kernel_f32_t() override = default; -+ -+#undef TAIL_PARAM -+#undef PARAM -+ - private: -- int itype_sz; -- int otype_sz; -- int stype_sz; -+ static constexpr int64_t with_tail_info_ = static_cast(true); -+ static constexpr int64_t without_tail_info_ = static_cast(false); -+ -+ int itype_sz_; -+ int otype_sz_; -+ int stype_sz_; - -- XReg reg_ptr_in = x6; -- XReg reg_ptr_out = x2; -- XReg reg_ptr_scale = abi_not_param1; -+ const cpu_isa_t isa_; - -- XReg reg_off_in = x8; -- XReg reg_off_out = x9; -- XReg reg_off_scale = x10; -+ const XReg reg_ptr_in_ = x6; -+ const XReg reg_ptr_out_ = x2; -+ const XReg reg_ptr_scale_ = abi_not_param1; -+ const XReg reg_ptr_comp_ = x3; -+ const WReg ®_scale_adjust_ = w5; - -- XReg reg_blk = x11; -- XReg reg_blk_chunks = x12; -- XReg reg_last_loop_cnt = x11; -+ const XReg reg_off_in_ = x8; -+ const XReg reg_off_out_ = x9; -+ const XReg reg_off_scale_ = x10; -+ const XReg reg_off_comp_ = x11; - -- XReg reg_tmp = x0; -+ XReg reg_tmp_ = x12; - -- VReg4S xmm_scale = v15.s; -- VReg4S xmm_zero = v14.s; -- VReg4S xmm_4x127b = v13.s; // TODO: unite with ymm_zero -- ZRegS ymm_zero = z14.s; -- ZRegS ymm_8x127b = z13.s; -- VReg4S xmm_tmp = v12.s; -- VReg4S xmm_saturation_ubound = v12.s; -- ZRegS ymm_saturation_ubound = z12.s; -+ VReg4S xmm_scale_ = v15.s; -+ VReg4S xmm_zero_ = v14.s; -+ ZRegS ymm_zero_ = z14.s; -+ VReg4S xmm_tmp_ = v12.s; -+ const VReg4S xmm_src_zp_ = v9.s; -+ const VReg4S xmm_dst_zp_ = v11.s; -+ VReg4S xmm_saturation_ubound_ = v12.s; -+ ZRegS ymm_saturation_ubound_ = z12.s; - - /* Note: x22 - x28 are already used as temporal registgers - in jit_generator.hpp. -- x_ptr_(in|out|scale)_off keeps (base + offset) address. */ -+ x_ptr_(in|out|scale|comp)_off keeps (base + offset) address. */ - XReg x_ptr_in_off = x16; - XReg x_ptr_out_off = x18; - XReg x_ptr_scale_off = x20; -+ XReg x_ptr_comp_off = x17; - - /* Caution: Chose predicate registers not used by x64's implementation. */ - PReg p_lsb_256 = p7; -- PReg p_all = p6; -+ PReg p_lsb_128 = p6; -+ PReg p_lsb_64 = p4; - PReg p_tmp0 = p5; - - const std::vector tmp_vec_idx = {20, 21, 22, 23, 24, 25, 26, 27}; -+ VReg v_tmp0 = v20; -+ ZReg z_tmp0 = z20; -+ ZReg z_tmp1 = z21; -+ ZReg z_tmp2 = z22; -+ ZReg z_tmp3 = z23; -+ ZReg z_tmp4 = z24; -+ ZReg z_tmp5 = z25; -+ ZReg z_tmp6 = z26; -+ ZReg z_tmp7 = z27; -+ VReg v_tmp7 = v27; -+ -+ const std::vector z_tmp_vec -+ = {z_tmp0, z_tmp1, z_tmp2, z_tmp3, z_tmp4, z_tmp5, z_tmp6, z_tmp7}; -+ constexpr static int z_tmp_vec_size = 8; -+}; -+ -+// Seperate class for no unroll/threading burden -+struct jit_single_blk_kernel_t : public jit_generator { -+ DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_single_blk_kernel) -+ static bool applicable(const prb_t &p) { -+ using namespace data_type; -+ -+ bool ok = p.ndims >= 2 && mayiuse(sve_256) -+ && p.scale_type == scale_type_t::NONE -+ && utils::one_of(p.itype, f32) && utils::one_of(p.otype, f32) -+ && utils::everyone_is(0, p.ioff, p.ooff) && p.beta == 0.f -+ && prb_has_small_strides(p); -+ if (!ok) return false; -+ -+ int64_t n0 = p.nodes[0].n; -+ auto i0 = p.nodes[0].is; -+ auto o0 = p.nodes[0].os; -+ int64_t n1 = p.nodes[1].n; -+ auto i1 = p.nodes[1].is; -+ auto o1 = p.nodes[1].os; -+ -+ /* -+ * for a transpose of plain to 8c case, nodes would be like: -+ * n is os -+ * m 1 8 -+ * 8 m 1 -+ * or -+ * 8 m 1 -+ * m 1 8 -+ */ -+ ok = (utils::one_of(n0, 8, 16) || utils::one_of(n1, 8, 16)) -+ && ((i0 == 1 && o1 == 1 && n0 == i1 && o0 == n1) -+ || (o0 == 1 && i1 == 1 && n0 == o1 && i0 == n1)); -+ if (!ok) return false; -+ -+ // Do not handle transpose of dimensions other than last 2 -+ for (int i = 2; i < p.ndims; ++i) { -+ if (p.nodes[i].is != p.nodes[i].os) { -+ ok = false; -+ break; -+ } -+ } -+ -+ return ok; -+ } -+ -+ jit_single_blk_kernel_t(const tr::prb_t &prb) -+ : jit_generator() -+ , prb_(prb) -+ , itype_sz_(data_type_size(prb_.itype)) -+ , otype_sz_(data_type_size(prb_.otype)) -+ , block_sz(prb.nodes[0].n) {} -+ -+ void generate() override { -+ auto input_stride -+ = prb_.nodes[0].is != 1 ? prb_.nodes[0].is : prb_.nodes[1].is; -+ auto output_stride -+ = prb_.nodes[0].os != 1 ? prb_.nodes[0].os : prb_.nodes[1].os; -+ -+ Label tail_processing; -+ -+ const auto load_zp = [&](const ZRegS ymm_zp, const XReg reg_zp) { -+ dup(ymm_zp, WReg(reg_zp.getIdx())); -+ scvtf(ymm_zp, P_ALL_ONE / T_m, ymm_zp); -+ }; -+ -+ preamble(); -+ -+ if (prb_.req_src_zp) load_zp(ymm_src_zp, reg_src_zp); -+ -+ if (prb_.req_dst_zp) load_zp(ymm_dst_zp, reg_dst_zp); -+ -+ cmp(reg_ptr_tail, true); -+ b(EQ, tail_processing); -+ -+ if (block_sz == 8) { -+ gen_ker8x8(0, 0, input_stride, output_stride, 8, 8); -+ block_sz = 8; -+ } else if (block_sz == 16) { -+ gen_ker16x16_in_8x8(input_stride, output_stride); -+ block_sz = 16; -+ } else { -+ assert(!"unimplemented"); -+ } -+ -+ postamble(); -+ -+ L(tail_processing); -+ -+ if (block_sz == 8) { -+ auto i_tail = input_stride % 8 != 0 ? input_stride % 8 : 8; -+ auto o_tail = output_stride % 8 != 0 ? output_stride % 8 : 8; -+ if (i_tail != o_tail) { -+ auto t_mask = i_tail == 8 ? o_tail : i_tail; -+ gen_setmask(t_mask); -+ gen_ker8x8(0, 0, input_stride, output_stride, i_tail, o_tail); -+ } -+ } else if (block_sz == 16) { -+ auto i_tail = input_stride % 16 != 0 ? input_stride % 16 : 16; -+ auto o_tail = output_stride % 16 != 0 ? output_stride % 16 : 16; -+ if (i_tail != o_tail) { -+ auto t_mask = i_tail == 16 ? o_tail : i_tail; -+ t_mask %= 8; -+ if (t_mask != 0) gen_setmask(t_mask); -+ gen_ker16x16_in_8x8( -+ input_stride, output_stride, i_tail, o_tail); -+ } -+ } else { -+ assert(!"unimplemented"); -+ } -+ -+ postamble(); -+ } -+ -+ void gen_loadu(const ZRegS ymm, const XReg &addr, int size) { -+ QReg xmm(ymm.getIdx()); -+ switch (size) { -+ case 32: ld1w(ymm, p_lsb_256 / T_z, ptr(addr)); break; -+ case 16: ldr(xmm, ptr(addr)); break; -+ default: assert(!"unreachable"); -+ } -+ } -+ -+ void gen_storeu(const XReg &addr, const ZRegS ymm, int size) { -+ QReg xmm(ymm.getIdx()); -+ switch (size) { -+ case 32: st1w(ymm, p_lsb_256, ptr(addr)); break; -+ case 16: str(xmm, ptr(addr)); break; -+ default: assert(!"unreachable"); -+ } -+ } -+ -+ void gen_maskloadu( -+ const ZRegS ymm, const XReg &addr, const PReg mask, int size) { -+ switch (size) { -+ case 32: -+ case 16: ld1w(ymm, mask / T_z, ptr(addr)); break; -+ default: assert(!"unreachable"); -+ } -+ } -+ -+ void gen_maskstoreu( -+ const XReg &addr, const ZRegS ymm, const PReg mask, int size) { -+ switch (size) { -+ case 32: -+ case 16: st1w(ymm, mask, ptr(addr)); break; -+ default: assert(!"unreachable"); -+ } -+ } -+ -+ // Register allocation xmm0~11 -+ void gen_transpose_8x8() { -+ const uint64_t sveLen = get_sve_length(); -+ constexpr int lane = 8; -+ -+#if 0 -+ /* Debug code -+ z0: 7, 6, 5, 4, 3, 2, 1, 0 -+ z1: 15, 14, 13, 12, 11, 10, 9, 8 -+ ... -+ z17: 63, 62, 61, 60, 59, 58, 57, 56 -+ */ -+ ptrue(P_ALL_ONE.b); -+ ptrue(P_TMP.s, VL8); -+ not_(P_TMP.b, P_ALL_ONE/T_z, P_TMP.b); -+ index(z0.s, 0, 1); -+ mov(z0.s, P_TMP/T_m, 0); -+ mov(z_tmp_vec[0].s, 8); -+ mov(z_tmp_vec[0].s, P_TMP/T_m, 0); -+ for(uint32_t i=1; i nChw()C -+ // or nChw()C -> nchw -+ void gen_setmask(int mask) { -+ mov_imm(x_tmp_0, 0); -+ mov_imm(x_tmp_1, mask); -+ whilelt(p_mask.s, x_tmp_0, x_tmp_1); -+ } -+ -+ // TODO: Mark parameter with type information -+ // XXX: ! -+ // offset in byte offset -+ // stride in element number -+ // -+ // Gen specific 8x8 transform respect to certain tail condition -+ void gen_tr8x8(int i_off, int o_off, int input_stride, int output_stride, -+ int in_tail, int out_tail) { -+ constexpr int lane = 8; -+ -+ if (in_tail == 0 || out_tail == 0) return; -+ -+ for (int i = 0; i < out_tail; ++i) { -+ if (in_tail != lane) { -+ add_imm(x_addr, reg_ptr_in_, -+ i_off + i * input_stride * itype_sz_, x_tmp_0); -+ gen_maskloadu(ZRegS(i), x_addr, p_mask, lane * itype_sz_); -+ } else { -+ add_imm(x_addr, reg_ptr_in_, -+ i_off + i * input_stride * itype_sz_, x_tmp_0); -+ gen_loadu(ZRegS(i), x_addr, lane * itype_sz_); -+ } -+ if (prb_.req_src_zp) { fsub(ZRegS(i), ZRegS(i), ymm_src_zp); } -+ } -+ -+ gen_transpose_8x8(); -+ -+ for (int i = 0; i < in_tail; ++i) { -+ if (prb_.req_dst_zp) { fadd(ZRegS(i), ZRegS(i), ymm_dst_zp); } -+ if (out_tail == lane) { -+ add_imm(x_addr, reg_ptr_out_, -+ o_off + i * output_stride * otype_sz_, x_tmp_0); -+ gen_storeu(x_addr, ZRegS(i), lane * otype_sz_); -+ } else { -+ add_imm(x_addr, reg_ptr_out_, -+ o_off + i * output_stride * otype_sz_, x_tmp_0); -+ gen_maskstoreu(x_addr, ZRegS(i), p_mask, lane * otype_sz_); -+ } -+ } -+ } -+ -+ // tail: 0 ~ 8 -+ // support: either in_tail or out_tail is not 8, but not both -+ void gen_ker8x8(int i_off, int o_off, int input_stride, int output_stride, -+ int in_tail, int out_tail) { -+ gen_tr8x8(i_off, o_off, input_stride, output_stride, in_tail, out_tail); -+ } -+ -+ void gen_ker16x16_in_8x8(int input_stride, int output_stride) { -+ const auto lane = 16; -+ const auto sub_lane = lane / 2; -+ gen_tr8x8(0, 0, input_stride, output_stride, sub_lane, sub_lane); -+ gen_tr8x8(input_stride * sub_lane * itype_sz_, sub_lane * otype_sz_, -+ input_stride, output_stride, sub_lane, sub_lane); -+ gen_tr8x8(sub_lane * itype_sz_, output_stride * sub_lane * otype_sz_, -+ input_stride, output_stride, sub_lane, sub_lane); -+ gen_tr8x8((input_stride * sub_lane + sub_lane) * itype_sz_, -+ (output_stride * sub_lane + sub_lane) * otype_sz_, input_stride, -+ output_stride, sub_lane, sub_lane); -+ } -+ -+ // tail can be 1 ~ 16, using avx2 for now -+ void gen_ker16x16_in_8x8( -+ int input_stride, int output_stride, int in_tail, int out_tail) { -+ constexpr auto lane = 16; -+ constexpr auto sub_lane = lane / 2; -+ auto tail = in_tail != lane ? in_tail : out_tail; -+ -+ const auto l_tail = tail < sub_lane ? tail : sub_lane; -+ const auto u_tail = tail < sub_lane ? 0 : tail - sub_lane; -+ -+ if (tail == in_tail) { -+ gen_tr8x8(0, 0, input_stride, output_stride, l_tail, sub_lane); -+ gen_tr8x8(input_stride * sub_lane * itype_sz_, sub_lane * otype_sz_, -+ input_stride, output_stride, l_tail, sub_lane); -+ gen_tr8x8(sub_lane * itype_sz_, -+ output_stride * sub_lane * otype_sz_, input_stride, -+ output_stride, u_tail, sub_lane); -+ gen_tr8x8(itype_sz_ * (input_stride * sub_lane + sub_lane), -+ otype_sz_ * (output_stride * sub_lane + sub_lane), -+ input_stride, output_stride, u_tail, sub_lane); -+ } else { -+ gen_tr8x8(0, 0, input_stride, output_stride, sub_lane, l_tail); -+ gen_tr8x8(input_stride * sub_lane * itype_sz_, sub_lane * otype_sz_, -+ input_stride, output_stride, sub_lane, u_tail); -+ gen_tr8x8(sub_lane * itype_sz_, -+ output_stride * sub_lane * itype_sz_, input_stride, -+ output_stride, sub_lane, l_tail); -+ gen_tr8x8(itype_sz_ * (input_stride * sub_lane + sub_lane), -+ otype_sz_ * (output_stride * sub_lane + sub_lane), -+ input_stride, output_stride, sub_lane, u_tail); -+ } -+ } -+ -+private: -+ // 6 ~ 12 -+ constexpr static int xmm_save_for_windows = 0; -+ constexpr static int xmm_save_start_from = 6; -+ constexpr static int xmm_width = 16; -+ -+ void preamble() { ptrue(p_lsb_256.b, VL32); } -+ -+ void postamble() { ret(); } -+ -+ const prb_t &prb_; -+ -+ int itype_sz_; -+ int otype_sz_; -+ int block_sz; -+ -+ XReg reg_ptr_in_ = abi_param1; -+ XReg reg_ptr_out_ = abi_param2; -+ XReg reg_ptr_tail = abi_param3; -+ XReg reg_src_zp = abi_param4; -+ XReg reg_dst_zp = abi_param5; -+ -+ XReg x_addr = x10; -+ XReg x_tmp_0 = x11; -+ XReg x_tmp_1 = x12; -+ -+ /* Avoid P_TMP(p7) in jit_generator.hpp. */ -+ PReg p_lsb_256 = p6; -+ PReg p_mask = p5; -+ -+ ZRegS ymm_tmp = z0.s; -+ ZRegS ymm_src_zp = z14.s; -+ ZRegS ymm_dst_zp = z15.s; -+ -+ const std::vector tmp_vec_idx = {20, 21, 22, 23, 24, 25, 26, 27}; -+ VReg v_tmp0 = v20; - ZReg z_tmp0 = z20; - ZReg z_tmp1 = z21; - ZReg z_tmp2 = z22; -@@ -1472,15 +2300,31 @@ kernel_t *kernel_t::create(const kernel_t::desc_t &desc) { - - return nullptr; - } -+ - } // namespace tr - - static void prb_block_for_cache(tr::prb_t &prb) { - /* If strides for 0th and 1st nodes are cache friendly - * then one can altogether do away with blocking ! */ -- const bool cache_blocking_needed = false -- || (prb.nodes[0].is % 64 == 0 && prb.nodes[0].n > 16) -- || (prb.ndims > 1 && prb.nodes[1].is % 64 == 0 -- && prb.nodes[1].n > 16); -+ static constexpr int num_elems_thr = 16; -+ const bool stride_cache_friendly -+ = ((prb.nodes[0].is % 64 == 0 && prb.nodes[0].n > num_elems_thr) -+ || (prb.ndims > 1 && prb.nodes[1].is % num_elems_thr == 0 -+ && prb.nodes[1].n > num_elems_thr)) -+ && !prb.is_tail_present; -+ -+ // performance improvement for shapes with large inner-most dimension -+ const size_t L1_cache_sz -+ = size_t(3) * platform::get_per_core_cache_size(1) / 4; -+ const size_t itype_sz_ = data_type_size(prb.itype); -+ const size_t inner_block_sz = prb.nodes[0].n * itype_sz_; -+ const bool requires_inner_blocking = inner_block_sz > L1_cache_sz -+ // 'is_tail_present' is not supported for cache_blocking when -+ // asymmetric_comp is executed. -+ && IMPLICATION(prb.req_asymmetric_comp, !prb.is_tail_present); -+ -+ const bool cache_blocking_needed -+ = stride_cache_friendly || requires_inner_blocking; - if (!cache_blocking_needed) return; - - int unit_input_stride_idx = -1; -@@ -1496,28 +2340,58 @@ static void prb_block_for_cache(tr::prb_t &prb) { - const auto output_stride = prb.nodes[unit_input_stride_idx].os; - const auto num_elems = prb.nodes[unit_input_stride_idx].n; - -- const bool split_needed = (num_elems > 16) && (num_elems % 16 == 0); -+ const bool split_needed = (num_elems > num_elems_thr) -+ && (num_elems % num_elems_thr == 0); - const int move_location = (output_stride % 4 != 0) ? 0 : 1; -- if (split_needed) prb_node_split(prb, unit_input_stride_idx, 16); -+ if (split_needed) -+ prb_node_split(prb, unit_input_stride_idx, num_elems_thr); - - /* Because of cache-unfriendly nature of unit-output stride node, let - * us move unit-input stride node on or near front! */ -- prb_node_move(prb, unit_input_stride_idx, move_location); -+ if (unit_input_stride_idx != move_location) -+ prb_node_move(prb, unit_input_stride_idx, move_location); - } - - /* Potentially, split the node with os=1 in two and pull in the node with - * is=1 between them for better cache reuse: - * [n0:is0:1][n1:1:os1] --> [16n0:is0:1][n1:1:os1][n0/16:is0*16:16] */ - if (prb.ndims >= 2 && prb.nodes[0].os == 1 && prb.nodes[1].is == 1) { -- const auto input_stride = prb.nodes[0].is; - const auto num_elems = prb.nodes[0].n; - -- const bool split_needed = true && (num_elems > 16) -- && (num_elems % 16 == 0) && (input_stride >= 256) -- && (input_stride % 64 == 0); -+ const bool split_needed = (num_elems > num_elems_thr) -+ && (num_elems % num_elems_thr == 0); - if (split_needed) { -- prb_node_split(prb, 0, 16); -+ prb_node_split(prb, 0, num_elems_thr); - prb_node_move(prb, 1, 2); -+ -+ // Update node information -+ prb_node_dependency(prb); -+ -+ // heuristics - looping over the unrolled dims should maximize reuse -+ // of the already cached data; observation is choosing the smallest -+ // dim from the remaining (from 2 up to ndims) gives good results -+ constexpr int new_position = 2; -+ const auto dim_beg_it = std::begin(prb.nodes); -+ const auto dim_two_it = dim_beg_it + new_position; -+ const auto dim_last_it = dim_beg_it + prb.ndims; -+ const auto min_n_node_it = std::min_element(dim_two_it, dim_last_it, -+ [](const tr::node_t &lhs, const tr::node_t &rhs) { -+ return lhs.n < rhs.n; -+ }); -+ const auto min_idx = std::distance(dim_beg_it, min_n_node_it); -+ // check if min_idx node is parent of node with tail processing which -+ // is currently unsupported (i.e. tail processing can only be handled -+ // at the inner-most dimension) -+ bool inner_block_has_tail = false; -+ for (int idx = min_idx - 1; idx >= new_position; idx--) { -+ if (prb.nodes[idx].parent_node_id == min_idx) { -+ inner_block_has_tail = true; -+ break; -+ } -+ } -+ -+ if (min_idx > new_position && (!inner_block_has_tail)) -+ prb_node_move(prb, min_idx, new_position); - } - } - } -@@ -1527,73 +2401,76 @@ static void prb_block_for_cache(tr::prb_t &prb) { - * parallel driver and the kernel. */ - static void prb_thread_kernel_balance( - tr::prb_t &prb, int &ndims_ker_max, int nthr) { -- size_t sz_total = 1; -+ size_t size_total = 1; - for (int d = 0; d < prb.ndims; ++d) -- sz_total *= prb.nodes[d].n; -+ size_total *= prb.nodes[d].n; - -- /* The general expression for sz_drv_thr can be written as -- * sz_drv_min = C0 + FC * (nthr > 1 ? 1 : 0) + VC * (nthr - 1) -+ /* The general expression for size_drv_thr can be written as -+ * size_drv_min = C0 + FC * (nthr > 1 ? 1 : 0) + VC * (nthr - 1) - * where FC and VC are fixed and variable costs respectively. - * Though for now, the below heuristic seems to be good enough */ -- const size_t sz_drv_thr = (nthr > 1) ? 16 * nthr : 1; -+ const size_t size_drv_thr = (nthr > 1) ? 16 * nthr : 1; - -- /* sz_drv_min is the minimal size for the parallel -+ /* size_drv_min is the minimal size for the parallel - * driver required for good parallelization */ -- const size_t sz_drv_min -- = nstl::min(sz_drv_thr, utils::div_up(sz_total, 1024)); -+ const size_t size_drv_min -+ = nstl::min(size_drv_thr, utils::div_up(size_total, 1024)); - - /* kdims -- # of dimensions processed by a kernel -- * sz_ker_cur -- product of the dimension processed by a kernel -- * sz_drv_cur -- product of the dimension processed by a driver */ -+ * size_ker_cur -- product of the dimension processed by a kernel -+ * size_drv_cur -- product of the dimension processed by a driver */ - - int kdims = prb.ndims; -- size_t sz_drv_cur = 1; -- for (; kdims > 1 && sz_drv_cur < sz_drv_min; --kdims) -- sz_drv_cur *= prb.nodes[kdims - 1].n; -+ size_t size_drv_cur = 1; -+ for (; kdims > 1 && size_drv_cur < size_drv_min; --kdims) -+ size_drv_cur *= prb.nodes[kdims - 1].n; - -- size_t sz_ker_cur = 1; -+ size_t size_ker_cur = 1; - for (int d = 0; d < kdims; ++d) -- sz_ker_cur *= prb.nodes[d].n; -+ size_ker_cur *= prb.nodes[d].n; - -- /* Initially kdims is chosen so that sz_drv_cur >= sz_drv_min. -+ /* Initially kdims is chosen so that size_drv_cur >= size_drv_min. - * -- * It might happen that for chosen kdims the sz_ker_cur is too small -+ * It might happen that for chosen kdims the size_ker_cur is too small - * (less than tr::ker_prb_size_min). In that case try to split the -- * innermost driver dimension into two, to increase sz_ker_cur. */ -- bool want_borrow_ker_from_drv = true && kdims < prb.ndims -- && sz_ker_cur < tr::ker_prb_size_min && sz_drv_cur > sz_drv_min -- && kdims != prb.blk_chunk_idx; -+ * innermost driver dimension into two, to increase size_ker_cur. */ -+ const bool want_borrow_ker_from_drv = kdims < prb.ndims -+ && size_ker_cur < tr::ker_prb_size_min -+ && size_drv_cur > size_drv_min; - if (want_borrow_ker_from_drv) { -- /* sz_want_borrow is the minimal sz, so that: -- * o) sz_ker_cur * sz_want_borrow >= tr::ker_prb_size_min -+ /* size_want_borrow is the minimal size, so that: -+ * o) size_ker_cur * size_want_borrow >= tr::ker_prb_size_min - * o) current innermost driver dimension is divisible by -- * sz_want_borrow (so that we can evenly split that -+ * size_want_borrow (so that we can evenly split that - * dimension into two) - * -- * In the worst case the minimal sz_want_borrow is equal -+ * In the worst case the minimal size_want_borrow is equal - * to the innermost driver dimension itself. In that case - * we will sacrifice it in favor of kernel (is it fine?). */ -- size_t sz_want_borrow = utils::div_up(tr::ker_prb_size_min, sz_ker_cur); -- for (; prb.nodes[kdims].n % sz_want_borrow; ++sz_want_borrow) -+ size_t size_want_borrow -+ = utils::div_up(tr::ker_prb_size_min, size_ker_cur); -+ for (; prb.nodes[kdims].n % size_want_borrow; ++size_want_borrow) - ; -- if (sz_want_borrow != prb.nodes[kdims].n) -- prb_node_split(prb, kdims, sz_want_borrow); -+ -+ if (size_want_borrow != prb.nodes[kdims].n) -+ prb_node_split(prb, kdims, size_want_borrow); - kdims += 1; - } - - /* On the other hand it might happen that for chosen kdims -- * the sz_drv_cur is too small (less than sz_drv_min). In that case -+ * the size_drv_cur is too small (less than size_drv_min). In that case - * try to split the outermost kernel dimension into two, to increase -- * sz_drv_cur. */ -- bool want_borrow_drv_from_ker = true && sz_ker_cur > tr::ker_prb_size_min -- && sz_drv_cur < sz_drv_min && kdims != prb.blk_chunk_idx; -+ * size_drv_cur. */ -+ const bool want_borrow_drv_from_ker = size_ker_cur > tr::ker_prb_size_min -+ && size_drv_cur < size_drv_min; - if (want_borrow_drv_from_ker) { -- size_t sz_want_borrow = utils::div_up(sz_drv_min, sz_drv_cur); -- for (; prb.nodes[kdims - 1].n % sz_want_borrow; ++sz_want_borrow) -+ size_t size_want_borrow = utils::div_up(size_drv_min, size_drv_cur); -+ for (; prb.nodes[kdims - 1].n % size_want_borrow; ++size_want_borrow) - ; -- if (sz_want_borrow != prb.nodes[kdims - 1].n) -+ -+ if (size_want_borrow != prb.nodes[kdims - 1].n) - prb_node_split( -- prb, kdims - 1, prb.nodes[kdims - 1].n / sz_want_borrow); -+ prb, kdims - 1, prb.nodes[kdims - 1].n / size_want_borrow); - } - - ndims_ker_max = kdims; -@@ -1607,6 +2484,33 @@ static void prb_thread_kernel_balance( - } - } - -+status_t jit_uni_reorder_t::pd_t::init( -+ engine_t *engine, engine_t *src_engine, engine_t *dst_engine) { -+ CHECK(cpu_reorder_pd_t::init(engine, src_engine, dst_engine)); -+ -+ const bool compensation_needed -+ = prb_.req_s8s8_comp || prb_.req_asymmetric_comp; -+ if (compensation_needed) init_scratchpad(); -+ -+ return status::success; -+} -+ -+void jit_uni_reorder_t::pd_t::init_scratchpad() { -+ const memory_desc_wrapper od(dst_md()); -+ const auto G = with_groups_ ? od.padded_dims()[0] : 1; -+ const auto N = od.padded_dims()[with_groups_ ? 1 : 0]; -+ static constexpr int cache_line_size = 16; -+ const auto wspace_per_thr_size -+ = utils::rnd_up(G * N, cache_line_size) * sizeof(int32_t); -+ -+ auto scratchpad = scratchpad_registry().registrar(); -+ const auto compensation_reduce_size = wspace_per_thr_size * nthr_; -+ -+ // Every thread gets its own scratchpad space for each N -+ scratchpad.template book(memory_tracking::names::key_reorder_space, -+ compensation_reduce_size); -+} -+ - status_t jit_uni_reorder_t::pd_t::create(reorder_pd_t **reorder_pd, - engine_t *engine, const primitive_attr_t *attr, engine_t *src_engine, - const memory_desc_t *src_md, engine_t *dst_engine, -@@ -1616,36 +2520,18 @@ status_t jit_uni_reorder_t::pd_t::create(reorder_pd_t **reorder_pd, - status_t prb_init_status = prb_init(prb, *src_md, *dst_md, attr); - if (prb_init_status != status::success) return prb_init_status; - -- DEBUG({ -- printf("init : "); -- prb_dump(prb); -- }); -- // Sort the prb array in increasing sizes of the output stride -- prb_normalize(prb); -- DEBUG({ -- printf("norm : "); -- prb_dump(prb); -- }); -- /* Combine the variables, which appear together on both -- * sides of the reorder */ -- prb_simplify(prb); -- DEBUG({ -- printf("smpl : "); -- prb_dump(prb); -- }); -- - prb_block_for_cache(prb); - DEBUG({ - printf("cache: "); - prb_dump(prb); - }); - -- CHECK(prb_check_blk(prb, *dst_md)); -- -- int ndims_ker_max; -+ int ndims_ker_max {}; - int nthr = dnnl_get_max_threads(); - prb_thread_kernel_balance(prb, ndims_ker_max, nthr); - -+ if (prb.is_tail_present) prb_node_dependency(prb); -+ - tr::kernel_t::desc_t ker_desc; - status_t ker_init_status - = tr::kernel_t::desc_init(ker_desc, prb, ndims_ker_max); -@@ -1663,99 +2549,191 @@ status_t jit_uni_reorder_t::pd_t::create(reorder_pd_t **reorder_pd, - auto _pd = new pd_t( - attr, src_engine->kind(), src_md, dst_engine->kind(), dst_md); - if (_pd == nullptr) return status::out_of_memory; -+ -+ _pd->nthr_ = nthr; -+ _pd->prb_ = prb; -+ _pd->with_groups_ -+ = prb.compensation_mask == tr::prb_t::comp_mask_with_groups; - if (_pd->init(engine, src_engine, dst_engine) != status::success) { - delete _pd; - return status::unimplemented; - } -- _pd->prb_ = prb; - _pd->ker_desc_ = ker_desc; - _pd->init_scratchpad_md(); -- _pd->nthr_ = nthr; -+ - return safe_ptr_assign(*reorder_pd, _pd); - } - --void jit_uni_reorder_t::omp_driver_0d( -- int off, const char *in, char *out, const float *scale) const { -- tr::call_param_t c {in, out, scale, 0}; -- (*kernel_)(&c); -+void jit_uni_reorder_t::omp_driver_0d(int off, const char *in, char *out, -+ const float *scale, int src_zp, int dst_zp, -+ int32_t *compensation_scratch) const { -+ const tr::prb_t &prb = pd()->prb_; -+ -+ tr::call_param_t base_params; -+ base_params.in = in; -+ base_params.out = out; -+ base_params.scale = scale; -+ base_params.src_zp = src_zp; -+ base_params.dst_zp = dst_zp; -+ base_params.compensation_scratch = compensation_scratch; -+ -+ if (prb.is_tail_present) { -+ tr::tail_call_param_t tail_params; -+ tail_params.base_params = base_params; -+ -+ static constexpr int omp_ndims = 0; -+ fill_curr_data_chunks(prb, off, nullptr, omp_ndims, tail_params); -+ (*kernel_)(&tail_params); -+ } else { -+ (*kernel_)(&base_params); -+ } - } - - void jit_uni_reorder_t::omp_driver_1d(int ithr, int nthr, int off, -- const char *in, char *out, const float *scale) const { -- const tr::node_t *ns = pd()->prb_.nodes + off; -+ const char *in, char *out, const float *scale, int src_zp, int dst_zp, -+ int32_t *compensation_scratch) const { -+ const tr::prb_t &prb = pd()->prb_; -+ const tr::node_t *ns = prb.nodes + off; - for_nd(ithr, nthr, (ptrdiff_t)ns[0].n, [&](ptrdiff_t d0) { -- auto c = tr::call_param_t(); -- c.in = in + d0 * ns[0].is * data_type_size(pd()->prb_.itype); -- c.out = out + d0 * ns[0].os * data_type_size(pd()->prb_.otype); -- c.scale = scale + d0 * ns[0].ss; -- c.blk_chunks = d0; -- (*kernel_)(&c); -+ tr::call_param_t base_params; -+ base_params.in = in + d0 * ns[0].is * data_type_size(prb.itype); -+ base_params.out = out + d0 * ns[0].os * data_type_size(prb.otype); -+ base_params.scale = scale + d0 * ns[0].ss; -+ base_params.src_zp = src_zp; -+ base_params.dst_zp = dst_zp; -+ base_params.compensation_scratch = compensation_scratch + d0 * ns[0].cs; -+ -+ if (prb.is_tail_present) { -+ tr::tail_call_param_t tail_params; -+ tail_params.base_params = base_params; -+ -+ static constexpr int omp_ndims = 1; -+ const ptrdiff_t omp_data_chunks[omp_ndims] = {d0}; -+ fill_curr_data_chunks( -+ prb, off, omp_data_chunks, omp_ndims, tail_params); -+ (*kernel_)(&tail_params); -+ } else { -+ (*kernel_)(&base_params); -+ } - }); - } - - void jit_uni_reorder_t::omp_driver_2d(int ithr, int nthr, int off, -- const char *in, char *out, const float *scale) const { -- const tr::node_t *ns = pd()->prb_.nodes + off; -- const int blk_idx_off = pd()->prb_.blk_chunk_idx - off; -+ const char *in, char *out, const float *scale, int src_zp, int dst_zp, -+ int32_t *compensation_scratch) const { -+ const tr::prb_t &prb = pd()->prb_; -+ const tr::node_t *ns = prb.nodes + off; - for_nd(ithr, nthr, (ptrdiff_t)ns[1].n, (ptrdiff_t)ns[0].n, - [&](ptrdiff_t d1, ptrdiff_t d0) { -- auto c = tr::call_param_t(); -- c.in = in -+ tr::call_param_t base_params; -+ base_params.in = in - + (d0 * ns[0].is + d1 * ns[1].is) -- * data_type_size(pd()->prb_.itype); -- c.out = out -+ * data_type_size(prb.itype); -+ base_params.out = out - + (d0 * ns[0].os + d1 * ns[1].os) -- * data_type_size(pd()->prb_.otype); -- c.scale = scale + d0 * ns[0].ss + d1 * ns[1].ss; -- c.blk_chunks = utils::pick(blk_idx_off, d0, d1); -- (*kernel_)(&c); -+ * data_type_size(prb.otype); -+ base_params.scale = scale + d0 * ns[0].ss + d1 * ns[1].ss; -+ base_params.src_zp = src_zp; -+ base_params.dst_zp = dst_zp; -+ base_params.compensation_scratch -+ = compensation_scratch + d0 * ns[0].cs + d1 * ns[1].cs; -+ -+ if (prb.is_tail_present) { -+ tr::tail_call_param_t tail_params; -+ tail_params.base_params = base_params; -+ -+ static constexpr int omp_ndims = 2; -+ const ptrdiff_t omp_data_chunks[omp_ndims] = {d0, d1}; -+ fill_curr_data_chunks( -+ prb, off, omp_data_chunks, omp_ndims, tail_params); -+ -+ (*kernel_)(&tail_params); -+ } else { -+ (*kernel_)(&base_params); -+ } - }); - } - - void jit_uni_reorder_t::omp_driver_3d(int ithr, int nthr, int off, -- const char *in, char *out, const float *scale) const { -- const tr::node_t *ns = pd()->prb_.nodes + off; -- const int blk_idx_off = pd()->prb_.blk_chunk_idx - off; -+ const char *in, char *out, const float *scale, int src_zp, int dst_zp, -+ int32_t *compensation_scratch) const { -+ const tr::prb_t &prb = pd()->prb_; -+ const tr::node_t *ns = prb.nodes + off; - for_nd(ithr, nthr, (ptrdiff_t)ns[2].n, (ptrdiff_t)ns[1].n, - (ptrdiff_t)ns[0].n, [&](ptrdiff_t d2, ptrdiff_t d1, ptrdiff_t d0) { -- auto c = tr::call_param_t(); -- c.in = in -+ tr::call_param_t base_params; -+ base_params.in = in - + (d0 * ns[0].is + d1 * ns[1].is + d2 * ns[2].is) -- * data_type_size(pd()->prb_.itype); -- c.out = out -+ * data_type_size(prb.itype); -+ base_params.out = out - + (d0 * ns[0].os + d1 * ns[1].os + d2 * ns[2].os) -- * data_type_size(pd()->prb_.otype); -- c.scale = scale + d0 * ns[0].ss + d1 * ns[1].ss + d2 * ns[2].ss; -- c.blk_chunks = utils::pick(blk_idx_off, d0, d1, d2); -- (*kernel_)(&c); -+ * data_type_size(prb.otype); -+ base_params.scale -+ = scale + d0 * ns[0].ss + d1 * ns[1].ss + d2 * ns[2].ss; -+ base_params.src_zp = src_zp; -+ base_params.dst_zp = dst_zp; -+ base_params.compensation_scratch = compensation_scratch -+ + d0 * ns[0].cs + d1 * ns[1].cs + d2 * ns[2].cs; -+ -+ if (prb.is_tail_present) { -+ tr::tail_call_param_t tail_params; -+ tail_params.base_params = base_params; -+ -+ static constexpr int omp_ndims = 3; -+ const ptrdiff_t omp_data_chunks[omp_ndims] = {d0, d1, d2}; -+ fill_curr_data_chunks( -+ prb, off, omp_data_chunks, omp_ndims, tail_params); -+ (*kernel_)(&tail_params); -+ } else { -+ (*kernel_)(&base_params); -+ } - }); - } - - void jit_uni_reorder_t::omp_driver_4d(int ithr, int nthr, int off, -- const char *in, char *out, const float *scale) const { -- const tr::node_t *ns = pd()->prb_.nodes + off; -- const int blk_idx_off = pd()->prb_.blk_chunk_idx - off; -+ const char *in, char *out, const float *scale, int src_zp, int dst_zp, -+ int32_t *compensation_scratch) const { -+ const tr::prb_t &prb = pd()->prb_; -+ const tr::node_t *ns = prb.nodes + off; - for_nd(ithr, nthr, (ptrdiff_t)ns[3].n, (ptrdiff_t)ns[2].n, - (ptrdiff_t)ns[1].n, (ptrdiff_t)ns[0].n, - [&](ptrdiff_t d3, ptrdiff_t d2, ptrdiff_t d1, ptrdiff_t d0) { -- auto c = tr::call_param_t(); -- c.in = in -+ tr::call_param_t base_params; -+ base_params.in = in - + (d0 * ns[0].is + d1 * ns[1].is + d2 * ns[2].is - + d3 * ns[3].is) -- * data_type_size(pd()->prb_.itype); -- c.out = out -+ * data_type_size(prb.itype); -+ base_params.out = out - + (d0 * ns[0].os + d1 * ns[1].os + d2 * ns[2].os - + d3 * ns[3].os) -- * data_type_size(pd()->prb_.otype); -- c.scale = scale + d0 * ns[0].ss + d1 * ns[1].ss + d2 * ns[2].ss -- + d3 * ns[3].ss; -- c.blk_chunks = utils::pick(blk_idx_off, d0, d1, d2, d3); -- (*kernel_)(&c); -+ * data_type_size(prb.otype); -+ base_params.scale = scale + d0 * ns[0].ss + d1 * ns[1].ss -+ + d2 * ns[2].ss + d3 * ns[3].ss; -+ base_params.src_zp = src_zp; -+ base_params.dst_zp = dst_zp; -+ base_params.compensation_scratch = compensation_scratch -+ + d0 * ns[0].cs + d1 * ns[1].cs + d2 * ns[2].cs -+ + d3 * ns[3].cs; -+ -+ if (prb.is_tail_present) { -+ tr::tail_call_param_t tail_params; -+ tail_params.base_params = base_params; -+ -+ static constexpr int omp_ndims = 4; -+ const ptrdiff_t omp_data_chunks[omp_ndims] -+ = {d0, d1, d2, d3}; -+ fill_curr_data_chunks( -+ prb, off, omp_data_chunks, omp_ndims, tail_params); -+ (*kernel_)(&tail_params); -+ } else { -+ (*kernel_)(&base_params); -+ } - }); - } - --void jit_uni_reorder_t::omp_driver( -- const char *in, char *out, const float *scale) const { -+void jit_uni_reorder_t::omp_driver(const char *in, char *out, -+ const float *scale, int src_zp, int dst_zp, -+ const memory_tracking::grantor_t &scratchpad) const { - in += pd()->prb_.ioff * data_type_size(pd()->prb_.itype); - out += pd()->prb_.ooff * data_type_size(pd()->prb_.otype); - -@@ -1770,29 +2748,153 @@ void jit_uni_reorder_t::omp_driver( - - int ndims = pd()->prb_.ndims; - int ndims_ker = pd()->ker_desc_.prb.ndims; -+ const bool req_s8s8_comp = pd()->prb_.req_s8s8_comp; -+ const bool req_asymmetric_comp = pd()->prb_.req_asymmetric_comp; -+ const bool req_compensation = req_s8s8_comp || req_asymmetric_comp; - assert(ndims - ndims_ker <= ndims_driver_max); - -+ int32_t *compensation_reduce_scratch = scratchpad.template get( -+ memory_tracking::names::key_reorder_space); -+ -+ const memory_desc_wrapper od(pd()->dst_md()); -+ const auto G = pd()->with_groups_ ? od.padded_dims()[0] : 1; -+ const auto N = od.padded_dims()[pd()->with_groups_ ? 1 : 0]; -+ static constexpr int cache_line_size = 16; -+ const auto wspace_per_thr_size = utils::rnd_up(G * N, cache_line_size); -+ const auto wspace_per_thr_bytes = wspace_per_thr_size * sizeof(int32_t); -+ - if (ndims - ndims_ker == 0) { -- omp_driver_0d(ndims_ker, in, out, scale); -+ if (req_compensation) -+ std::memset(compensation_reduce_scratch, 0, wspace_per_thr_bytes); -+ -+ omp_driver_0d(ndims_ker, in, out, scale, src_zp, dst_zp, -+ compensation_reduce_scratch); - } else { - parallel(pd()->nthr_, [&](const int ithr, const int nthr) { -+ int32_t *compensation_scratch = nullptr; -+ if (req_compensation) { -+ compensation_scratch = &compensation_reduce_scratch[ithr -+ * wspace_per_thr_size]; -+ std::memset(compensation_scratch, 0, wspace_per_thr_bytes); -+ } -+ - switch (ndims - ndims_ker) { - case 1: -- omp_driver_1d(ithr, nthr, ndims_ker, in, out, scale); -+ omp_driver_1d(ithr, nthr, ndims_ker, in, out, scale, src_zp, -+ dst_zp, compensation_scratch); - break; - case 2: -- omp_driver_2d(ithr, nthr, ndims_ker, in, out, scale); -+ omp_driver_2d(ithr, nthr, ndims_ker, in, out, scale, src_zp, -+ dst_zp, compensation_scratch); - break; - case 3: -- omp_driver_3d(ithr, nthr, ndims_ker, in, out, scale); -+ omp_driver_3d(ithr, nthr, ndims_ker, in, out, scale, src_zp, -+ dst_zp, compensation_scratch); - break; - case 4: -- omp_driver_4d(ithr, nthr, ndims_ker, in, out, scale); -+ omp_driver_4d(ithr, nthr, ndims_ker, in, out, scale, src_zp, -+ dst_zp, compensation_scratch); - break; - default: assert(!"unimplemented"); - } - }); - } -+ -+ // Reduction of intermediate compensation results to the final output -+ if (req_compensation) { -+ const int nthr = ndims - ndims_ker == 0 ? 1 : pd()->nthr_; -+ reduce_compensation( -+ out, compensation_reduce_scratch, nthr, wspace_per_thr_size); -+ } -+} -+ -+void jit_uni_reorder_t::reduce_compensation(char *out, -+ const int32_t *compensation_reduce_scratch, const int nthr, -+ const dim_t wspace_per_thr_size) const { -+ -+ const memory_desc_wrapper od(pd()->dst_md()); -+ const size_t offset = od.size() - od.additional_buffer_size(); -+ -+ static constexpr auto comp_dt_size = sizeof(int32_t); -+ static constexpr int32_t comp_s8s8_shift = 128; -+ -+ // Note: We do not need to explicitly zero-out compensation buffer, as the -+ // per_thread buffers are already zeroed out in the padded area. -+ const auto G = pd()->with_groups_ ? od.padded_dims()[0] : 1; -+ const auto N = od.padded_dims()[pd()->with_groups_ ? 1 : 0]; -+ const auto GN = G * N; -+ const bool req_s8s8_comp = pd()->prb_.req_s8s8_comp; -+ const bool req_asymmetric_comp = pd()->prb_.req_asymmetric_comp; -+ const size_t zp_offset -+ = offset + (pd()->prb_.req_s8s8_comp ? GN * comp_dt_size : 0); -+ -+ parallel_nd(GN, [&](int idx) { -+ int32_t acc = 0; -+ for (int ithr = 0; ithr < nthr; ithr++) { -+ acc -= compensation_reduce_scratch[ithr * wspace_per_thr_size -+ + idx]; -+ } -+ if (req_s8s8_comp) { -+ int32_t *out_comp = reinterpret_cast(&out[offset]); -+ out_comp[idx] = comp_s8s8_shift * acc; -+ } -+ if (req_asymmetric_comp) { -+ int32_t *out_asym_comp -+ = reinterpret_cast(&out[zp_offset]); -+ out_asym_comp[idx] = acc; -+ } -+ }); -+} -+ -+void jit_uni_reorder_t::fill_curr_data_chunks(const tr::prb_t &prb, -+ const int off, const ptrdiff_t *omp_data_chunks, const int omp_ndims, -+ tr::tail_call_param_t &c) const { -+ // Chunks are backwards numered i.e: -+ // [0] -> [node_size] -+ // [1] -> [node_size - 1] -+ // ... -+ // [node_size - 1] -> [1] -+ -+ // It is done like this, because it is easier to decrement counter -+ // and check if it is equal to zero than increment and check -+ // if it is equal to node_size in jit kernel. -+ -+ static constexpr int64_t empty_chunk_info = -1; -+ static constexpr int64_t last_chunk = 1; -+ -+ for (int curr_node_id = prb.ndims - 1; curr_node_id >= 0; curr_node_id--) { -+ const int parent_node_id = prb.nodes[curr_node_id].parent_node_id; -+ const bool is_drv_processing_this_node -+ = curr_node_id >= off && curr_node_id <= off + omp_ndims - 1; -+ const bool is_tail_processing -+ = prb.is_tail_in_one_of_child_nodes(curr_node_id) -+ || prb.nodes[curr_node_id].tail_size > 0; -+ -+ if (is_drv_processing_this_node && is_tail_processing) { -+ const int inner_idx = curr_node_id - off; -+ assert(inner_idx < omp_ndims); -+ const int64_t node_size = prb.nodes[curr_node_id].tail_size > 0 -+ ? prb.nodes[curr_node_id].tail_size -+ : prb.nodes[curr_node_id].n; -+ const int64_t data_chunk = node_size - omp_data_chunks[inner_idx]; -+ -+ if (!prb.nodes[curr_node_id].is_parent_empty()) { -+ const bool is_parent_chunk_last -+ = c.curr_data_chunks[parent_node_id] == last_chunk; -+ c.curr_data_chunks[curr_node_id] -+ = is_parent_chunk_last ? data_chunk : empty_chunk_info; -+ c.zeroing_data = static_cast( -+ is_parent_chunk_last && data_chunk <= 0); -+ } else { -+ c.curr_data_chunks[curr_node_id] = data_chunk; -+ c.zeroing_data = static_cast(data_chunk <= 0); -+ } -+ c.skip_kernel_execution = static_cast(c.zeroing_data -+ && !prb.nodes[curr_node_id].is_zero_pad_needed); -+ if (c.zeroing_data || c.skip_kernel_execution) break; -+ } else -+ c.curr_data_chunks[curr_node_id] = empty_chunk_info; -+ } - } - - status_t jit_uni_reorder_t::init(engine_t *engine) { -@@ -1801,13 +2903,98 @@ status_t jit_uni_reorder_t::init(engine_t *engine) { - } - - status_t jit_uni_reorder_t::execute(const exec_ctx_t &ctx) const { -- status_t status = status::success; - auto in = CTX_IN_MEM(const char *, DNNL_ARG_FROM); -- auto out = CTX_OUT_CLEAN_MEM(char *, DNNL_ARG_TO, status); -- CHECK(status); -+ auto out = CTX_OUT_MEM(char *, DNNL_ARG_TO); - DEFINE_SCALES_BUFFER(scales); -+ DEFINE_ZERO_POINT_VALUE(src_zp, DNNL_ARG_FROM); -+ DEFINE_ZERO_POINT_VALUE(dst_zp, DNNL_ARG_TO); -+ const auto &scratchpad = ctx.get_scratchpad_grantor(); -+ -+ omp_driver(in, out, scales, src_zp, dst_zp, scratchpad); -+ -+ return status::success; -+} -+ -+status_t jit_blk_reorder_t::pd_t::create(reorder_pd_t **reorder_pd, -+ engine_t *engine, const primitive_attr_t *attr, engine_t *src_engine, -+ const memory_desc_t *src_md, engine_t *dst_engine, -+ const memory_desc_t *dst_md) { -+ auto prb = tr::prb_t(); -+ -+ status_t prb_init_status = prb_init(prb, *src_md, *dst_md, attr); -+ if (prb_init_status != status::success) return prb_init_status; -+ // only uni_reorder supports tail processing now -+ // TODO: Add tail processing support in blk_reorder -+ if (prb.is_tail_present) return status::unimplemented; -+ -+ prb_tile_normalize(prb); -+ DEBUG({ -+ printf("tile : "); -+ prb_dump(prb); -+ }); -+ -+ if (!tr::jit_single_blk_kernel_t::applicable(prb)) { -+ return status::unimplemented; -+ } - -- omp_driver(in, out, scales); -+ auto _pd = new pd_t( -+ attr, src_engine->kind(), src_md, dst_engine->kind(), dst_md); -+ if (_pd == nullptr) return status::out_of_memory; -+ _pd->prb_ = prb; -+ if (_pd->init(engine, src_engine, dst_engine) != status::success) { -+ delete _pd; -+ return status::unimplemented; -+ } -+ _pd->init_scratchpad_md(); -+ -+ return safe_ptr_assign(*reorder_pd, _pd); -+} -+ -+void jit_blk_reorder_t::pd_t::prb_tile_normalize(tr::prb_t &p) { -+ if (!utils::one_of(p.nodes[0].n, 8ul, 16ul) -+ && utils::one_of(p.nodes[1].n, 8ul, 16ul)) { -+ nstl::swap(p.nodes[0], p.nodes[1]); -+ } -+} -+ -+jit_blk_reorder_t::jit_blk_reorder_t(const pd_t *apd) : primitive_t(apd) {} -+jit_blk_reorder_t::~jit_blk_reorder_t() = default; -+ -+status_t jit_blk_reorder_t::init(engine_t *engine) { -+ kernel_ = utils::make_unique(pd()->prb_); -+ return kernel_->create_kernel(); -+} -+ -+status_t jit_blk_reorder_t::execute(const exec_ctx_t &ctx) const { -+ const auto in = CTX_IN_MEM(const char *, DNNL_ARG_FROM); -+ auto out = CTX_OUT_MEM(char *, DNNL_ARG_TO); -+ DEFINE_ZERO_POINT_VALUE(src_zp, DNNL_ARG_FROM); -+ DEFINE_ZERO_POINT_VALUE(dst_zp, DNNL_ARG_TO); -+ -+ // kernel handle 2-dimension tiles, a tail is possible -+ auto &prb = this->pd()->prb_; -+ ptrdiff_t BH = 1; -+ for (int i = 2; i < prb.ndims; ++i) { -+ BH *= prb.nodes[i].n; -+ } -+ -+ auto block_sz = prb.n(0); -+ auto n1 = prb.n(1); -+ auto i1 = prb.is(1); -+ auto o1 = prb.os(1); -+ auto FL = (n1 + block_sz - 1) / block_sz; -+ auto bh_stride = BH == 1 ? 0 : prb.is(2); -+ -+ auto itype_sz_ = data_type_size(pd()->prb_.itype); -+ auto otype_sz_ = data_type_size(pd()->prb_.otype); -+ -+ parallel_nd(BH, FL, [&](dim_t bh, dim_t fl) { -+ auto fl_b = fl * block_sz; -+ auto bh_b = bh_stride * bh; -+ auto *i = in + (bh_b + fl_b * i1) * itype_sz_; -+ auto *o = out + (bh_b + fl_b * o1) * otype_sz_; -+ (*kernel_)(i, o, n1 - fl_b < block_sz, src_zp, dst_zp); -+ }); - - return status::success; - } -diff --git a/src/cpu/aarch64/jit_uni_reorder.hpp b/src/cpu/aarch64/jit_uni_reorder.hpp -index 2fb6f0f89f3..bf400430ba5 100644 ---- a/src/cpu/aarch64/jit_uni_reorder.hpp -+++ b/src/cpu/aarch64/jit_uni_reorder.hpp -@@ -1,6 +1,6 @@ - /******************************************************************************* --* Copyright 2018-2020 Intel Corporation --* Copyright 2020 FUJITSU LIMITED -+* Copyright 2018-2022 Intel Corporation -+* Copyright 2020-2022 FUJITSU LIMITED - * Copyright 2022 Arm Ltd. and affiliates - * - * Licensed under the Apache License, Version 2.0 (the "License"); -@@ -36,15 +36,76 @@ namespace tr { - constexpr int max_ndims = DNNL_MAX_NDIMS; - - struct node_t { -- size_t n; -- ptrdiff_t is; // input stride -- ptrdiff_t os; // output stride -- ptrdiff_t ss; // scale stride -+ static constexpr int64_t empty_field = -1; -+ -+ size_t n = 0; -+ size_t tail_size = 0; -+ int dim_id = empty_field; -+ int parent_node_id = empty_field; -+ bool is_zero_pad_needed = false; -+ ptrdiff_t is = 0; // input stride -+ ptrdiff_t os = 0; // output stride -+ ptrdiff_t ss = 0; // scale stride -+ ptrdiff_t cs = 0; // compensation stride -+ -+ bool is_dim_id_empty() const { return dim_id == empty_field; } -+ bool is_parent_empty() const { return parent_node_id == empty_field; } - }; - - enum class scale_type_t { NONE, COMMON, MANY }; - - struct prb_t { -+ /* The compensation mask value indicates how big an additional buffer should be. -+ * Possible values for reorder: -+ * 1) standard compensation = 1 = 0b01 -+ * 2) asymmetric compensation = 2 = 0b10 -+ * 3) compensation if tensor contains group = 3 = 0b11 */ -+ static constexpr int invalid_comp_mask = 0; -+ static constexpr int standard_comp_mask = 0b1; -+ static constexpr int asymmetric_comp_mask = 0b10; -+ static constexpr int comp_mask_with_groups -+ = standard_comp_mask + asymmetric_comp_mask; -+ -+ bool is_tail_in_one_of_child_nodes(int parent_node_id) const { -+ for (int i = parent_node_id; i >= 0; i--) { -+ if (nodes[i].parent_node_id == parent_node_id) { -+ if (nodes[i].tail_size != 0) -+ return true; -+ else -+ parent_node_id = i; -+ } -+ } -+ -+ return false; -+ } -+ -+ int tail(int d) const { -+ assert(d < ndims); -+ return static_cast(nodes[d].tail_size); -+ } -+ -+ int n(int d) const { -+ assert(d < ndims); -+ return static_cast(nodes[d].n); -+ } -+ int is(int d) const { -+ assert(d < ndims); -+ return static_cast(nodes[d].is); -+ } -+ int os(int d) const { -+ assert(d < ndims); -+ return static_cast(nodes[d].os); -+ } -+ int ss(int d) const { -+ assert(d < ndims); -+ return static_cast(nodes[d].ss); -+ } -+ -+ int cs(int d) const { -+ assert(d < ndims); -+ return static_cast(nodes[d].cs); -+ } -+ - data_type_t itype; - data_type_t otype; - int ndims; -@@ -54,21 +115,24 @@ struct prb_t { - scale_type_t scale_type; - float beta; - int full_ndims; -- int ip_tail; -- int op_tail; -- int iblock; -- int oblock; -- int blk_chunk_idx; -+ bool is_tail_present = false; -+ float scale_adjust = 1.f; -+ int compensation_mask = invalid_comp_mask; -+ bool req_s8s8_comp = false; -+ bool req_asymmetric_comp = false; -+ bool req_src_zp = false; -+ bool req_dst_zp = false; - }; - - status_t prb_init(prb_t &prb, const memory_desc_t &imd, - const memory_desc_t &omd, const primitive_attr_t *attr); - --status_t prb_check_blk(prb_t &prb, const memory_desc_t &imd); -- - /** sorts the problem nodes so that output strides come in ascending order */ - void prb_normalize(prb_t &p); - -+/** fill parent node info for blocked nodes */ -+void prb_node_dependency(prb_t &p); -+ - /** folds nodes together if possible */ - void prb_simplify(prb_t &p); - -@@ -88,10 +152,24 @@ void prb_node_move(prb_t &p, int d0, int d1); - void prb_dump(const prb_t &p); - - struct call_param_t { -- const void *in; -- void *out; -- const float *scale; -- size_t blk_chunks; -+ const void *in = nullptr; -+ void *out = nullptr; -+ const float *scale = nullptr; -+ int32_t src_zp = 0; -+ int32_t dst_zp = 0; -+ int32_t *compensation_scratch = nullptr; -+}; -+ -+// The additional structure is needed because -+// using a data structure with tail processing -+// data for non-tail cases reduces kernel -+// performance. This is because there is too -+// much data that has to be transferred to the kernel. -+struct tail_call_param_t { -+ call_param_t base_params; -+ int64_t curr_data_chunks[DNNL_MAX_NDIMS] = {-1}; -+ int64_t zeroing_data = static_cast(false); -+ int64_t skip_kernel_execution = static_cast(false); - }; - - struct kernel_t { -@@ -100,8 +178,12 @@ struct kernel_t { - prb_t prb; - }; - -- kernel_t(const desc_t &desc) : desc_(desc) {} -+ kernel_t(const desc_t &desc) -+ : desc_(desc) -+ , compensation_needed_( -+ desc.prb.req_s8s8_comp || desc.prb.req_asymmetric_comp) {} - virtual void operator()(const call_param_t *c) const = 0; -+ virtual void operator()(const tail_call_param_t *c) const = 0; - virtual status_t create_kernel() = 0; - virtual ~kernel_t() {} - -@@ -119,10 +201,13 @@ struct kernel_t { - protected: - const desc_t desc_; - const prb_t &prb_ = desc_.prb; -+ bool compensation_needed_ = false; - }; - - /* TODO: add trans_t class */ - -+struct jit_single_blk_kernel_t; -+ - } // namespace tr - - struct jit_uni_reorder_t : public primitive_t { -@@ -135,8 +220,13 @@ struct jit_uni_reorder_t : public primitive_t { - tr::prb_t prb_; - tr::kernel_t::desc_t ker_desc_; - int nthr_; -+ bool with_groups_ = false; -+ -+ status_t init( -+ engine_t *engine, engine_t *src_engine, engine_t *dst_engine); - - private: -+ void init_scratchpad(); - static status_t create(reorder_pd_t **reorder_pd, engine_t *engine, - const primitive_attr_t *attr, engine_t *src_engine, - const memory_desc_t *src_md, engine_t *dst_engine, -@@ -151,23 +241,66 @@ struct jit_uni_reorder_t : public primitive_t { - enum { ndims_driver_max = 4 }; - - private: -- void omp_driver_0d( -- int off, const char *in, char *out, const float *scale) const; -+ void omp_driver_0d(int off, const char *in, char *out, const float *scale, -+ int src_zp, int dst_zp, int32_t *compensation_scratch) const; - void omp_driver_1d(int ithr, int nthr, int off, const char *in, char *out, -- const float *scale) const; -+ const float *scale, int src_zp, int dst_zp, -+ int32_t *compensation_scratch) const; - void omp_driver_2d(int ithr, int nthr, int off, const char *in, char *out, -- const float *scale) const; -+ const float *scale, int src_zp, int dst_zp, -+ int32_t *compensation_scratch) const; - void omp_driver_3d(int ithr, int nthr, int off, const char *in, char *out, -- const float *scale) const; -+ const float *scale, int src_zp, int dst_zp, -+ int32_t *compensation_scratch) const; - void omp_driver_4d(int ithr, int nthr, int off, const char *in, char *out, -- const float *scale) const; -+ const float *scale, int src_zp, int dst_zp, -+ int32_t *compensation_scratch) const; -+ -+ void omp_driver(const char *in, char *out, const float *scale, int src_zp, -+ int dst_zp, const memory_tracking::grantor_t &scratchpad) const; - -- void omp_driver(const char *in, char *out, const float *scale) const; -+ void fill_curr_data_chunks(const tr::prb_t &prb, const int off, -+ const ptrdiff_t *omp_data_chunks, const int omp_ndims, -+ tr::tail_call_param_t &c) const; -+ -+ void reduce_compensation(char *out, -+ const int32_t *compensation_reduce_scratch, const int nthr, -+ const dim_t wspace_per_thr_size) const; - - const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } - std::unique_ptr kernel_; - }; - -+struct jit_blk_reorder_t : public primitive_t { -+ using primitive_t::primitive_t; -+ struct pd_t : public cpu_reorder_pd_t { -+ using cpu_reorder_pd_t::cpu_reorder_pd_t; -+ DECLARE_COMMON_PD_T("jit:blk", jit_blk_reorder_t); -+ -+ tr::prb_t prb_; -+ -+ private: -+ static status_t create(reorder_pd_t **reorder_pd, engine_t *engine, -+ const primitive_attr_t *attr, engine_t *src_engine, -+ const memory_desc_t *src_md, engine_t *dst_engine, -+ const memory_desc_t *dst_md); -+ -+ // Swap last two nodes, put block 4, 8, 16 nodes to first -+ static void prb_tile_normalize(tr::prb_t &p); -+ friend dnnl::impl::impl_list_item_t; -+ }; -+ -+ status_t init(engine_t *engine) override; -+ status_t execute(const exec_ctx_t &ctx) const override; -+ -+ jit_blk_reorder_t(const pd_t *apd); -+ ~jit_blk_reorder_t(); -+ -+private: -+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } -+ std::unique_ptr kernel_; -+}; -+ - } // namespace aarch64 - } // namespace cpu - } // namespace impl -diff --git a/src/cpu/aarch64/jit_uni_reorder_utils.cpp b/src/cpu/aarch64/jit_uni_reorder_utils.cpp -index 7123811f827..28f36a7e2e7 100644 ---- a/src/cpu/aarch64/jit_uni_reorder_utils.cpp -+++ b/src/cpu/aarch64/jit_uni_reorder_utils.cpp -@@ -1,6 +1,6 @@ - /******************************************************************************* --* Copyright 2018-2021 Intel Corporation --* Copyright 2020 FUJITSU LIMITED -+* Copyright 2018-2022 Intel Corporation -+* Copyright 2020-2022 FUJITSU LIMITED - * Copyright 2022 Arm Ltd. and affiliates - * - * Licensed under the Apache License, Version 2.0 (the "License"); -@@ -25,10 +25,21 @@ - #include "common/nstl.hpp" - #include "common/type_helpers.hpp" - #include "common/utils.hpp" --#include "dnnl_debug.h" -+#include "oneapi/dnnl/dnnl_debug.h" - - #include "cpu/aarch64/jit_uni_reorder.hpp" - -+// #define TR_DEBUG -+#if defined(TR_DEBUG) -+#define DEBUg(...) \ -+ do { \ -+ __VA_ARGS__ \ -+ } while (0) -+#else -+#define DEBUg(...) -+#endif -+#define DEBUG(...) DEBUg(__VA_ARGS__) -+ - using namespace dnnl::impl::types; - using namespace dnnl::impl::status; - -@@ -41,87 +52,45 @@ namespace tr { - - /** ad-hoc structure to describe blocked memory layout */ - struct layout_desc_t { -+ layout_desc_t() -+ : dt(dnnl_data_type_undef) -+ , ndims(0) -+ , id {-1} -+ , dims {0} -+ , tails {0} -+ , is_blk {false} -+ , strides {0} {} - data_type_t dt; - int ndims; - dims_t id; - dims_t dims; -+ dims_t tails; -+ bool is_blk[DNNL_MAX_NDIMS]; - strides_t strides; - }; - --static status_t compute_blk_and_tail( -- const memory_desc_t &md_, const int idx, int &blk, int &tail) { -- const auto md = memory_desc_wrapper(md_); -- const auto &bd = md.blocking_desc(); -- if (tail == 0) return status::success; -- -- const std::set unique_inner_idxs( -- bd.inner_idxs, bd.inner_idxs + bd.inner_nblks); -- std::set dims_with_multiple_blks; -- for (dim_t dim : unique_inner_idxs) { -- if (std::count(bd.inner_idxs, bd.inner_idxs + bd.inner_nblks, dim) > 1) -- dims_with_multiple_blks.insert(dim); -- } -- -- // Dims that have a tail and have multiple blocks are not supported by the jit kernel yet. -- // For example: -- // src_tag = abcd -- // dst_tag = ABcd16b16a4b -- // 16x15x3x3 -- // In this case, 'b' dim has two blocks and has a tail. It is not a supported case. -- if (dims_with_multiple_blks.find(idx) != dims_with_multiple_blks.end()) -- return status::unimplemented; -- -- // Only supports inconsistent padding in single and double blocks -- // and the total block size <= 256 -- for (int iblk = bd.inner_nblks - 1; iblk > 0; --iblk) { -- if (bd.inner_idxs[iblk] == idx) break; -- blk *= bd.inner_blks[iblk]; -- tail *= bd.inner_blks[iblk]; -- } -- if (unique_inner_idxs.size() > 2 || blk > 256) return status::unimplemented; -- -- return status::success; --} -- --static status_t compute_chunk_idx(const prb_t &p, const memory_desc_t &imd_, -- const memory_desc_t &omd_, const int blk_idx, int &chunk_idx) { -- const auto imd = memory_desc_wrapper(imd_); -- const auto omd = memory_desc_wrapper(omd_); -- const auto &ibd = imd.blocking_desc(); -- const auto &obd = omd.blocking_desc(); -- if (p.ip_tail == 0 && p.op_tail == 0) return status::success; -- -- const ptrdiff_t is -- = ibd.strides[blk_idx] * obd.inner_blks[obd.inner_idxs[blk_idx]]; -- const ptrdiff_t os = obd.strides[blk_idx]; -- -- for (int i = blk_idx; i < omd.ndims(); ++i) { -- if (p.nodes[i].os == os && p.nodes[i].is == is) { -- chunk_idx = i; -- return status::success; -- } -- } -- -- return status::invalid_arguments; --} -- - status_t cvt_mem_desc_to_layout_desc(const memory_desc_t &md_, -- layout_desc_t &ld, const dims_t &blocks, const dims_t &ext_padding) { -+ layout_desc_t &ld, const dims_t &blocks, const dims_t &external_padding, -+ const dims_t &tails) { -+ static constexpr bool it_is_blk = true; -+ - const auto md = memory_desc_wrapper(md_); - -- bool ok = true && md.is_blocking_desc() && md.extra().flags == 0; -- if (!ok) return invalid_arguments; -+ if (!md.is_blocking_desc()) return invalid_arguments; - - const auto &bd = md.blocking_desc(); - - ld.ndims = 0; - ld.dt = md.data_type(); - -- auto P = [&ld](int id, int dim, ptrdiff_t stride) { -+ auto add_dim = [&ld](int id, dim_t dim, dim_t tail, bool is_blk, -+ ptrdiff_t stride) { - assert((size_t)ld.ndims < sizeof(ld.dims) / sizeof(ld.dims[0])); - ld.id[ld.ndims] = id; - ld.dims[ld.ndims] = dim; - ld.strides[ld.ndims] = stride; -+ ld.tails[ld.ndims] = tail; -+ ld.is_blk[ld.ndims] = is_blk; - ++ld.ndims; - }; - -@@ -129,12 +98,27 @@ status_t cvt_mem_desc_to_layout_desc(const memory_desc_t &md_, - const int ld_ndims_start = ld.ndims; - if (blocks[d] != 1) { - stride_t stride = 1; -+ int tail = tails[d]; - for (int iblk = bd.inner_nblks - 1; iblk >= 0; --iblk) { -- if (bd.inner_idxs[iblk] == d) P(d, bd.inner_blks[iblk], stride); -+ if (bd.inner_idxs[iblk] == d) { -+ const dim_t inner_tail = tail % bd.inner_blks[iblk]; -+ add_dim(d, bd.inner_blks[iblk], inner_tail, it_is_blk, -+ stride); -+ tail = utils::div_up(tail, bd.inner_blks[iblk]); -+ } - stride *= bd.inner_blks[iblk]; - } - } -- P(d, (md.padded_dims()[d] + ext_padding[d]) / blocks[d], bd.strides[d]); -+ -+ const dim_t dim_with_external_padding -+ = (md.padded_dims()[d] + external_padding[d]) / blocks[d]; -+ const dim_t padded_dim = md.padded_dims()[d] / blocks[d]; -+ const dim_t tail = dim_with_external_padding != padded_dim -+ ? dim_with_external_padding -+ - (dim_with_external_padding - padded_dim) -+ : 0; -+ -+ add_dim(d, dim_with_external_padding, tail, !it_is_blk, bd.strides[d]); - - // TODO: NOW: revisit, do we need a reverse? - // TODO: NOW: consider using strides instead of block sizes in md -@@ -144,12 +128,70 @@ status_t cvt_mem_desc_to_layout_desc(const memory_desc_t &md_, - const int idx1 = ld.ndims - 1 - ld_d; - nstl::swap(ld.dims[idx0], ld.dims[idx1]); - nstl::swap(ld.strides[idx0], ld.strides[idx1]); -+ nstl::swap(ld.tails[idx0], ld.tails[idx1]); -+ nstl::swap(ld.is_blk[idx0], ld.is_blk[idx1]); - } - } - - return success; - } - -+static bool is_with_groups(const memory_desc_t &dst_md) { -+ using namespace memory_extra_flags; -+ auto dst_d = memory_desc_wrapper(dst_md); -+ const int grp_bit = 1 << 1; -+ auto check_flag_and_mask = [&](int flag, int mask) { -+ return (dst_d.extra().flags & flag) && (mask & grp_bit); -+ }; -+ -+ return check_flag_and_mask( -+ compensation_conv_s8s8, dst_d.extra().compensation_mask) -+ || check_flag_and_mask(compensation_conv_asymmetric_src, -+ dst_d.extra().asymm_compensation_mask); -+} -+ -+static inline int get_next_parent_node(node_t *nodes, int ndims, int cur_node) { -+ const int cur_id = nodes[cur_node].dim_id; -+ for (int d = cur_node + 1; d < ndims; ++d) { -+ if (nodes[d].dim_id == cur_id) return d; -+ } -+ return -1; -+} -+ -+static void prb_set_compensation_strides(prb_t &p) { -+ -+ auto require_n_stride = [&](int cur_node) -> bool { -+ const int parent = get_next_parent_node(p.nodes, p.ndims, cur_node); -+ if (parent < 0) return false; -+ -+ const size_t p_n = p.nodes[parent].n; -+ -+ // if 'parent_node.n' is larger than 1, then cur_node stride -+ // is 'cur_node.n' -+ return p_n > size_t(1); -+ }; -+ -+ const auto compensation_needed = p.req_s8s8_comp || p.req_asymmetric_comp; -+ if (!compensation_needed) return; -+ int mask = p.compensation_mask; -+ ptrdiff_t cs = 1; -+ for (int d = 0; d < p.ndims; ++d) { -+ if (mask & (1 << p.nodes[d].dim_id)) { -+ -+ // correct cases when 'cs' exceeds output stride -+ if (cs > p.nodes[d].os) cs = p.nodes[d].os; -+ -+ p.nodes[d].cs = cs; -+ const bool n_stride = require_n_stride(d); -+ if (p.nodes[d].tail_size > 0 && (!p.nodes[d].is_zero_pad_needed) -+ && (!n_stride)) -+ cs *= p.nodes[d].tail_size; -+ else -+ cs *= p.nodes[d].n; -+ } -+ } -+} -+ - status_t prb_init(prb_t &p, const memory_desc_t &imd, const memory_desc_t &omd, - const primitive_attr_t *attr) { - auto im_d = memory_desc_wrapper(imd); -@@ -157,8 +199,7 @@ status_t prb_init(prb_t &p, const memory_desc_t &imd, const memory_desc_t &omd, - - auto check_post_ops = [](const primitive_attr_t *attr) { - const auto &po = attr->post_ops_; -- return po.len() == 0 -- || (po.len() == 1 && po.contain(primitive_kind::sum, 0)); -+ return po.len() == 0 || (po.len() == 1 && po.entry_[0].is_sum(false)); - }; - - bool ok = im_d.is_blocking_desc() && om_d.is_blocking_desc() -@@ -166,81 +207,129 @@ status_t prb_init(prb_t &p, const memory_desc_t &imd, const memory_desc_t &omd, - && !om_d.has_runtime_dims_or_strides() && !om_d.has_zero_dim() - && attr->has_default_values( - primitive_attr_t::skip_mask_t::oscale_runtime -+ | primitive_attr_t::skip_mask_t::zero_points_runtime - | primitive_attr_t::skip_mask_t::post_ops) - && check_post_ops(attr); - if (!ok) return unimplemented; - -- dims_t iblocks, oblocks, ip_padding, op_padding; -+ bool is_tail_present = false; -+ dims_t iblocks, oblocks, i_tails, o_tails, i_paddings, o_paddings; - im_d.compute_blocks(iblocks); - om_d.compute_blocks(oblocks); -- utils::array_set(ip_padding, 0, im_d.ndims()); -- utils::array_set(op_padding, 0, om_d.ndims()); -- -- /* padding_dim consistency check -- * only supports inconsitent padding for src -- * TODO: Add inconsistent padding support for dst */ -- int ip_tail = 0; -- int op_tail = 0; -- int iblk_w_tail = 1; -- int oblk_w_tail = 1; -- int blk_idx = 0; -+ -+ for (int d = 0; d < om_d.ndims(); ++d) { -+ const auto dim = om_d.dims()[d]; -+ const auto pdim = om_d.padded_dims()[d]; -+ const auto cblock = oblocks[d]; -+ // do not allow excess pdim other than required for rounding-up of dim. -+ if (utils::rnd_up(dim, cblock) != pdim) return unimplemented; -+ } -+ -+ utils::array_set(i_tails, 0, im_d.ndims()); -+ utils::array_set(o_tails, 0, om_d.ndims()); -+ utils::array_set(i_paddings, 0, im_d.ndims()); -+ utils::array_set(o_paddings, 0, om_d.ndims()); - - for (int d = 0; d < im_d.ndims(); ++d) { -- const int ip_tmp_dim = im_d.padded_dims()[d]; -- const int op_tmp_dim = om_d.padded_dims()[d]; -- const int ip_tmp_tail = ip_tmp_dim % oblocks[d]; -- const int op_tmp_tail = op_tmp_dim % iblocks[d]; -- -- const bool pdim_consistent = ip_tmp_dim == op_tmp_dim -- && ip_tmp_tail == 0 && op_tmp_tail == 0; -- const bool pdim_tail = ip_tmp_tail > 0 -- && (ip_tmp_dim + oblocks[d] - ip_tmp_tail) == op_tmp_dim -- && op_tmp_tail == 0 && ip_tail == 0; -- if (!pdim_consistent && !pdim_tail) return status::unimplemented; -- if (pdim_tail) { -- blk_idx = d; -- ip_tail = ip_tmp_tail; -- op_tail = op_tmp_tail; -- iblk_w_tail = iblocks[d]; -- oblk_w_tail = oblocks[d]; -- ip_padding[d] = oblocks[d] - ip_tmp_tail; -- op_padding[d] = iblocks[d] - op_tmp_tail; -+ const dim_t i_dim = im_d.dims()[d]; -+ const dim_t o_dim = om_d.dims()[d]; -+ const dim_t i_tail = i_dim % iblocks[d]; -+ const dim_t o_tail = o_dim % oblocks[d]; -+ -+ if (o_tail > 0) { -+ is_tail_present = true; -+ o_tails[d] = o_tail; -+ o_paddings[d] = oblocks[d] - o_tail; -+ } -+ -+ if (i_tail > 0) { -+ is_tail_present = true; -+ i_tails[d] = i_tail; -+ i_paddings[d] = iblocks[d] - i_tail; - } - } -- CHECK(compute_blk_and_tail(omd, blk_idx, oblk_w_tail, ip_tail)); - -+ // To compute input layout description we need to pass output paddings -+ // which will be used to compute input dims rounded up to multiple of -+ // output dims. Analogous applies to output layout description. -+ // This is demanded by the algorithm of nodes creation. -+ // Example: -+ // input: -+ // format: abc -+ // size: 77, 15, 3 -+ // o_padding: 3, 17, 0 -+ // returns ild: 80, 32, 3 -+ // output: -+ // format: ABc16b16a2b -+ // size: 77, 15, 3 -+ // i_padding: 0, 0, 0 -+ // returns old: 5, 16, 1, 16, 2, 3 - layout_desc_t ild, old; -- status_t status -- = cvt_mem_desc_to_layout_desc(imd, ild, iblocks, ip_padding); -- if (status != success) return status; -- status = cvt_mem_desc_to_layout_desc(omd, old, oblocks, op_padding); -- if (status != success) return status; -+ CHECK(cvt_mem_desc_to_layout_desc(imd, ild, iblocks, o_paddings, i_tails)); -+ CHECK(cvt_mem_desc_to_layout_desc(omd, old, oblocks, i_paddings, o_tails)); - - p.itype = ild.dt; - p.otype = old.dt; -- p.ip_tail = ip_tail; -- p.op_tail = op_tail; -- p.iblock = iblk_w_tail; -- p.oblock = oblk_w_tail; -- -+ p.is_tail_present = is_tail_present; -+ p.req_src_zp = !attr->zero_points_.has_default_values(DNNL_ARG_SRC); -+ p.req_dst_zp = !attr->zero_points_.has_default_values(DNNL_ARG_DST); - p.scale_type = attr->output_scales_.has_default_values() - ? scale_type_t::NONE - : (attr->output_scales_.mask_ == 0 ? scale_type_t::COMMON - : scale_type_t::MANY); -+ p.scale_adjust = (om_d.extra().flags & memory_extra_flags::scale_adjust) -+ ? om_d.extra().scale_adjust -+ : 1.f; -+ p.req_s8s8_comp -+ = om_d.extra().flags & memory_extra_flags::compensation_conv_s8s8; -+ p.req_asymmetric_comp = om_d.extra().flags -+ & memory_extra_flags::compensation_conv_asymmetric_src; -+ -+ const bool with_groups = is_with_groups(omd); -+ -+ auto mask_ok = [&](bool check, int mask) { -+ return IMPLICATION(check, mask == (with_groups ? 0x3 : 0x1)); -+ }; -+ -+ if (!mask_ok(p.req_s8s8_comp, om_d.extra().compensation_mask) -+ || !mask_ok(p.req_asymmetric_comp, -+ om_d.extra().asymm_compensation_mask)) -+ return status::unimplemented; - -- ptrdiff_t ss[max_ndims] = {0}; -+ ptrdiff_t ss[max_ndims] = {0}; // scales strides - if (p.scale_type == scale_type_t::MANY) { -- ptrdiff_t last_ss = 1; -+ const int mask = attr->output_scales_.mask_; -+ ptrdiff_t dense_stride = 1; -+ ptrdiff_t last_stride = 1; - for (int d = old.ndims - 1; d >= 0; --d) { - assert((d == 0 || old.id[d - 1] <= old.id[d]) - && "logical dimensions should be in ascending order"); -- if (attr->output_scales_.mask_ & (1 << old.id[d])) { -- ss[d] = last_ss; -- last_ss *= old.dims[d]; -+ if (mask & (1 << old.id[d])) { -+ if ((d + 1) < old.ndims && old.id[d + 1] != old.id[d] -+ && (mask & (1 << old.id[d + 1]))) { -+ dense_stride = dense_stride * imd.dims[old.id[d + 1]]; -+ last_stride = dense_stride; -+ } -+ ss[d] = last_stride; -+ last_stride *= old.dims[d]; - } - } - } - -+ const auto compensation_needed = p.req_s8s8_comp || p.req_asymmetric_comp; -+ if (compensation_needed) { -+ p.compensation_mask = p.req_s8s8_comp -+ ? om_d.extra().compensation_mask -+ : (p.req_asymmetric_comp ? om_d.extra().asymm_compensation_mask -+ : tr::prb_t::invalid_comp_mask); -+ -+ if (p.compensation_mask == tr::prb_t::asymmetric_comp_mask) -+ return unimplemented; -+ -+ assert(p.compensation_mask == tr::prb_t::standard_comp_mask -+ || p.compensation_mask == tr::prb_t::comp_mask_with_groups); -+ } -+ - int ndims = 0; - - int i_pos = 0; /* state for input -- current dimension */ -@@ -254,6 +343,10 @@ status_t prb_init(prb_t &p, const memory_desc_t &imd, const memory_desc_t &omd, - - if (ild.dims[i_pos] == old.dims[o_pos]) { - p.nodes[ndims].n = ild.dims[i_pos]; -+ p.nodes[ndims].dim_id = old.id[o_pos]; -+ p.nodes[ndims].tail_size = old.tails[o_pos]; -+ p.nodes[ndims].is_zero_pad_needed -+ = old.is_blk[o_pos] && old.tails[o_pos] > 0; - p.nodes[ndims].is = ild.strides[i_pos]; - p.nodes[ndims].os = old.strides[o_pos]; - p.nodes[ndims].ss = ss[o_pos]; -@@ -261,19 +354,45 @@ status_t prb_init(prb_t &p, const memory_desc_t &imd, const memory_desc_t &omd, - ++i_pos; - ++o_pos; - } else if (ild.dims[i_pos] < old.dims[o_pos]) { -- assert(old.dims[o_pos] % ild.dims[i_pos] == 0); -- int factor = old.dims[o_pos] / ild.dims[i_pos]; -+ // old must be divisible by ild or we will not be -+ // able to create valid nodes. The problem appears -+ // when stag=Acdb48a and dtag=Acdb32a for example. -+ if (ild.dims[i_pos] == 0 || old.dims[o_pos] % ild.dims[i_pos] != 0) -+ return status::unimplemented; -+ -+ dim_t factor = old.dims[o_pos] / ild.dims[i_pos]; -+ -+ const size_t tail_of_upper_dim -+ = utils::div_up(old.tails[o_pos], factor) == ild.dims[i_pos] -+ ? 0 -+ : utils::div_up(old.tails[o_pos], factor); -+ const size_t tail_of_lower_dim = old.tails[o_pos] % factor; -+ - p.nodes[ndims].n = ild.dims[i_pos]; -+ p.nodes[ndims].dim_id = old.id[o_pos]; -+ p.nodes[ndims].tail_size = tail_of_upper_dim; -+ p.nodes[ndims].is_zero_pad_needed -+ = old.is_blk[o_pos] && tail_of_upper_dim > 0; - p.nodes[ndims].is = ild.strides[i_pos]; - p.nodes[ndims].os = old.strides[o_pos] * factor; - p.nodes[ndims].ss = ss[o_pos] * factor; - ++ndims; - ++i_pos; - old.dims[o_pos] = factor; -+ old.tails[o_pos] = tail_of_lower_dim; - } else if (ild.dims[i_pos] > old.dims[o_pos]) { -- assert(ild.dims[i_pos] % old.dims[o_pos] == 0); -- int factor = ild.dims[i_pos] / old.dims[o_pos]; -+ // ild must be divisible by old or we will not be -+ // able to create valid nodes. The problem appears -+ // when stag=Acdb32a and dtag=Acdb48a for example. -+ if (old.dims[o_pos] == 0 || ild.dims[i_pos] % old.dims[o_pos] != 0) -+ return status::unimplemented; -+ -+ dim_t factor = ild.dims[i_pos] / old.dims[o_pos]; - p.nodes[ndims].n = old.dims[o_pos]; -+ p.nodes[ndims].dim_id = old.id[o_pos]; -+ p.nodes[ndims].tail_size = old.tails[o_pos]; -+ p.nodes[ndims].is_zero_pad_needed -+ = old.is_blk[o_pos] && old.tails[o_pos] > 0; - p.nodes[ndims].is = ild.strides[i_pos] * factor; - p.nodes[ndims].os = old.strides[o_pos]; - p.nodes[ndims].ss = ss[o_pos]; -@@ -282,12 +401,9 @@ status_t prb_init(prb_t &p, const memory_desc_t &imd, const memory_desc_t &omd, - ild.dims[i_pos] = factor; - } - } -- int blk_chunk_idx = ndims; -- CHECK(compute_chunk_idx(p, imd, omd, blk_idx, blk_chunk_idx)); - - p.ndims = ndims; - p.full_ndims = ndims; -- p.blk_chunk_idx = blk_chunk_idx; - - p.ioff = memory_desc_wrapper(imd).offset0(); - p.ooff = memory_desc_wrapper(omd).offset0(); -@@ -295,6 +411,28 @@ status_t prb_init(prb_t &p, const memory_desc_t &imd, const memory_desc_t &omd, - const int sum_idx = attr->post_ops_.find(primitive_kind::sum); - p.beta = sum_idx == -1 ? 0.f : attr->post_ops_.entry_[sum_idx].sum.scale; - -+ DEBUG({ -+ printf("init : "); -+ prb_dump(prb); -+ }); -+ // Sort the prb array in increasing sizes of the output stride -+ prb_normalize(p); -+ DEBUG({ -+ printf("norm : "); -+ prb_dump(prb); -+ }); -+ -+ // compensation strides require prb_normalized -+ prb_set_compensation_strides(p); -+ -+ /* Combine the variables, which appear together on both -+ * sides of the reorder */ -+ prb_simplify(p); -+ DEBUG({ -+ printf("smpl : "); -+ prb_dump(prb); -+ }); -+ - return success; - } - -@@ -307,28 +445,23 @@ void prb_normalize(prb_t &p) { - && p.nodes[j].n < p.nodes[min_pos].n); - if (new_min) min_pos = j; - } -- if (min_pos != d) { -- nstl::swap(p.nodes[d], p.nodes[min_pos]); -- if (p.blk_chunk_idx == min_pos || p.blk_chunk_idx == d) -- p.blk_chunk_idx = p.blk_chunk_idx == min_pos ? d : min_pos; -- } -+ if (min_pos != d) { nstl::swap(p.nodes[d], p.nodes[min_pos]); } - } - } - --status_t prb_check_blk(prb_t &p, const memory_desc_t &md_) { -- const auto md = memory_desc_wrapper(md_); -- const auto &bd = md.blocking_desc(); -- if (p.ip_tail == 0) return status::success; -- -- // Check if the inner blocks and p.nodes[blk].n in the firsti nblks -- // is equivalent in reverse order when has tail in block layout. -- const int nblk = bd.inner_nblks; -- for (int iblk = 0; iblk < nblk; ++iblk) { -- if (bd.inner_blks[nblk - iblk - 1] -- != static_cast(p.nodes[iblk].n)) -- return status::unimplemented; -+void prb_node_dependency(prb_t &prb) { -+ for (int i = 0; i < prb.ndims; i++) { -+ tr::node_t &node = prb.nodes[i]; -+ node.parent_node_id = node_t::empty_field; -+ for (int j = i + 1; j < prb.ndims; j++) { -+ const tr::node_t &potential_parent_node = prb.nodes[j]; -+ if (!potential_parent_node.is_dim_id_empty() -+ && potential_parent_node.dim_id == node.dim_id) { -+ node.parent_node_id = j; -+ break; -+ } -+ } - } -- return status::success; - } - - void prb_simplify(prb_t &p) { -@@ -338,16 +471,25 @@ void prb_simplify(prb_t &p) { - #pragma GCC diagnostic push - #pragma GCC diagnostic ignored "-Warray-bounds" - #endif -+ -+ const auto skip_dim_combining = [&p](const int node_id) -> bool { -+ return (p.is_tail_in_one_of_child_nodes(node_id) -+ && p.nodes[node_id].n > 1) -+ || p.nodes[node_id].tail_size > 0; -+ }; -+ -+ if (p.is_tail_present) prb_node_dependency(p); -+ - for (int d = 0; d < p.ndims - 1; ++d) { - auto &this_node = p.nodes[d + 0]; - auto &next_node = p.nodes[d + 1]; -- const bool skip_blk_idx = (p.ip_tail > 0 || p.op_tail > 0) -- && (p.blk_chunk_idx == d || p.blk_chunk_idx == d + 1); -+ const bool skip_dims_combining -+ = skip_dim_combining(d) || skip_dim_combining(d + 1); - const bool fold = false - || (next_node.n == static_cast(1) -- && !skip_blk_idx) // trivial case, just drop next node -+ && !skip_dims_combining) // trivial case, just drop next node - || (true // or real folding if possible -- && !skip_blk_idx -+ && !skip_dims_combining - && next_node.is - == static_cast( - this_node.n * this_node.is) -@@ -356,15 +498,20 @@ void prb_simplify(prb_t &p) { - this_node.n * this_node.os) - && next_node.ss - == static_cast( -- this_node.n * this_node.ss)); -+ this_node.n * this_node.ss) -+ && next_node.cs -+ == static_cast( -+ this_node.n * this_node.cs)); - if (fold) { - this_node.n *= next_node.n; -+ this_node.dim_id = node_t::empty_field; -+ this_node.is_zero_pad_needed = false; - for (int j = d + 2; j < p.ndims; ++j) - p.nodes[j - 1] = p.nodes[j]; -- if (d < p.blk_chunk_idx) --p.blk_chunk_idx; - --p.ndims; - --p.full_ndims; - --d; // make another try -+ if (p.is_tail_present) prb_node_dependency(p); - } - } - #if defined(__GNUC__) && __GNUC__ >= 4 -@@ -372,24 +519,42 @@ void prb_simplify(prb_t &p) { - #endif - } - --void prb_node_split(prb_t &p, int dim, size_t n1) { -+void prb_node_split(prb_t &p, int dim, size_t new_node_size) { - assert(dim < p.ndims); - assert(p.ndims < max_ndims); -- assert(p.nodes[dim].n % n1 == 0); -+ assert(p.nodes[dim].n % new_node_size == 0); - - p.ndims += 1; - p.full_ndims += 1; -- if (dim < p.blk_chunk_idx) p.blk_chunk_idx += 1; - - for (int d = p.ndims; d > dim + 1; --d) - p.nodes[d] = p.nodes[d - 1]; - -- p.nodes[dim + 1].n = p.nodes[dim].n / n1; -- p.nodes[dim + 1].is = p.nodes[dim].is * n1; -- p.nodes[dim + 1].os = p.nodes[dim].os * n1; -- p.nodes[dim + 1].ss = p.nodes[dim].ss * n1; -- -- p.nodes[dim].n = n1; -+ const size_t upper_node_size = p.nodes[dim].n / new_node_size; -+ const size_t lower_node_size = new_node_size; -+ p.nodes[dim + 1].n = upper_node_size; -+ p.nodes[dim].n = lower_node_size; -+ -+ const bool is_tail = p.nodes[dim].tail_size > 0; -+ const size_t upper_node_tail -+ = utils::div_up(p.nodes[dim].tail_size, lower_node_size) -+ == upper_node_size -+ ? 0 -+ : utils::div_up(p.nodes[dim].tail_size, lower_node_size); -+ const size_t lower_node_tail = p.nodes[dim].tail_size % lower_node_size; -+ p.nodes[dim].tail_size = is_tail ? lower_node_tail : 0; -+ p.nodes[dim + 1].tail_size = is_tail ? upper_node_tail : 0; -+ -+ p.nodes[dim + 1].is_zero_pad_needed -+ = p.nodes[dim].is_zero_pad_needed && p.nodes[dim + 1].tail_size > 0; -+ p.nodes[dim].is_zero_pad_needed -+ = p.nodes[dim].is_zero_pad_needed && p.nodes[dim].tail_size > 0; -+ -+ p.nodes[dim + 1].dim_id = p.nodes[dim].dim_id; -+ p.nodes[dim + 1].is = p.nodes[dim].is * lower_node_size; -+ p.nodes[dim + 1].os = p.nodes[dim].os * lower_node_size; -+ p.nodes[dim + 1].ss = p.nodes[dim].ss * lower_node_size; -+ p.nodes[dim + 1].cs = p.nodes[dim].cs * lower_node_size; - } - - void prb_node_swap(prb_t &p, int d0, int d1) { -@@ -425,8 +590,11 @@ void prb_dump(const prb_t &p) { - printf("@@@ type:%s:%s ndims:%d ", dnnl_dt2str(p.itype), - dnnl_dt2str(p.otype), p.ndims); - for (int d = 0; d < p.ndims; ++d) -- printf("[%zu:%td:%td:%td]", p.nodes[d].n, p.nodes[d].is, p.nodes[d].os, -- p.nodes[d].ss); -+ printf("[%zu:%zu:%d:%d:%s:%td:%td:%td:%td]", p.nodes[d].n, -+ p.nodes[d].tail_size, p.nodes[d].dim_id, -+ p.nodes[d].parent_node_id, -+ p.nodes[d].is_zero_pad_needed ? "true" : "false", p.nodes[d].is, -+ p.nodes[d].os, p.nodes[d].ss, p.nodes[d].cs); - printf(" off:%zu:%zu\n", p.ioff, p.ooff); - } - -diff --git a/src/cpu/reorder/cpu_reorder_regular_f32_f32.cpp b/src/cpu/reorder/cpu_reorder_regular_f32_f32.cpp -index f51e3c22414..fdefec8a049 100644 ---- a/src/cpu/reorder/cpu_reorder_regular_f32_f32.cpp -+++ b/src/cpu/reorder/cpu_reorder_regular_f32_f32.cpp -@@ -1,5 +1,6 @@ - /******************************************************************************* - * Copyright 2020-2022 Intel Corporation -+* Copyright 2022 FUJITSU LIMITED - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. -@@ -32,6 +33,7 @@ const impl_list_map_t ®ular_f32_f32_impl_list_map() { - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_blk_reorder_t)) - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_t)) - -+ DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_blk_reorder_t)) - DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_uni_reorder_t)) - REG_SR(f32, any, f32, any, fmt_order::any, spec::reference) - -@@ -44,6 +46,7 @@ const impl_list_map_t ®ular_f32_f32_impl_list_map() { - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_blk_reorder_t)) - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_t)) - -+ DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_blk_reorder_t)) - DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_uni_reorder_t)) - DNNL_NON_X64_ONLY(REG_SR_BIDIR(f32, any, f32, nCw16c)) - DNNL_NON_X64_ONLY(REG_SR_BIDIR(f32, any, f32, nCw8c)) -@@ -75,6 +78,7 @@ const impl_list_map_t ®ular_f32_f32_impl_list_map() { - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_blk_reorder_t)) - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_t)) - -+ DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_blk_reorder_t)) - DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_uni_reorder_t)) - - DNNL_NON_X64_ONLY(REG_SR_BIDIR(f32, any, f32, nChw16c)) -@@ -123,6 +127,7 @@ const impl_list_map_t ®ular_f32_f32_impl_list_map() { - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_blk_reorder_t)) - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_t)) - -+ DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_blk_reorder_t)) - DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_uni_reorder_t)) - - DNNL_NON_X64_ONLY(REG_SR_BIDIR(f32, any, f32, nCdhw16c)) -@@ -171,6 +176,7 @@ const impl_list_map_t ®ular_f32_f32_impl_list_map() { - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_blk_reorder_t)) - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_t)) - -+ DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_blk_reorder_t)) - DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_uni_reorder_t)) - - -diff --git a/src/cpu/reorder/cpu_reorder_regular_f32_s32.cpp b/src/cpu/reorder/cpu_reorder_regular_f32_s32.cpp -index fadbee0ecf8..b1881df80e0 100644 ---- a/src/cpu/reorder/cpu_reorder_regular_f32_s32.cpp -+++ b/src/cpu/reorder/cpu_reorder_regular_f32_s32.cpp -@@ -1,5 +1,6 @@ - /******************************************************************************* - * Copyright 2020-2022 Intel Corporation -+* Copyright 2022 FUJITSU LIMITED - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. -@@ -31,6 +32,7 @@ const impl_list_map_t ®ular_f32_s32_impl_list_map() { - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_blk_reorder_t)) - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_t)) - -+ DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_blk_reorder_t)) - DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_uni_reorder_t)) - DNNL_NON_X64_ONLY(REG_SR_BIDIR(f32, any, s32, nChw16c)) - REG_SR(f32, any, s32, any, fmt_order::any, spec::reference) -diff --git a/src/cpu/reorder/cpu_reorder_regular_f32_s8.cpp b/src/cpu/reorder/cpu_reorder_regular_f32_s8.cpp -index b83d47b2d6f..6bd305c7b41 100644 ---- a/src/cpu/reorder/cpu_reorder_regular_f32_s8.cpp -+++ b/src/cpu/reorder/cpu_reorder_regular_f32_s8.cpp -@@ -1,5 +1,6 @@ - /******************************************************************************* - * Copyright 2020-2022 Intel Corporation -+* Copyright 2022 FUJITSU LIMITED - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. -@@ -35,6 +36,7 @@ const impl_list_map_t ®ular_f32_s8_impl_list_map() { - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_blk_reorder_t)) - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_t)) - -+ DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_blk_reorder_t)) - DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_uni_reorder_t)) - - DNNL_NON_X64_ONLY(REG_SR_BIDIR(f32, any, s8, nChw16c)) -diff --git a/src/cpu/reorder/cpu_reorder_regular_f32_u8.cpp b/src/cpu/reorder/cpu_reorder_regular_f32_u8.cpp -index 4bae84307e6..d306c3abeb8 100644 ---- a/src/cpu/reorder/cpu_reorder_regular_f32_u8.cpp -+++ b/src/cpu/reorder/cpu_reorder_regular_f32_u8.cpp -@@ -1,5 +1,6 @@ - /******************************************************************************* - * Copyright 2020-2022 Intel Corporation -+* Copyright 2022 FUJITSU LIMITED - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. -@@ -33,6 +34,7 @@ const impl_list_map_t ®ular_f32_u8_impl_list_map() { - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_blk_reorder_t)) - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_t)) - -+ DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_blk_reorder_t)) - DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_uni_reorder_t)) - DNNL_NON_X64_ONLY(REG_SR_BIDIR(f32, any, u8, nChw16c)) - REG_SR(f32, any, u8, any, fmt_order::any, spec::reference) -diff --git a/src/cpu/reorder/cpu_reorder_regular_s32.cpp b/src/cpu/reorder/cpu_reorder_regular_s32.cpp -index 54d65661791..a8197402b0a 100644 ---- a/src/cpu/reorder/cpu_reorder_regular_s32.cpp -+++ b/src/cpu/reorder/cpu_reorder_regular_s32.cpp -@@ -1,5 +1,6 @@ - /******************************************************************************* - * Copyright 2020-2022 Intel Corporation -+* Copyright 2022 FUJITSU LIMITED - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. -@@ -34,6 +35,7 @@ const impl_list_map_t ®ular_s32_impl_list_map() { - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_blk_reorder_t)) - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_t)) - -+ DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_blk_reorder_t)) - DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_uni_reorder_t)) - - DNNL_NON_X64_ONLY(REG_SR_BIDIR(s32, any, f32, nChw16c)) -diff --git a/src/cpu/reorder/cpu_reorder_regular_s8.cpp b/src/cpu/reorder/cpu_reorder_regular_s8.cpp -index f57d01e2009..ce18dc5caf1 100644 ---- a/src/cpu/reorder/cpu_reorder_regular_s8.cpp -+++ b/src/cpu/reorder/cpu_reorder_regular_s8.cpp -@@ -1,5 +1,6 @@ - /******************************************************************************* - * Copyright 2020-2022 Intel Corporation -+* Copyright 2022 FUJITSU LIMITED - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. -@@ -41,6 +42,7 @@ const impl_list_map_t ®ular_s8_impl_list_map() { - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_blk_reorder_t)) - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_t)) - -+ DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_blk_reorder_t)) - DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_uni_reorder_t)) - - DNNL_NON_X64_ONLY(REG_SR_BIDIR(s8, any, f32, nChw16c)) -diff --git a/src/cpu/reorder/cpu_reorder_regular_u8.cpp b/src/cpu/reorder/cpu_reorder_regular_u8.cpp -index 73d731c3b15..87a58872262 100644 ---- a/src/cpu/reorder/cpu_reorder_regular_u8.cpp -+++ b/src/cpu/reorder/cpu_reorder_regular_u8.cpp -@@ -1,5 +1,6 @@ - /******************************************************************************* - * Copyright 2020-2022 Intel Corporation -+* Copyright 2022 FUJITSU LIMITED - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. -@@ -35,6 +36,7 @@ const impl_list_map_t ®ular_u8_impl_list_map() { - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_blk_reorder_t)) - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_t)) - -+ DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_blk_reorder_t)) - DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_uni_reorder_t)) - - DNNL_NON_X64_ONLY(REG_SR_BIDIR(u8, any, f32, nChw16c)) diff --git a/third_party/mkl_dnn/onednn_acl_threadpool_scheduler.patch b/third_party/mkl_dnn/onednn_acl_threadpool_scheduler.patch index 0e0cb39e82f1bb..7e3725af270292 100644 --- a/third_party/mkl_dnn/onednn_acl_threadpool_scheduler.patch +++ b/third_party/mkl_dnn/onednn_acl_threadpool_scheduler.patch @@ -1,20 +1,3 @@ - ******************************************************************************* - Copyright 2023 Arm Limited and affiliates. - SPDX-License-Identifier: Apache-2.0 - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ******************************************************************************* - diff --git a/src/cpu/aarch64/acl_threadpool_scheduler.cpp b/src/cpu/aarch64/acl_threadpool_scheduler.cpp index 418d7f30f..439ca862e 100644 --- a/src/cpu/aarch64/acl_threadpool_scheduler.cpp From 3627fb0eeb028e6c75979ab61e60e17dba588402 Mon Sep 17 00:00:00 2001 From: David Majnemer Date: Sat, 15 Jul 2023 18:58:23 -0700 Subject: [PATCH 352/376] [XLA] Bump up the number of inline tiles PiperOrigin-RevId: 548415965 --- tensorflow/compiler/xla/layout.h | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tensorflow/compiler/xla/layout.h b/tensorflow/compiler/xla/layout.h index e8b8378279807e..756bf319b0fb46 100644 --- a/tensorflow/compiler/xla/layout.h +++ b/tensorflow/compiler/xla/layout.h @@ -86,6 +86,8 @@ class Tile { absl::InlinedVector dimensions_; }; +using TileVector = absl::InlinedVector; + // TODO: Rename the `dim_level_types` field to `lvl_types`, so that it // matches `mlir::sparse_tensor::SparseTensorEncodingAttr`. class Layout { @@ -293,7 +295,7 @@ class Layout { return *this; } absl::Span tiles() const { return tiles_; } - absl::InlinedVector* mutable_tiles() { return &tiles_; } + TileVector* mutable_tiles() { return &tiles_; } int64_t element_size_in_bits() const { return element_size_in_bits_; } Layout& set_element_size_in_bits(int64_t value) { @@ -376,7 +378,7 @@ class Layout { DimensionVector minor_to_major_; // The tiles used in tiling-based layout. - absl::InlinedVector tiles_; + TileVector tiles_; // The primitive type to use for sparse array indices and pointers. Each of // these must either be INVALID, or an unsigned integer type. From 45829c3e1fcdcf71c2bd59d44409d388d9d9ab23 Mon Sep 17 00:00:00 2001 From: Hyeontaek Lim Date: Sat, 15 Jul 2023 20:27:27 -0700 Subject: [PATCH 353/376] [IFRT] Make Sharding deserialization to use a function to lookup devices `Sharding` deserialization now only requires a function that looks up `ifrt::Device`s instead of a pointer to the full `Client` instance, as the deserialization only needs to get `Device`s rather than knowing a `Client`. PiperOrigin-RevId: 548425224 --- tensorflow/compiler/xla/python/ifrt/BUILD | 4 +-- tensorflow/compiler/xla/python/ifrt/device.cc | 4 +-- tensorflow/compiler/xla/python/ifrt/device.h | 11 +++++--- .../xla/python/ifrt/sharding_serdes.cc | 24 ++++++++++-------- .../xla/python/ifrt/sharding_serdes.h | 12 ++++++--- .../xla/python/ifrt/sharding_serdes_test.cc | 25 ++++++++++--------- .../compiler/xla/python/pjrt_ifrt/BUILD | 2 +- .../python/pjrt_ifrt/xla_sharding_serdes.cc | 7 +++--- .../pjrt_ifrt/xla_sharding_serdes_test.cc | 7 +++--- 9 files changed, 55 insertions(+), 41 deletions(-) diff --git a/tensorflow/compiler/xla/python/ifrt/BUILD b/tensorflow/compiler/xla/python/ifrt/BUILD index 9e6d600f4105cf..d2f06b3e175df3 100644 --- a/tensorflow/compiler/xla/python/ifrt/BUILD +++ b/tensorflow/compiler/xla/python/ifrt/BUILD @@ -66,7 +66,6 @@ cc_library( ], deps = [ ":serdes", - ":sharding_proto_cc", ":types_proto_cc", "//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla:statusor", @@ -74,10 +73,10 @@ cc_library( "//tensorflow/compiler/xla/pjrt:pjrt_client", "//tensorflow/compiler/xla/python/ifrt/ir", "//tensorflow/tsl/platform:logging", - "//tensorflow/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/functional:function_ref", "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", @@ -350,6 +349,7 @@ xla_cc_test( ":serdes", ":sharding_serdes", ":sharding_test_util", + "@com_google_absl//absl/functional:bind_front", "@com_google_googletest//:gtest_main", ], ) diff --git a/tensorflow/compiler/xla/python/ifrt/device.cc b/tensorflow/compiler/xla/python/ifrt/device.cc index a549a811de6de6..629afe4d515854 100644 --- a/tensorflow/compiler/xla/python/ifrt/device.cc +++ b/tensorflow/compiler/xla/python/ifrt/device.cc @@ -24,12 +24,12 @@ limitations under the License. namespace xla { namespace ifrt { -StatusOr DeviceList::FromProto(Client* client, +StatusOr DeviceList::FromProto(LookupDeviceFunc lookup_device, const DeviceListProto& proto) { DeviceList::Devices devices; devices.reserve(proto.device_ids_size()); for (int device_id : proto.device_ids()) { - TF_ASSIGN_OR_RETURN(Device * device, client->LookupDevice(device_id)); + TF_ASSIGN_OR_RETURN(Device * device, lookup_device(device_id)); devices.push_back(device); } return DeviceList(std::move(devices)); diff --git a/tensorflow/compiler/xla/python/ifrt/device.h b/tensorflow/compiler/xla/python/ifrt/device.h index d54afa190deaa9..89f123f5e93869 100644 --- a/tensorflow/compiler/xla/python/ifrt/device.h +++ b/tensorflow/compiler/xla/python/ifrt/device.h @@ -20,6 +20,7 @@ limitations under the License. #include #include "absl/container/inlined_vector.h" +#include "absl/functional/function_ref.h" #include "tensorflow/compiler/xla/pjrt/pjrt_client.h" #include "tensorflow/compiler/xla/python/ifrt/types.pb.h" @@ -43,11 +44,15 @@ class DeviceList { // better performance. using Devices = absl::InlinedVector; + // Function that matches the semantics of `Client::LookupDevice()`. + using LookupDeviceFunc = absl::FunctionRef(int)>; + explicit DeviceList(Devices devices) : devices_(std::move(devices)) {} - // Constructs `DeviceList` from `DeviceListProto`. Device ids in the proto - // must be consistent with the devices owned by `client'. - static StatusOr FromProto(Client* client, + // Constructs `DeviceList` from `DeviceListProto`. Devices are looked up using + // `lookup_device`. Device ids in the proto must be consistent with the + // devices returned by `lookup_device`. + static StatusOr FromProto(LookupDeviceFunc lookup_device, const DeviceListProto& proto); // Returns a `DeviceListProto` representation. diff --git a/tensorflow/compiler/xla/python/ifrt/sharding_serdes.cc b/tensorflow/compiler/xla/python/ifrt/sharding_serdes.cc index d9ade8d6a62b96..5d4881499a09c9 100644 --- a/tensorflow/compiler/xla/python/ifrt/sharding_serdes.cc +++ b/tensorflow/compiler/xla/python/ifrt/sharding_serdes.cc @@ -20,7 +20,6 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/python/ifrt/client.h" #include "tensorflow/compiler/xla/python/ifrt/device.h" #include "tensorflow/compiler/xla/python/ifrt/serdes.h" #include "tensorflow/compiler/xla/python/ifrt/shape.h" @@ -63,7 +62,7 @@ class SingleDeviceShardingSerDes } TF_ASSIGN_OR_RETURN( Device * device, - deserialize_sharding_options->client->LookupDevice(proto.device_id())); + deserialize_sharding_options->lookup_device(proto.device_id())); return SingleDeviceSharding::Create(device); } @@ -96,9 +95,10 @@ class OpaqueShardingSerDes return absl::InvalidArgumentError( "Failed to parse serialized OpaqueSharding"); } - TF_ASSIGN_OR_RETURN(auto devices, DeviceList::FromProto( - deserialize_sharding_options->client, - proto.devices())); + TF_ASSIGN_OR_RETURN( + auto devices, + DeviceList::FromProto(deserialize_sharding_options->lookup_device, + proto.devices())); return OpaqueSharding::Create(std::move(devices)); } @@ -136,9 +136,10 @@ class ConcreteShardingSerDes return absl::InvalidArgumentError( "Failed to parse serialized ConcreteSharding"); } - TF_ASSIGN_OR_RETURN(auto devices, DeviceList::FromProto( - deserialize_sharding_options->client, - proto.devices())); + TF_ASSIGN_OR_RETURN( + auto devices, + DeviceList::FromProto(deserialize_sharding_options->lookup_device, + proto.devices())); TF_ASSIGN_OR_RETURN(auto shape, Shape::FromProto(proto.shape())); std::vector shard_shapes; shard_shapes.reserve(proto.shard_shapes_size()); @@ -183,9 +184,10 @@ class ConcreteEvenShardingSerDes return absl::InvalidArgumentError( "Failed to parse serialized ConcreteEvenSharding"); } - TF_ASSIGN_OR_RETURN(auto devices, DeviceList::FromProto( - deserialize_sharding_options->client, - proto.devices())); + TF_ASSIGN_OR_RETURN( + auto devices, + DeviceList::FromProto(deserialize_sharding_options->lookup_device, + proto.devices())); TF_ASSIGN_OR_RETURN(auto shape, Shape::FromProto(proto.shape())); TF_ASSIGN_OR_RETURN(auto shard_shape, Shape::FromProto(proto.shard_shape())); diff --git a/tensorflow/compiler/xla/python/ifrt/sharding_serdes.h b/tensorflow/compiler/xla/python/ifrt/sharding_serdes.h index 965670bcbc3401..7ba47d87df9aac 100644 --- a/tensorflow/compiler/xla/python/ifrt/sharding_serdes.h +++ b/tensorflow/compiler/xla/python/ifrt/sharding_serdes.h @@ -19,6 +19,7 @@ limitations under the License. #include #include "llvm/Support/ExtensibleRTTI.h" +#include "tensorflow/compiler/xla/python/ifrt/device.h" #include "tensorflow/compiler/xla/python/ifrt/serdes.h" #include "tensorflow/compiler/xla/statusor.h" @@ -27,15 +28,18 @@ namespace ifrt { class Client; -// Options for deserializing shardings. +// Options for deserializing shardings. Function referenced by `lookup_device` +// must remain valid during deserialization. struct DeserializeShardingOptions : llvm::RTTIExtends { - explicit DeserializeShardingOptions(Client* client) : client(client) {} + explicit DeserializeShardingOptions( + DeviceList::LookupDeviceFunc lookup_device) + : lookup_device(lookup_device) {} static char ID; // NOLINT - // The client whose devices will be used by deserialized shardings. - Client* client; + // Function that converts device ids to devices. + DeviceList::LookupDeviceFunc lookup_device; }; // Casts `DeserializeOptions` into `DeserializeShardingOptions`. diff --git a/tensorflow/compiler/xla/python/ifrt/sharding_serdes_test.cc b/tensorflow/compiler/xla/python/ifrt/sharding_serdes_test.cc index 3ea3aad6478fa6..5cbd445da77a75 100644 --- a/tensorflow/compiler/xla/python/ifrt/sharding_serdes_test.cc +++ b/tensorflow/compiler/xla/python/ifrt/sharding_serdes_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include +#include "absl/functional/bind_front.h" #include "tensorflow/compiler/xla/python/ifrt/serdes.h" #include "tensorflow/compiler/xla/python/ifrt/sharding.h" #include "tensorflow/compiler/xla/python/ifrt/sharding_test_util.h" @@ -39,11 +40,11 @@ TEST_P(ShardingSerDesTest, SingleDeviceShardingRoundTrip) { TF_ASSERT_OK_AND_ASSIGN(auto serialized, Serialize(*sharding)); - auto deserialized_options = - std::make_unique(client()); TF_ASSERT_OK_AND_ASSIGN( auto deserialized, - Deserialize(serialized, std::move(deserialized_options))); + Deserialize(serialized, + std::make_unique( + absl::bind_front(&Client::LookupDevice, client())))); const auto* out_sharding = llvm::dyn_cast(deserialized.get()); @@ -56,11 +57,11 @@ TEST_P(ShardingSerDesTest, OpaqueShardingRoundTrip) { TF_ASSERT_OK_AND_ASSIGN(auto serialized, Serialize(*sharding)); - auto deserialized_options = - std::make_unique(client()); TF_ASSERT_OK_AND_ASSIGN( auto deserialized, - Deserialize(serialized, std::move(deserialized_options))); + Deserialize(serialized, + std::make_unique( + absl::bind_front(&Client::LookupDevice, client())))); const auto* out_sharding = llvm::dyn_cast(deserialized.get()); ASSERT_NE(out_sharding, nullptr); @@ -75,11 +76,11 @@ TEST_P(ShardingSerDesTest, ConcreteShardingRoundTrip) { TF_ASSERT_OK_AND_ASSIGN(auto serialized, Serialize(*sharding)); - auto deserialized_options = - std::make_unique(client()); TF_ASSERT_OK_AND_ASSIGN( auto deserialized, - Deserialize(serialized, std::move(deserialized_options))); + Deserialize(serialized, + std::make_unique( + absl::bind_front(&Client::LookupDevice, client())))); const auto* out_sharding = llvm::dyn_cast(deserialized.get()); @@ -97,11 +98,11 @@ TEST_P(ShardingSerDesTest, ConcreteEvenShardingRoundTrip) { TF_ASSERT_OK_AND_ASSIGN(auto serialized, Serialize(*sharding)); - auto deserialized_options = - std::make_unique(client()); TF_ASSERT_OK_AND_ASSIGN( auto deserialized, - Deserialize(serialized, std::move(deserialized_options))); + Deserialize(serialized, + std::make_unique( + absl::bind_front(&Client::LookupDevice, client())))); const auto* out_sharding = llvm::dyn_cast(deserialized.get()); diff --git a/tensorflow/compiler/xla/python/pjrt_ifrt/BUILD b/tensorflow/compiler/xla/python/pjrt_ifrt/BUILD index d9f0338ae29a35..a7023e0b73459d 100644 --- a/tensorflow/compiler/xla/python/pjrt_ifrt/BUILD +++ b/tensorflow/compiler/xla/python/pjrt_ifrt/BUILD @@ -126,7 +126,6 @@ cc_library( ":xla_ifrt", ":xla_sharding_proto_cc", "//tensorflow/compiler/xla/hlo/ir:hlo", - "//tensorflow/compiler/xla/python/ifrt", "//tensorflow/compiler/xla/python/ifrt:serdes", "//tensorflow/compiler/xla/python/ifrt:sharding_serdes", ], @@ -142,6 +141,7 @@ xla_cc_test( "//tensorflow/compiler/xla/hlo/ir:hlo", "//tensorflow/compiler/xla/python/ifrt:sharding_serdes", "//tensorflow/compiler/xla/python/ifrt:sharding_test_util", + "@com_google_absl//absl/functional:bind_front", "@com_google_googletest//:gtest_main", ], ) diff --git a/tensorflow/compiler/xla/python/pjrt_ifrt/xla_sharding_serdes.cc b/tensorflow/compiler/xla/python/pjrt_ifrt/xla_sharding_serdes.cc index c3d8d2470600b9..daff5e2149ff6c 100644 --- a/tensorflow/compiler/xla/python/pjrt_ifrt/xla_sharding_serdes.cc +++ b/tensorflow/compiler/xla/python/pjrt_ifrt/xla_sharding_serdes.cc @@ -54,9 +54,10 @@ class HloShardingSerDes : public llvm::RTTIExtends { return absl::InvalidArgumentError( "Failed to parse serialized HloSharding"); } - TF_ASSIGN_OR_RETURN(auto devices, DeviceList::FromProto( - deserialize_sharding_options->client, - proto.devices())); + TF_ASSIGN_OR_RETURN( + auto devices, + DeviceList::FromProto(deserialize_sharding_options->lookup_device, + proto.devices())); TF_ASSIGN_OR_RETURN(auto xla_hlo_sharding, xla::HloSharding::FromProto(proto.xla_op_sharding())); return HloSharding::Create(std::move(devices), std::move(xla_hlo_sharding)); diff --git a/tensorflow/compiler/xla/python/pjrt_ifrt/xla_sharding_serdes_test.cc b/tensorflow/compiler/xla/python/pjrt_ifrt/xla_sharding_serdes_test.cc index 95b90c21fbba9a..a98a0271a03a41 100644 --- a/tensorflow/compiler/xla/python/pjrt_ifrt/xla_sharding_serdes_test.cc +++ b/tensorflow/compiler/xla/python/pjrt_ifrt/xla_sharding_serdes_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "absl/functional/bind_front.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_sharding.h" #include "tensorflow/compiler/xla/python/ifrt/sharding_serdes.h" #include "tensorflow/compiler/xla/python/ifrt/sharding_test_util.h" @@ -39,11 +40,11 @@ TEST_P(XlaShardingSerDesTest, HloShardingRoundTrip) { TF_ASSERT_OK_AND_ASSIGN(auto serialized, Serialize(*sharding)); - auto deserialized_options = - std::make_unique(client()); TF_ASSERT_OK_AND_ASSIGN( auto deserialized, - Deserialize(serialized, std::move(deserialized_options))); + Deserialize(serialized, + std::make_unique( + absl::bind_front(&Client::LookupDevice, client())))); const auto* out_sharding = llvm::dyn_cast(deserialized.get()); ASSERT_NE(out_sharding, nullptr); From d554f77eacb88cff9ca802503294e4deec2f477b Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sun, 16 Jul 2023 02:02:07 -0700 Subject: [PATCH 354/376] compat: Update forward compatibility horizon to 2023-07-16 PiperOrigin-RevId: 548462349 --- tensorflow/python/compat/compat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index 70b6e1971d498e..88b73342c3202f 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -29,7 +29,7 @@ # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2023, 7, 15) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2023, 7, 16) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None From 61302d8dd536910b621fdb56c08f76cf59750595 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sun, 16 Jul 2023 02:02:11 -0700 Subject: [PATCH 355/376] Update GraphDef version to 1559. PiperOrigin-RevId: 548462357 --- tensorflow/core/public/version.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index d319fe9e9a52cc..963b5d9169e7c0 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -108,7 +108,7 @@ limitations under the License. #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 -#define TF_GRAPH_DEF_VERSION 1558 // Updated: 2023/7/15 +#define TF_GRAPH_DEF_VERSION 1559 // Updated: 2023/7/16 // Checkpoint compatibility versions (the versions field in SavedSliceMeta). // From af20981148cc96fced153f7a4863ce7860121f38 Mon Sep 17 00:00:00 2001 From: Deqiang Chen Date: Sun, 16 Jul 2023 06:19:08 -0700 Subject: [PATCH 356/376] Add runtime config when creating KernelFallbackCompatRequest PiperOrigin-RevId: 548487262 --- .../kernel/kernel_fallback_compat_request_state.cc | 5 ++++- .../kernel/kernel_fallback_compat_request_state.h | 3 ++- .../runtime/runtime_fallback_batch_tf_opkernels.cc | 3 ++- tensorflow/core/runtime_fallback/util/fallback_test_util.cc | 3 ++- tensorflow/core/tfrt/mlrt/kernel/batch_kernel.cc | 2 ++ 5 files changed, 12 insertions(+), 4 deletions(-) diff --git a/tensorflow/core/runtime_fallback/kernel/kernel_fallback_compat_request_state.cc b/tensorflow/core/runtime_fallback/kernel/kernel_fallback_compat_request_state.cc index f9a404d560e7b3..16ef9cd3eefa91 100644 --- a/tensorflow/core/runtime_fallback/kernel/kernel_fallback_compat_request_state.cc +++ b/tensorflow/core/runtime_fallback/kernel/kernel_fallback_compat_request_state.cc @@ -30,6 +30,7 @@ limitations under the License. #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/platform/threadpool_interface.h" #include "tensorflow/core/platform/types.h" +#include "tensorflow/core/tfrt/graph_executor/config.h" #include "tensorflow/core/tfrt/utils/fallback_tensor.h" #include "tfrt/host_context/resource_context.h" // from @tf_runtime #include "tfrt/support/pointer_util.h" // from @tf_runtime @@ -158,7 +159,8 @@ Status SetUpKernelFallbackCompatRequestContext( std::function)>* runner, tfrt_stub::CostRecorder* cost_recorder, tfrt::ResourceContext* client_graph_resource_context, - tensorflow::CancellationManager* cancellation_manager) { + tensorflow::CancellationManager* cancellation_manager, + const tensorflow::tfrt_stub::RuntimeConfig* runtime_config) { DCHECK(builder); DCHECK(device_manager); DCHECK(pflr); @@ -175,6 +177,7 @@ Status SetUpKernelFallbackCompatRequestContext( fallback_request_state.set_client_graph_resource_context( client_graph_resource_context); fallback_request_state.set_cancellation_manager(cancellation_manager); + fallback_request_state.set_runtime_config(runtime_config); return OkStatus(); } diff --git a/tensorflow/core/runtime_fallback/kernel/kernel_fallback_compat_request_state.h b/tensorflow/core/runtime_fallback/kernel/kernel_fallback_compat_request_state.h index a37c772b978c43..201eae2e1c6f5d 100644 --- a/tensorflow/core/runtime_fallback/kernel/kernel_fallback_compat_request_state.h +++ b/tensorflow/core/runtime_fallback/kernel/kernel_fallback_compat_request_state.h @@ -240,7 +240,8 @@ Status SetUpKernelFallbackCompatRequestContext( std::function)>* runner, tfrt_stub::CostRecorder* cost_recorder, tfrt::ResourceContext* client_graph_resource_context, - tensorflow::CancellationManager* cancellation_manager); + tensorflow::CancellationManager* cancellation_manager, + const tensorflow::tfrt_stub::RuntimeConfig* runtime_config); } // namespace tfd } // namespace tensorflow diff --git a/tensorflow/core/runtime_fallback/runtime/runtime_fallback_batch_tf_opkernels.cc b/tensorflow/core/runtime_fallback/runtime/runtime_fallback_batch_tf_opkernels.cc index 1facf07d621831..2309310a4d3a26 100644 --- a/tensorflow/core/runtime_fallback/runtime/runtime_fallback_batch_tf_opkernels.cc +++ b/tensorflow/core/runtime_fallback/runtime/runtime_fallback_batch_tf_opkernels.cc @@ -271,7 +271,8 @@ Status SetUpKernelFallbackCompatRequestContextForBatch( src_fallback_request_state->runner(), src_fallback_request_state->cost_recorder(), src_fallback_request_state->client_graph_resource_context(), - src_fallback_request_state->cancellation_manager()); + src_fallback_request_state->cancellation_manager(), + src_fallback_request_state->runtime_config()); } StatusOr> SetUpRequestContext( diff --git a/tensorflow/core/runtime_fallback/util/fallback_test_util.cc b/tensorflow/core/runtime_fallback/util/fallback_test_util.cc index af9bba7079fd3f..43d617cca46664 100644 --- a/tensorflow/core/runtime_fallback/util/fallback_test_util.cc +++ b/tensorflow/core/runtime_fallback/util/fallback_test_util.cc @@ -72,7 +72,8 @@ tfrt::ExecutionContext CreateFallbackTestExecutionContext( user_intra_op_threadpool, /*model_metadata=*/std::nullopt, /*runner=*/nullptr, /*cost_recorder=*/nullptr, /*client_graph_resource_context=*/resource_context, - /*cancellation_manager=*/nullptr); + /*cancellation_manager=*/nullptr, + /*runtime_config=*/nullptr); TF_DCHECK_OK(status); TF_DCHECK_OK(status); diff --git a/tensorflow/core/tfrt/mlrt/kernel/batch_kernel.cc b/tensorflow/core/tfrt/mlrt/kernel/batch_kernel.cc index 796a484dfa4665..20fd6d69a3ef2f 100644 --- a/tensorflow/core/tfrt/mlrt/kernel/batch_kernel.cc +++ b/tensorflow/core/tfrt/mlrt/kernel/batch_kernel.cc @@ -324,6 +324,8 @@ void MlrtBatchResource::ProcessFuncBatchImpl( fallback_request_state.set_cancellation_manager( caller_fallback_request_state.cancellation_manager()); + fallback_request_state.set_runtime_config( + caller_fallback_request_state.runtime_config()); tensorflow::profiler::TraceMeProducer activity( // To TraceMeConsumers in WorkQueue. From 3e69e9824fd02c21e58fe318738bdc8f2807c2f2 Mon Sep 17 00:00:00 2001 From: Ilia Sergachev Date: Sun, 16 Jul 2023 12:33:20 -0700 Subject: [PATCH 357/376] [XLA:GPU] Re-enable fusion of broadcasts of scalar constants in Triton GEMM. An issue with constants in the Triton compiler was resolved. PiperOrigin-RevId: 548520833 --- .../xla/service/gpu/gemm_rewriter_triton.cc | 8 +++---- .../service/gpu/gemm_rewriter_triton_test.cc | 6 ++--- .../xla/service/gpu/ir_emitter_triton_test.cc | 23 +++++++++---------- 3 files changed, 17 insertions(+), 20 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton.cc b/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton.cc index d1acb13ce8d759..6fc4470cb90439 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton.cc @@ -629,11 +629,9 @@ FusionDecision CanFuse(const HloInstruction& hlo, bool as_input, if (!IsSupportedDataType(hlo.shape().element_type(), gpu_version)) { return "Unsupported output data type."; } - if (hlo.IsConstant()) { - return "Not fusing a constant."; - } - if (hlo.opcode() == HloOpcode::kBroadcast) { - return "Not fusing a broadcast."; + if (hlo.opcode() == HloOpcode::kBroadcast && + !hlo_query::IsScalarConstant(hlo.operand(0))) { + return "Skipping unsupported broadcast."; } if (as_input) { if (hlo.GetModule() diff --git a/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton_test.cc b/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton_test.cc index 151199150b3e1f..d0a57b84d396cc 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton_test.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton_test.cc @@ -94,7 +94,7 @@ ENTRY e { GmockMatch(m::Fusion(m::Parameter(), m::Parameter()))); } -TEST_F(GemmRewriterTritonTest, DoNotFuseConstants) { +TEST_F(GemmRewriterTritonTest, DoNotFuseVectorConstants) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(R"( HloModule m @@ -102,8 +102,8 @@ HloModule m ENTRY e { p0 = s8[60,5] parameter(0) c0 = f16[60,5] convert(p0) - cst1 = f16[] constant(1234) - r1 = f16[5,120] broadcast(cst1) + cst1 = f16[5] constant({...}) + r1 = f16[5,120] broadcast(cst1), dimensions={0} ROOT d = f16[60,120] dot(c0, r1), lhs_contracting_dims={1}, rhs_contracting_dims={0} })")); diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_triton_test.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_triton_test.cc index 9f97c1c0bc5597..51183398f4bef9 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_triton_test.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_triton_test.cc @@ -648,7 +648,9 @@ ENTRY entry { .status()); } -TEST_F(TritonGemmTest, TritonCompilerCanFailOnConstants) { +// Triton compiler used to have an issue with reordering constants: +// https://github.com/openai/triton/issues/1864 +TEST_F(TritonGemmTest, TritonCompilerDoesNotFailOnConstants) { TF_CHECK_OK(GetOptimizedModule(R"( HloModule m, is_scheduled=true @@ -850,7 +852,7 @@ ENTRY e { EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-2, /*arel=*/1e-2})); } -TEST_F(TritonGemmLevel2Test, BroadcastOfConstantIsNotFused) { +TEST_F(TritonGemmLevel2Test, BroadcastOfScalarConstantIsFused) { const std::string kHloText = R"( HloModule m @@ -859,19 +861,16 @@ ENTRY e { p0c = f32[70,30] convert(p0) constant_3663 = f32[] constant(4321) bc0 = f32[30,5] broadcast(constant_3663) - p1 = f32[30,5] parameter(1) - a = f32[30,5] add(p1, bc0) - ROOT d = f32[70,5] dot(p0c, a), + ROOT d = f32[70,5] dot(p0c, bc0), lhs_contracting_dims={1}, rhs_contracting_dims={0} })"; - MatchOptimizedHlo(kHloText, R"( -; CHECK: ENTRY -; CHECK: constant -; CHECK: broadcast -; CHECK: fusion -; CHECK-SAME: kind=kCustom -)"); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + GetOptimizedModule(kHloText)); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch(m::Fusion(m::Parameter()) + .WithFusionKind(HloInstruction::FusionKind::kCustom))); EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/2e-3, /*arel=*/2e-3})); } From 0dc240282c5cacd1e3e958a4b2667e18fa0feebf Mon Sep 17 00:00:00 2001 From: Ilia Sergachev Date: Sun, 16 Jul 2023 15:55:21 -0700 Subject: [PATCH 358/376] [XLA:GPU] Implement dimension analysis of output fusions in Triton GEMMs. This will not affect the existing fusions yet: the emitter is currently using fixed simplified logic to handle output fusions. PiperOrigin-RevId: 548538621 --- tensorflow/compiler/xla/service/gpu/BUILD | 1 + .../xla/service/gpu/gemm_rewriter_triton.cc | 116 +++++++++++------- .../service/gpu/gemm_rewriter_triton_test.cc | 99 +++++++++++++++ 3 files changed, 173 insertions(+), 43 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 7d491f8063173b..c3f4bf4ddaeef4 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -1208,6 +1208,7 @@ xla_cc_test( ":gemm_rewriter_triton", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla/hlo/ir:hlo", + "//tensorflow/compiler/xla/service:pattern_matcher", "//tensorflow/compiler/xla/service:pattern_matcher_gmock", "//tensorflow/compiler/xla/stream_executor:device_description", "//tensorflow/compiler/xla/tests:hlo_test_base", diff --git a/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton.cc b/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton.cc index 6fc4470cb90439..97c54e52a91626 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton.cc @@ -1186,6 +1186,65 @@ Status MakeDotComputationSplitKBatch( return OkStatus(); } +// Propagate dimension orders in consumer->producer direction starting at +// `origin` with input `origin_dim_order` till parameters of the computation. +// Store the found parameters and their iteration specs. +Status PropagateDimensionOrdersToParameters( + const HloInstruction& origin, const DimensionOrder& origin_dim_order, + absl::flat_hash_set& parameters, + absl::flat_hash_map& + iter_specs) { + absl::flat_hash_set visited; + std::queue to_process; + // Dimension orders describing inputs of corresponding instructions. + absl::flat_hash_map dim_orders; + TF_RET_CHECK(RequireTritonGemmSupportedDimOrder(origin_dim_order)); + dim_orders.insert({&origin, origin_dim_order}); + visited.insert(&origin); + to_process.push(&origin); + while (!to_process.empty()) { + const HloInstruction* hlo = to_process.front(); + to_process.pop(); + if (hlo->opcode() == HloOpcode::kParameter) { + // One parameter corresponds to one iteration spec in the results of the + // analysis. This describes well situations when a parameter has one or + // more elementwise users - they share the same tiling. Situations when + // one instruction is read differently by different users in the same + // scope of the dot are currently prevented during the fusion. + TF_RET_CHECK(parameters.insert(hlo).second); + VLOG(5) << hlo->ToString(); + } + for (const HloInstruction* operand : hlo->operands()) { + if (!visited.insert(operand).second) { + continue; + } + if (operand->opcode() == HloOpcode::kDot) { + // Encountering the dot itself happens during the processing of the + // output fusion. The propagation should stop at it. + continue; + } + // Operand's output is described by its consumer's input. + auto [it, inserted] = + dim_orders.insert({operand, DimensionOrder(dim_orders.at(hlo))}); + TF_RET_CHECK(inserted); + DimensionOrder& hlo_operand_dim_order = it->second; + TF_RET_CHECK(hlo_operand_dim_order.HandleInstruction( + operand, DimensionOrder::TransformDirection::kOutputToInput)) + << operand->ToString(); + TF_RET_CHECK(RequireTritonGemmSupportedDimOrder(hlo_operand_dim_order)); + to_process.push(operand); + } + } + // For now all parameters of one scope have to use the same tiling. + for (const HloInstruction* parameter : parameters) { + TF_RET_CHECK(dim_orders.at(parameter).IsPhysicallyEquivalent( + dim_orders.at(*parameters.cbegin()))); + iter_specs[parameter] = + DimensionOrderToTensorIterationSpec(dim_orders.at(parameter)); + } + return OkStatus(); +} + } // anonymous namespace // BF16 is supported in a sense that all operations on it are implemented @@ -1312,52 +1371,14 @@ DotFusionAnalysis::DotFusionAnalysis(const HloComputation* dot_computation, for (const Scope scope : {Scope::LHS, Scope::RHS}) { const int operand_number = static_cast(scope); - const HloInstruction* dot_operand = dot->operand(operand_number); - absl::flat_hash_set visited; - std::queue to_process; - // Dimension orders describing inputs of corresponding instructions. - absl::flat_hash_map dim_orders; + const HloInstruction* operand = dot->operand(operand_number); DimensionOrder dot_operand_dim_order = DimensionOrder::FromDotOperand(*dot, operand_number, split_k); CHECK(dot_operand_dim_order.HandleInstruction( - dot_operand, DimensionOrder::TransformDirection::kOutputToInput)); - CHECK(RequireTritonGemmSupportedDimOrder(dot_operand_dim_order)) - << dot_computation->ToString(); - dim_orders.insert({dot_operand, dot_operand_dim_order}); - visited.insert(dot_operand); - to_process.push(dot_operand); - while (!to_process.empty()) { - const HloInstruction* hlo = to_process.front(); - to_process.pop(); - if (hlo->opcode() == HloOpcode::kParameter) { - CHECK(parameters_[scope].insert(hlo).second); - VLOG(5) << hlo->ToString(); - } - for (const HloInstruction* hlo_operand : hlo->operands()) { - if (!visited.insert(hlo_operand).second) { - continue; - } - // Operand's output is described by its consumer's input. - auto [it, inserted] = dim_orders.insert( - {hlo_operand, DimensionOrder(dim_orders.at(hlo))}); - CHECK(inserted); - DimensionOrder& hlo_operand_dim_order = it->second; - CHECK(hlo_operand_dim_order.HandleInstruction( - hlo_operand, DimensionOrder::TransformDirection::kOutputToInput)); - CHECK(RequireTritonGemmSupportedDimOrder(hlo_operand_dim_order)) - << " " << dot_computation->ToString(); - to_process.push(hlo_operand); - } - } - - // For now all parameters of one scope have to use the same tiling. - for (const HloInstruction* parameter : parameters_[scope]) { - CHECK(dim_orders.at(parameter).IsPhysicallyEquivalent( - dim_orders.at(*parameters_[scope].cbegin()))) - << dot_computation->ToString(); - iter_specs_[scope][parameter] = - DimensionOrderToTensorIterationSpec(dim_orders.at(parameter)); - } + operand, DimensionOrder::TransformDirection::kOutputToInput)); + CHECK_OK(PropagateDimensionOrdersToParameters( + *operand, dot_operand_dim_order, parameters_[scope], + iter_specs_[scope])); } int64_t lhs_nc_split_major_part_size = -1; @@ -1373,6 +1394,7 @@ DotFusionAnalysis::DotFusionAnalysis(const HloComputation* dot_computation, *dot, split_k, lhs_nc_split_major_part_size); const HloInstruction* output = dot; // Currently supported is one fusion output and one path from dot to it. + // Propagate dimension order from dot to root. while (!output->IsRoot()) { CHECK_EQ(output->user_count(), 1); output = output->users()[0]; @@ -1383,6 +1405,14 @@ DotFusionAnalysis::DotFusionAnalysis(const HloComputation* dot_computation, CHECK(iter_specs_[Scope::OUTPUT] .insert({output, DimensionOrderToTensorIterationSpec(dim_order)}) .second); + if (output != dot) { + // Propagate back to parameters of the output fusion. + CHECK(dim_order.HandleInstruction( + output, DimensionOrder::TransformDirection::kOutputToInput)); + CHECK_OK(PropagateDimensionOrdersToParameters(*output, dim_order, + parameters_[Scope::OUTPUT], + iter_specs_[Scope::OUTPUT])); + } } const DimIterationSpec* DotFusionAnalysis::IterSpec( diff --git a/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton_test.cc b/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton_test.cc index d0a57b84d396cc..f256ef6b0c3a15 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton_test.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/hlo/ir/hlo_instruction.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_opcode.h" #include "tensorflow/compiler/xla/service/gpu/cublas_padding_requirements.h" +#include "tensorflow/compiler/xla/service/pattern_matcher.h" #include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/stream_executor/device_description.h" @@ -452,6 +453,52 @@ ENTRY e { /*subfragments=*/ElementsAre(3)))); } +TEST_F(TritonDotAnalysisTest, OutputParameterIsHandled) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +HloModule t + +triton_dot { + p0 = bf16[24,4]{1,0} parameter(0) + p1 = bf16[4,3]{1,0} parameter(1) + dot = bf16[24,3]{1,0} dot(p0, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} + p2 = f16[3,24]{1,0} parameter(2) + p2t = f16[24,3]{1,0} transpose(p2), dimensions={1,0} + p2tc = bf16[24,3]{1,0} convert(p2t) + ROOT r = bf16[24,3]{1,0} divide(p2tc, dot) +} + +ENTRY e { + p0 = bf16[24,4]{1,0} parameter(0) + p1 = bf16[4,3]{1,0} parameter(1) + p2 = f16[3,24]{1,0} parameter(2) + ROOT r = bf16[24,3]{1,0} fusion(p0, p1, p2), kind=kCustom, + calls=triton_dot +})")); + const HloComputation* dot_computation = + module->entry_computation()->root_instruction()->called_computations()[0]; + const HloInstruction* output_param = + dot_computation->parameter_instruction(2); + const DotFusionAnalysis analysis(dot_computation); + EXPECT_EQ( + analysis.IterSpec(DotFusionAnalysis::Scope::OUTPUT, output_param, 0) + ->size(), + 1); + EXPECT_THAT( + *analysis.IterSpec(DotFusionAnalysis::Scope::OUTPUT, output_param, 0), + ElementsAre(FieldsAre(/*stride=*/1, /*count=*/24, + /*subfragments=*/ElementsAre(24)))); + EXPECT_EQ( + analysis.IterSpec(DotFusionAnalysis::Scope::OUTPUT, output_param, 1) + ->size(), + 1); + EXPECT_THAT( + *analysis.IterSpec(DotFusionAnalysis::Scope::OUTPUT, output_param, 1), + ElementsAre(FieldsAre(/*stride=*/24, /*count=*/3, + /*subfragments=*/ElementsAre(3)))); +} + TEST_F(TritonDotAnalysisTest, InputBroadcastFromScalarIsHandled) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(R"( @@ -1142,6 +1189,58 @@ ENTRY e { DotFusionAnalysis::kMaxParameterPerScope * 2); } +TEST_F(GemmRewriterTritonLevel2Test, ParameterUsedElementwiseTwiceIsFused) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +HloModule t + +ENTRY e { + p0 = f32[1,35] parameter(0) + p0n = f32[1,35] negate(p0) + p0e = f32[1,35] exponential(p0) + a = f32[1,35] add(p0e, p0n) + p1 = f16[35,1] parameter(1) + p1c = f32[35,1] convert(p1) + ROOT dot = f32[1,1] dot(a, p1c), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +})")); + EXPECT_TRUE(GemmRewriterTriton(se::CudaComputeCapability{ + se::CudaComputeCapability::VOLTA, 0}) + .Run(module.get()) + .value()); + EXPECT_THAT(module->entry_computation()->root_instruction(), + GmockMatch((m::Fusion(m::Parameter(), m::Parameter())))); + const DotFusionAnalysis analysis(module->entry_computation() + ->root_instruction() + ->called_computations()[0]); + EXPECT_EQ(analysis.ScopeParameters(DotFusionAnalysis::Scope::LHS).size(), 1); + EXPECT_EQ(analysis.ScopeParameters(DotFusionAnalysis::Scope::RHS).size(), 1); +} + +TEST_F(GemmRewriterTritonLevel2Test, + ParameterUsedNonElementwiseTwiceIsFusedOnlyOnOnePath) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +HloModule t + +ENTRY e { + p0 = f32[4,4] parameter(0) + p0t = f32[4,4] transpose(p0), dimensions={1,0} + a = f32[4,4] add(p0, p0t) + p1 = f16[4,5] parameter(1) + p1c = f32[4,5] convert(p1) + ROOT dot = f32[4,5] dot(a, p1c), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +})")); + EXPECT_TRUE(GemmRewriterTriton(se::CudaComputeCapability{ + se::CudaComputeCapability::VOLTA, 0}) + .Run(module.get()) + .value()); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch((m::Fusion(m::Parameter(), m::Transpose(), m::Parameter())))); +} + } // namespace } // namespace gpu } // namespace xla From 52dc79628e5c3806fde55f1865ac684872242a9f Mon Sep 17 00:00:00 2001 From: Rahul Joshi Date: Sun, 16 Jul 2023 17:02:35 -0700 Subject: [PATCH 359/376] [XLA] Add option to pattern match true only scalars in offset computation. - This option will force reduce-scatter pattern matching to only recognize true scalar computation for DUS offset computation. PiperOrigin-RevId: 548544857 --- .../xla/service/reduce_scatter_utils.cc | 29 ++++++++++++------- .../xla/service/reduce_scatter_utils.h | 3 +- 2 files changed, 20 insertions(+), 12 deletions(-) diff --git a/tensorflow/compiler/xla/service/reduce_scatter_utils.cc b/tensorflow/compiler/xla/service/reduce_scatter_utils.cc index c9a55b6d1af4ec..d6508c93c875a4 100644 --- a/tensorflow/compiler/xla/service/reduce_scatter_utils.cc +++ b/tensorflow/compiler/xla/service/reduce_scatter_utils.cc @@ -141,7 +141,8 @@ bool IsPerIdOffsets(absl::Span offsets, // Returns if `offset` == shard_size * id. bool IsPerIdOffset(const HloInstruction* offset, int64_t shard_size, const MapIdToTableOffset& map_id, int64_t group_size, - const HloAllReduceInstruction* ar) { + const HloAllReduceInstruction* ar, + bool true_scalar_for_offset_computation) { const bool iota_group = ar->replica_groups().empty() || (ar->IsCrossModuleAllReduce() && !ar->use_global_device_ids()); @@ -152,6 +153,10 @@ bool IsPerIdOffset(const HloInstruction* offset, int64_t shard_size, VLOG(2) << "Offset is not a scalar " << offset->ToString(); return false; } + if (true_scalar_for_offset_computation && offset->shape().rank() != 0) { + VLOG(2) << "Offset is not a true scalar " << offset->ToString(); + return false; + } int64_t const_operand = -1; if (offset->operand(0)->IsConstant()) { const_operand = 0; @@ -168,7 +173,8 @@ bool IsPerIdOffset(const HloInstruction* offset, int64_t shard_size, return false; } return IsPerIdOffset(offset->operand(1 - const_operand), - shard_size / *multiplier, map_id, group_size, ar); + shard_size / *multiplier, map_id, group_size, ar, + true_scalar_for_offset_computation); } if (shard_size == 1 && iota_group) { bool id_mapping_is_identity = true; @@ -186,16 +192,16 @@ bool IsPerIdOffset(const HloInstruction* offset, int64_t shard_size, if (offset->opcode() == HloOpcode::kBitcast || offset->opcode() == HloOpcode::kReshape || offset->opcode() == HloOpcode::kCopy) { - return IsPerIdOffset(offset->operand(0), shard_size, map_id, group_size, - ar); + return IsPerIdOffset(offset->operand(0), shard_size, map_id, group_size, ar, + true_scalar_for_offset_computation); } if (offset->opcode() == HloOpcode::kConvert && offset->operand(0)->shape().IsInteger() && primitive_util::BitWidth(offset->operand(0)->shape().element_type()) <= primitive_util::BitWidth(offset->shape().element_type())) { - return IsPerIdOffset(offset->operand(0), shard_size, map_id, group_size, - ar); + return IsPerIdOffset(offset->operand(0), shard_size, map_id, group_size, ar, + true_scalar_for_offset_computation); } if (offset->opcode() == HloOpcode::kClamp) { @@ -207,8 +213,8 @@ bool IsPerIdOffset(const HloInstruction* offset, int64_t shard_size, << offset->ToString(); return false; } - return IsPerIdOffset(offset->operand(1), shard_size, map_id, group_size, - ar); + return IsPerIdOffset(offset->operand(1), shard_size, map_id, group_size, ar, + true_scalar_for_offset_computation); } const int64_t num_groups = iota_group ? 1 : ar->replica_groups().size(); @@ -266,7 +272,8 @@ std::optional MatchReduceScatter( const HloAllReduceInstruction* ar, int64_t num_partitions, int64_t num_replicas, bool allow_multiple_split_dims, bool allow_intervening_reshape, int64_t min_rank, - HloPredicate match_partition_id, HloPredicate match_replica_id) { + HloPredicate match_partition_id, HloPredicate match_replica_id, + bool true_scalar_for_offset_computation) { if (!ar->shape().IsArray() || ar->constrain_layout() || (ar->IsCrossModuleAllReduce() && !ar->GetModule()->config().use_spmd_partitioning())) { @@ -470,8 +477,8 @@ std::optional MatchReduceScatter( } else { if (!IsPerIdOffset(user->operand(spec.split_dim + 1), user->dynamic_slice_sizes()[spec.split_dim], map_id, - group_size, ar)) { - VLOG(2) << "IsPerIdOffsets() failed " << ar->ToString(); + group_size, ar, true_scalar_for_offset_computation)) { + VLOG(2) << "IsPerIdOffset() failed " << ar->ToString(); return std::nullopt; } } diff --git a/tensorflow/compiler/xla/service/reduce_scatter_utils.h b/tensorflow/compiler/xla/service/reduce_scatter_utils.h index bbebe3e2ac132d..711ddd3456f019 100644 --- a/tensorflow/compiler/xla/service/reduce_scatter_utils.h +++ b/tensorflow/compiler/xla/service/reduce_scatter_utils.h @@ -38,7 +38,8 @@ std::optional MatchReduceScatter( int64_t num_replicas, bool allow_multiple_split_dims = false, bool allow_intervening_reshape = false, int64_t min_rank = 1, HloPredicate match_partition_id = HloPredicateIsOp, - HloPredicate match_replica_id = HloPredicateIsOp); + HloPredicate match_replica_id = HloPredicateIsOp, + bool true_scalar_for_offset_computation = false); } // namespace xla From 95d702022e3c4e7bd75724893b510328dc3a991a Mon Sep 17 00:00:00 2001 From: Son Tuan Vu Date: Sun, 16 Jul 2023 20:52:43 -0700 Subject: [PATCH 360/376] [XLA:GPU][NFC] Add FindHeroReduction util function PiperOrigin-RevId: 548569074 --- .../xla/service/gpu/hlo_fusion_analysis.cc | 28 ++++++++----------- .../xla/service/gpu/hlo_fusion_analysis.h | 2 +- .../xla/service/gpu/ir_emission_utils.cc | 10 +++++++ .../xla/service/gpu/ir_emission_utils.h | 6 ++++ 4 files changed, 29 insertions(+), 17 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/hlo_fusion_analysis.cc b/tensorflow/compiler/xla/service/gpu/hlo_fusion_analysis.cc index 89483a71a85e15..39619ad46c96ec 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_fusion_analysis.cc +++ b/tensorflow/compiler/xla/service/gpu/hlo_fusion_analysis.cc @@ -334,15 +334,11 @@ const ReductionCodegenInfo* HloFusionAnalysis::GetReductionCodegenInfo() { return &reduction_codegen_info_.value(); } - HloInstruction* first_reduce = - *absl::c_find_if(fusion_roots_, [](HloInstruction* instr) { - return IsReductionFromOrToContiguousDimensions(*instr); - }); + HloInstruction* hero_reduction = + FindHeroReduction(absl::Span(fusion_roots_)); + CHECK_NE(hero_reduction, nullptr); - // We always use the first reduce as representative to construct - // ReductionCodegenInfo, since all the reductions are required to have the - // same shape and layout as verified by `IsFusedReductionOutputConsistent()`. - auto reduction_codegen_info = ComputeReductionCodegenInfo(first_reduce); + auto reduction_codegen_info = ComputeReductionCodegenInfo(hero_reduction); reduction_codegen_info_.emplace(std::move(reduction_codegen_info)); return &reduction_codegen_info_.value(); } @@ -662,10 +658,10 @@ int HloFusionAnalysis::CalculateVirtualThreadScalingFactorForReduction( } ReductionCodegenInfo HloFusionAnalysis::ComputeReductionCodegenInfo( - HloInstruction* first_reduce) const { - Shape input_shape = first_reduce->operand(0)->shape(); + HloInstruction* hero_reduction) const { + Shape input_shape = hero_reduction->operand(0)->shape(); ReductionDimensions reduction_dimensions = - GetReductionKindAndContiguousComponents(*first_reduce); + GetReductionKindAndContiguousComponents(*hero_reduction); VLOG(10) << "is_row_reduction " << reduction_dimensions.is_row_reduction << " " << reduction_dimensions.dimensions[0] << " " << reduction_dimensions.dimensions[1] << " " @@ -683,9 +679,9 @@ ReductionCodegenInfo HloFusionAnalysis::ComputeReductionCodegenInfo( // Use 512 as default block size (threads per block) for row reductions. // For multi-output fusions, reduce the block size further to decrease // register pressure when multiple outputs are computed by each thread. - int64_t max_block_size = - std::max(MinThreadsXRowReduction(first_reduce->GetModule()->config()), - static_cast(512LL / NearestPowerOfTwo(fan_out))); + int64_t max_block_size = std::max( + MinThreadsXRowReduction(hero_reduction->GetModule()->config()), + static_cast(512LL / NearestPowerOfTwo(fan_out))); return std::min(max_block_size, RoundUpTo(CeilOfRatio(reduction_dimensions.dimensions[2], reduction_tiling[2]), @@ -702,7 +698,7 @@ ReductionCodegenInfo HloFusionAnalysis::ComputeReductionCodegenInfo( ProjectedShmemUsageBytes(reduction_dimensions, instr_index_groups); const int64_t shmem_budget = device_info_->shared_memory_per_block; bool reduction_is_race_free = ReductionIsRaceFree( - first_reduce->GetModule()->config(), reduction_dimensions); + hero_reduction->GetModule()->config(), reduction_dimensions); bool vectorize = // Vectorization might cause us to run out of budget. (shmem_usage * 2 <= shmem_budget) && @@ -764,7 +760,7 @@ ReductionCodegenInfo HloFusionAnalysis::ComputeReductionCodegenInfo( virtual_thread_scaling_factor); return ReductionCodegenInfo( tiling_scheme, num_partial_results, reduction_dimensions.is_row_reduction, - reduction_is_race_free, std::move(instr_index_groups), first_reduce); + reduction_is_race_free, std::move(instr_index_groups), hero_reduction); } } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/hlo_fusion_analysis.h b/tensorflow/compiler/xla/service/gpu/hlo_fusion_analysis.h index 225523e6901a02..035916f8282fc7 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_fusion_analysis.h +++ b/tensorflow/compiler/xla/service/gpu/hlo_fusion_analysis.h @@ -123,7 +123,7 @@ class HloFusionAnalysis { int CalculateVirtualThreadScalingFactorForReduction( const ReductionDimensions& reduction_dimensions) const; ReductionCodegenInfo ComputeReductionCodegenInfo( - HloInstruction* first_reduce) const; + HloInstruction* hero_reduction) const; bool HasConsistentTransposeHeros() const; const HloFusionInstruction* fusion_; diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc index 7a8db3574b7a16..7f9f8e6593fd06 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc @@ -997,6 +997,16 @@ bool HasAnyUnnestedReductionRoot(HloComputation* computation) { }); } +HloInstruction* FindHeroReduction(absl::Span roots) { + auto it = absl::c_find_if(roots, [](HloInstruction* instr) { + return IsReductionFromOrToContiguousDimensions(*instr); + }); + if (it == roots.end()) { + return nullptr; + } + return *it; +} + void LogAndVerify(const llvm::Module* m) { if (VLOG_IS_ON(5)) { XLA_VLOG_LINES(5, llvm_ir::DumpToString(m)); diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h index 4df85639667806..29fb463544d373 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h @@ -227,6 +227,12 @@ std::vector GetFusionRoots(HloComputation* computation); // reduction emitter. bool HasAnyUnnestedReductionRoot(HloComputation* computation); +// Returns the hero reduction of the computation. +// We always use the first reduce root that triggers unnested reduction emitter +// as the hero reduction, since all the reductions are required to have the same +// shape and layout as verified by `IsFusedReductionOutputConsistent()`. +HloInstruction* FindHeroReduction(absl::Span roots); + const HloInstruction& FindNonTrivialHero(const HloInstruction& instr); // Whether there is a fusion root triggering transposition emitter. From ccc45ff41dffc8074f0f738a9128941e6b5fcfff Mon Sep 17 00:00:00 2001 From: Dan Suh Date: Mon, 17 Jul 2023 00:28:49 -0700 Subject: [PATCH 361/376] Modify the signature of `QuantizeModel` to accept the buffer directly instead of `ModelT` Instead of accepting `ModelT` for quantization, it uses the raw model buffer as a parameter to `QuantizeModel` in order to support the case where the "offset buffer" exists. Offset buffers are enabled by `_experimental_use_buffer_offset`. This case happens usually when the model size is larger than 2GB for which the concept of offset buffer is introduced to avoid the flatbuffer's serialization limit of 2GB (see the previous commit that introduced this support for compiler import / export: https://github.com/tensorflow/tensorflow/commit/909b28382c31e3ada3872c11a767e64144a56db6) PiperOrigin-RevId: 548599861 --- .../mlir/lite/quantization/lite/BUILD | 3 +- .../lite/quantization/lite/quantize_model.cc | 28 +-- .../lite/quantization/lite/quantize_model.h | 32 ++- .../quantization/lite/quantize_model_test.cc | 222 +++++++++--------- .../lite/quantization/lite/tfl_quantizer.cc | 33 +-- tensorflow/lite/toco/python/BUILD | 1 + .../lite/toco/python/toco_python_api.cc | 25 +- 7 files changed, 161 insertions(+), 183 deletions(-) diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/BUILD b/tensorflow/compiler/mlir/lite/quantization/lite/BUILD index 20346adde817c7..93c50fed86f77a 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/quantization/lite/BUILD @@ -30,18 +30,17 @@ cc_library( "//tensorflow/compiler/mlir/lite:common", "//tensorflow/compiler/mlir/lite:flatbuffer_translate_lib", "//tensorflow/compiler/mlir/lite:tensorflow_lite", - "//tensorflow/compiler/mlir/lite:tensorflow_lite_quantize", "//tensorflow/compiler/mlir/lite:tf_tfl_passes", "//tensorflow/compiler/mlir/lite/quantization:quantization_config", "//tensorflow/compiler/mlir/tensorflow:error_util", "//tensorflow/core:protos_all_cc", "//tensorflow/lite:framework", + "//tensorflow/lite/c:c_api_types", "//tensorflow/lite/core/api", "//tensorflow/lite/schema:schema_fbs", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", - "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", ], diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc index 7581b5c78cfbcd..3d7503cd64128b 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc +++ b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc @@ -20,9 +20,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/strings/string_view.h" -#include "llvm/ADT/SmallVector.h" #include "llvm/Support/Debug.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/Location.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project @@ -38,6 +36,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/utils/convert_type.h" #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" #include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/schema/schema_generated.h" namespace mlir { @@ -50,12 +49,11 @@ std::string TfLiteToMlir(const absl::string_view tflite_op_name) { // TODO(fengliuai): check the result for `fully_quantize` flag. TfLiteStatus QuantizeModel( - const tflite::ModelT& input_model, const tflite::TensorType& input_type, + const absl::string_view model_buffer, const tflite::TensorType& input_type, const tflite::TensorType& output_type, const tflite::TensorType& inference_type, const std::unordered_set& operator_names, - bool disable_per_channel, bool fully_quantize, - flatbuffers::FlatBufferBuilder* builder, + bool disable_per_channel, bool fully_quantize, std::string& output_buffer, tflite::ErrorReporter* error_reporter, bool verify_numeric, bool whole_model_verify, bool legacy_float_scale, const absl::flat_hash_set& denylisted_ops, @@ -73,18 +71,8 @@ TfLiteStatus QuantizeModel( StatusScopedDiagnosticHandler statusHandler(&context, /*propagate=*/true); - // Import input_model to a MLIR module - flatbuffers::FlatBufferBuilder input_builder; - flatbuffers::Offset input_model_location = - tflite::Model::Pack(input_builder, &input_model); - tflite::FinishModelBuffer(input_builder, input_model_location); - - std::string serialized_model( - reinterpret_cast(input_builder.GetBufferPointer()), - input_builder.GetSize()); - OwningOpRef module = tflite::FlatBufferToMlir( - serialized_model, &context, UnknownLoc::get(&context)); + model_buffer, &context, UnknownLoc::get(&context)); if (!module) { error_reporter->Report("Couldn't import flatbuffer to MLIR."); return kTfLiteError; @@ -130,20 +118,16 @@ TfLiteStatus QuantizeModel( return kTfLiteError; } - // Export the results to the builder - std::string result; + // Export the results. tflite::FlatbufferExportOptions options; options.toco_flags.set_force_select_tf_ops(false); options.toco_flags.set_enable_select_tf_ops(true); options.toco_flags.set_allow_custom_ops(true); if (!tflite::MlirToFlatBufferTranslateFunction(module.get(), options, - &result)) { + &output_buffer)) { error_reporter->Report("Failed to export MLIR to flatbuffer."); return kTfLiteError; } - builder->PushFlatBuffer(reinterpret_cast(result.data()), - result.size()); - return kTfLiteOk; } diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.h b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.h index 243af219da689b..d85aba47811675 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.h +++ b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.h @@ -15,39 +15,47 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_LITE_QUANTIZE_MODEL_H_ #define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_LITE_QUANTIZE_MODEL_H_ -#include #include #include #include "absl/container/flat_hash_set.h" +#include "absl/strings/string_view.h" +#include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/core/api/error_reporter.h" -#include "tensorflow/lite/model.h" #include "tensorflow/lite/schema/schema_generated.h" namespace mlir { namespace lite { -// Quantize the `input_model` and write the result to a flatbuffer `builder`. -// The `input_type`, `output_type` and `inference_type` can be -// float32/qint8/int8/int16. -// Return partially quantized model if `fully_quantize` is false. +// Quantizes the input model represented as `model_buffer` and writes the result +// to the `output_buffer`. Both `model_buffer` and `output_buffer` should be a +// valid FlatBuffer format for Model supported by TFLite. +// +// The `input_type`, `output_type` and `inference_type` can be float32 / qint8 / +// int8 / int16. +// +// Returns a partially quantized model if `fully_quantize` is false. Returns a +// non-OK status if the quantization fails. +// // When `verify_numeric` is true, the model will have it's original float ops // and NumericVerify ops to compare output values from the quantized and float -// ops. When `legacy_float_scale` is true, the quantizer will use float scale -// instead of double, and call TOCO's quantization routines to maintain -// bit-exactness of the values with the TOCO quantizer. +// ops. +// +// When `legacy_float_scale` is true, the quantizer will use float scale instead +// of double, and call TOCO's quantization routines to maintain bit-exactness of +// the values with the TOCO quantizer. TfLiteStatus QuantizeModel( - const tflite::ModelT& input_model, const tflite::TensorType& input_type, + absl::string_view model_buffer, const tflite::TensorType& input_type, const tflite::TensorType& output_type, const tflite::TensorType& inference_type, const std::unordered_set& operator_names, - bool disable_per_channel, bool fully_quantize, - flatbuffers::FlatBufferBuilder* builder, + bool disable_per_channel, bool fully_quantize, std::string& output_buffer, tflite::ErrorReporter* error_reporter, bool verify_numeric = false, bool whole_model_verify = false, bool legacy_float_scale = true, const absl::flat_hash_set& denylisted_ops = {}, const absl::flat_hash_set& denylisted_nodes = {}, bool enable_variable_quantization = false); + } // namespace lite } // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model_test.cc b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model_test.cc index 798e011dec247d..ee9c5e7852aea9 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model_test.cc +++ b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model_test.cc @@ -67,74 +67,77 @@ ModelT UnPackFlatBufferModel(const Model& flatbuffer_model) { } TfLiteStatus QuantizeModel( - flatbuffers::FlatBufferBuilder* builder, ModelT* model, - const TensorType& input_type, const TensorType& output_type, - bool allow_float, const std::unordered_set& operator_names, + ModelT* model, const TensorType& input_type, const TensorType& output_type, + const bool allow_float, const std::unordered_set& operator_names, const TensorType& activations_type, ErrorReporter* error_reporter, - bool disable_per_channel = false, + std::string& output_buffer, const bool disable_per_channel = false, const absl::flat_hash_set& blocked_ops = {}, const absl::flat_hash_set& blocked_nodes = {}) { TensorType inference_tensor_type = activations_type; - bool fully_quantize = !allow_float; + const bool fully_quantize = !allow_float; + + flatbuffers::FlatBufferBuilder input_builder; + tflite::FinishModelBuffer(input_builder, + tflite::Model::Pack(input_builder, model)); + + const std::string input_buffer( + reinterpret_cast(input_builder.GetBufferPointer()), + input_builder.GetSize()); auto status = mlir::lite::QuantizeModel( - *model, input_type, output_type, inference_tensor_type, - /*operator_names=*/{}, disable_per_channel, fully_quantize, builder, + input_buffer, input_type, output_type, inference_tensor_type, + /*operator_names=*/{}, disable_per_channel, fully_quantize, output_buffer, error_reporter, /*verify_numeric=*/false, /*whole_model_verify=*/false, /*legacy_float_scale=*/true, blocked_ops, blocked_nodes); if (status != kTfLiteOk) { return status; } - std::string buffer( - reinterpret_cast(builder->GetCurrentBufferPointer()), - builder->GetSize()); - auto flatbuffer_model = - FlatBufferModel::BuildFromBuffer(buffer.c_str(), buffer.size()); + auto flatbuffer_model = FlatBufferModel::BuildFromBuffer( + output_buffer.data(), output_buffer.size()); *model = UnPackFlatBufferModel(*flatbuffer_model->GetModel()); return kTfLiteOk; } -TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder, - ModelT* model, const TensorType& input_type, +TfLiteStatus QuantizeModel(ModelT* model, const TensorType& input_type, const TensorType& output_type, bool allow_float, - ErrorReporter* error_reporter) { - return QuantizeModel(builder, model, input_type, output_type, allow_float, - /*operator_names=*/{}, TensorType_INT8, error_reporter); + ErrorReporter* error_reporter, + std::string& output_buffer) { + return QuantizeModel(model, input_type, output_type, allow_float, + /*operator_names=*/{}, TensorType_INT8, error_reporter, + output_buffer); } -TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder, - ModelT* model, const TensorType& input_type, +TfLiteStatus QuantizeModel(ModelT* model, const TensorType& input_type, const TensorType& output_type, - ErrorReporter* error_reporter) { - return QuantizeModel(builder, model, input_type, output_type, - /*allow_float=*/false, error_reporter); + ErrorReporter* error_reporter, + std::string& output_buffer) { + return QuantizeModel(model, input_type, output_type, + /*allow_float=*/false, error_reporter, output_buffer); } -TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder, - ModelT* model, ErrorReporter* error_reporter) { - return QuantizeModel(builder, model, TensorType_FLOAT32, TensorType_FLOAT32, - /*allow_float=*/true, error_reporter); +TfLiteStatus QuantizeModel(ModelT* model, ErrorReporter* error_reporter, + std::string& output_buffer) { + return QuantizeModel(model, TensorType_FLOAT32, TensorType_FLOAT32, + /*allow_float=*/true, error_reporter, output_buffer); } TfLiteStatus QuantizeModelAllOperators( - flatbuffers::FlatBufferBuilder* builder, ModelT* model, - const TensorType& input_type, const TensorType& output_type, + ModelT* model, const TensorType& input_type, const TensorType& output_type, bool allow_float, const TensorType& activations_type, - bool disable_per_channel, ErrorReporter* error_reporter) { - return QuantizeModel(builder, model, input_type, output_type, allow_float, + bool disable_per_channel, ErrorReporter* error_reporter, + std::string& output_buffer) { + return QuantizeModel(model, input_type, output_type, allow_float, /*operator_names=*/{}, activations_type, error_reporter, - disable_per_channel); + output_buffer, disable_per_channel); } -TfLiteStatus QuantizeModelAllOperators(flatbuffers::FlatBufferBuilder* builder, - ModelT* model, - const TensorType& input_type, - const TensorType& output_type, - bool allow_float, - const TensorType& activations_type, - ErrorReporter* error_reporter) { - return QuantizeModel(builder, model, input_type, output_type, allow_float, - /*operator_names=*/{}, activations_type, error_reporter); +TfLiteStatus QuantizeModelAllOperators( + ModelT* model, const TensorType& input_type, const TensorType& output_type, + bool allow_float, const TensorType& activations_type, + ErrorReporter* error_reporter, std::string& output_buffer) { + return QuantizeModel(model, input_type, output_type, allow_float, + /*operator_names=*/{}, activations_type, error_reporter, + output_buffer); } std::unique_ptr ReadModel(const string& model_name) { @@ -180,8 +183,8 @@ class QuantizeModelTest : public testing::Test { std::unique_ptr input_model_; const Model* readonly_model_; tflite::ModelT model_; - flatbuffers::FlatBufferBuilder builder_; internal::FailOnErrorReporter error_reporter_; + std::string output_buffer_; // Raw buffer for quantized output model. }; void ExpectEqualTensor(TensorT* tensor, TensorT* expected_tensor) { @@ -279,20 +282,21 @@ INSTANTIATE_TEST_SUITE_P(QuantizeConvModelTestInst, QuantizeConvModelTest, testing::ValuesIn({TensorType_INT8})); TEST_P(QuantizeConvModelTest, QuantizationSucceeds) { - auto status = QuantizeModelAllOperators( - &builder_, &model_, tensor_type_, tensor_type_, /*allow_float=*/false, - tensor_type_, &error_reporter_); + auto status = QuantizeModelAllOperators(&model_, tensor_type_, tensor_type_, + /*allow_float=*/false, tensor_type_, + &error_reporter_, output_buffer_); EXPECT_THAT(status, Eq(kTfLiteOk)); - const uint8_t* buffer = builder_.GetBufferPointer(); - const Model* output_model = GetModel(buffer); + + const Model* output_model = GetModel(output_buffer_.data()); ASSERT_TRUE(output_model); } TEST_P(QuantizeConvModelTest, SkipUnspecifiedLayer) { - auto status = QuantizeModel( - &builder_, &model_, TensorType_FLOAT32, TensorType_FLOAT32, - /*allow_float=*/true, /*operator_names=*/{}, TensorType_FLOAT32, - &error_reporter_, /*disable_per_channel=*/false, {"CONV_2D"}); + auto status = + QuantizeModel(&model_, TensorType_FLOAT32, TensorType_FLOAT32, + /*allow_float=*/true, /*operator_names=*/{}, + TensorType_FLOAT32, &error_reporter_, output_buffer_, + /*disable_per_channel=*/false, {"CONV_2D"}); EXPECT_THAT(status, Eq(kTfLiteOk)); ModelT expected_model; @@ -302,11 +306,11 @@ TEST_P(QuantizeConvModelTest, SkipUnspecifiedLayer) { } TEST_P(QuantizeConvModelTest, SkipUnspecifiedLayerByName) { - auto status = QuantizeModel( - &builder_, &model_, TensorType_FLOAT32, TensorType_FLOAT32, - /*allow_float=*/true, /*operator_names=*/{}, TensorType_FLOAT32, - &error_reporter_, /*disable_per_channel=*/false, /*blocked_ops=*/{}, - {"output"}); + auto status = QuantizeModel(&model_, TensorType_FLOAT32, TensorType_FLOAT32, + /*allow_float=*/true, /*operator_names=*/{}, + TensorType_FLOAT32, &error_reporter_, + output_buffer_, /*disable_per_channel=*/false, + /*blocked_ops=*/{}, {"output"}); EXPECT_THAT(status, Eq(kTfLiteOk)); ModelT expected_model; @@ -316,9 +320,9 @@ TEST_P(QuantizeConvModelTest, SkipUnspecifiedLayerByName) { } TEST_P(QuantizeConvModelTest, GraphIsFullyQuantized) { - auto status = QuantizeModelAllOperators( - &builder_, &model_, tensor_type_, tensor_type_, - /*allow_float=*/false, tensor_type_, &error_reporter_); + auto status = QuantizeModelAllOperators(&model_, tensor_type_, tensor_type_, + /*allow_float=*/false, tensor_type_, + &error_reporter_, output_buffer_); EXPECT_THAT(status, Eq(kTfLiteOk)); for (const auto& subgraph : model_.subgraphs) { @@ -340,11 +344,10 @@ class QuantizeConvNoBiasModelTest : public QuantizeModelTest { TEST_F(QuantizeConvNoBiasModelTest, QuantizationSucceeds) { auto status = QuantizeModelAllOperators( - &builder_, &model_, TensorType_INT8, TensorType_INT8, - /*allow_float=*/false, TensorType_INT8, &error_reporter_); + &model_, TensorType_INT8, TensorType_INT8, + /*allow_float=*/false, TensorType_INT8, &error_reporter_, output_buffer_); EXPECT_THAT(status, Eq(kTfLiteOk)); - const uint8_t* buffer = builder_.GetBufferPointer(); - const Model* output_model = GetModel(buffer); + const Model* output_model = GetModel(output_buffer_.data()); ASSERT_TRUE(output_model); } @@ -361,8 +364,8 @@ class QuantizeSplitModelTest : public QuantizeModelTest { // should have the scales be hardcodes to the input scale value. TEST_F(QuantizeSplitModelTest, QuantizeSplit) { auto status = QuantizeModelAllOperators( - &builder_, &model_, TensorType_INT8, TensorType_INT8, - /*allow_float=*/false, TensorType_INT8, &error_reporter_); + &model_, TensorType_INT8, TensorType_INT8, /*allow_float=*/false, + TensorType_INT8, &error_reporter_, output_buffer_); EXPECT_THAT(status, Eq(kTfLiteOk)); // There is only one subgraph. @@ -458,9 +461,9 @@ INSTANTIATE_TEST_SUITE_P(QuantizeConvModel2TestInst, QuantizeConvModel2Test, testing::ValuesIn({TensorType_INT8})); TEST_P(QuantizeConvModel2Test, VerifyConvQuantization) { - auto status = QuantizeModelAllOperators( - &builder_, &model_, tensor_type_, tensor_type_, /*allow_float=*/false, - tensor_type_, &error_reporter_); + auto status = QuantizeModelAllOperators(&model_, tensor_type_, tensor_type_, + /*allow_float=*/false, tensor_type_, + &error_reporter_, output_buffer_); ASSERT_THAT(status, Eq(kTfLiteOk)); const auto& subgraph = model_.subgraphs[0]; auto conv_op = subgraph->operators[0].get(); @@ -566,8 +569,8 @@ TEST_P(QuantizeConvModel2Test, VerifyConvQuantization) { TEST_P(QuantizeConvModel2Test, VerifyConvDisablePerChannelQuantization) { auto status = QuantizeModelAllOperators( - &builder_, &model_, tensor_type_, tensor_type_, /*allow_float=*/false, - tensor_type_, /*disable_per_channel=*/true, &error_reporter_); + &model_, tensor_type_, tensor_type_, /*allow_float=*/false, tensor_type_, + /*disable_per_channel=*/true, &error_reporter_, output_buffer_); ASSERT_THAT(status, Eq(kTfLiteOk)); const auto& subgraph = model_.subgraphs[0]; auto conv_op = subgraph->operators[0].get(); @@ -684,8 +687,8 @@ class QuantizeSoftmaxTest : public QuantizeModelTest { TEST_F(QuantizeSoftmaxTest, VerifySoftmaxQuantization) { auto status = QuantizeModelAllOperators( - &builder_, &model_, TensorType_INT8, TensorType_INT8, - /*allow_float=*/false, TensorType_INT8, &error_reporter_); + &model_, TensorType_INT8, TensorType_INT8, /*allow_float=*/false, + TensorType_INT8, &error_reporter_, output_buffer_); ASSERT_THAT(status, Eq(kTfLiteOk)); const auto& subgraph = model_.subgraphs[0]; @@ -748,8 +751,8 @@ class QuantizeAvgPoolTest : public QuantizeModelTest { TEST_F(QuantizeAvgPoolTest, VerifyAvgPoolQuantization) { auto status = QuantizeModelAllOperators( - &builder_, &model_, TensorType_INT8, TensorType_INT8, - /*allow_float=*/false, TensorType_INT8, &error_reporter_); + &model_, TensorType_INT8, TensorType_INT8, /*allow_float=*/false, + TensorType_INT8, &error_reporter_, output_buffer_); ASSERT_THAT(status, Eq(kTfLiteOk)); const auto& subgraph = model_.subgraphs[0]; @@ -809,8 +812,8 @@ class QuantizeMultiInputAddWithReshapeTest : public QuantizeModelTest { TEST_F(QuantizeMultiInputAddWithReshapeTest, VerifyReshapeQuantization) { auto status = QuantizeModelAllOperators( - &builder_, &model_, TensorType_INT8, TensorType_INT8, - /*allow_float=*/false, TensorType_INT8, &error_reporter_); + &model_, TensorType_INT8, TensorType_INT8, /*allow_float=*/false, + TensorType_INT8, &error_reporter_, output_buffer_); ASSERT_THAT(status, Eq(kTfLiteOk)); @@ -861,8 +864,8 @@ TEST_F(QuantizeMultiInputAddWithReshapeTest, VerifyReshapeQuantization) { TEST_F(QuantizeMultiInputAddWithReshapeTest, VerifyAddQuantization) { auto status = QuantizeModelAllOperators( - &builder_, &model_, TensorType_INT8, TensorType_INT8, - /*allow_float=*/false, TensorType_INT8, &error_reporter_); + &model_, TensorType_INT8, TensorType_INT8, /*allow_float=*/false, + TensorType_INT8, &error_reporter_, output_buffer_); ASSERT_THAT(status, Eq(kTfLiteOk)); // Verify ADD is quantized. @@ -935,10 +938,9 @@ INSTANTIATE_TEST_SUITE_P(QuantizeConstInputTestInst, QuantizeConstInputTest, testing::ValuesIn({TensorType_INT8})); TEST_P(QuantizeConstInputTest, VerifyConstOpInput) { - auto status = - QuantizeModelAllOperators( - &builder_, &model_, tensor_type_, tensor_type_, /*allow_float=*/false, - tensor_type_, &error_reporter_); + auto status = QuantizeModelAllOperators(&model_, tensor_type_, tensor_type_, + /*allow_float=*/false, tensor_type_, + &error_reporter_, output_buffer_); ASSERT_THAT(status, Eq(kTfLiteOk)); // Verify ConstOp is quantized. @@ -981,8 +983,8 @@ class QuantizeArgMaxTest : public QuantizeModelTest { TEST_F(QuantizeArgMaxTest, VerifyArgMax) { auto status = QuantizeModelAllOperators( - &builder_, &model_, TensorType_INT8, TensorType_INT8, - /*allow_float=*/false, TensorType_INT8, &error_reporter_); + &model_, TensorType_INT8, TensorType_INT8, /*allow_float=*/false, + TensorType_INT8, &error_reporter_, output_buffer_); ASSERT_THAT(status, Eq(kTfLiteOk)); const auto& subgraph = model_.subgraphs[0]; @@ -1026,10 +1028,9 @@ class QuantizeLSTMTest : public QuantizeModelTest { }; TEST_F(QuantizeLSTMTest, VerifyLSTM) { - // Quantize model. auto status = QuantizeModelAllOperators( - &builder_, &model_, TensorType_FLOAT32, TensorType_FLOAT32, true, - TensorType_INT8, &error_reporter_); + &model_, TensorType_FLOAT32, TensorType_FLOAT32, /*allow_float=*/true, + TensorType_INT8, &error_reporter_, output_buffer_); ASSERT_THAT(status, Eq(kTfLiteOk)); // Read expected model. @@ -1053,8 +1054,8 @@ class QuantizeLSTM2Test : public QuantizeModelTest { TEST_F(QuantizeLSTM2Test, VerifyLSTM) { // Quantize model. auto status = QuantizeModelAllOperators( - &builder_, &model_, TensorType_FLOAT32, TensorType_FLOAT32, - /*allow_float=*/false, TensorType_INT8, &error_reporter_); + &model_, TensorType_FLOAT32, TensorType_FLOAT32, + /*allow_float=*/false, TensorType_INT8, &error_reporter_, output_buffer_); ASSERT_THAT(status, Eq(kTfLiteOk)); // Read expected model. @@ -1077,10 +1078,9 @@ class QuantizeUnidirectionalSequenceLSTMTest : public QuantizeModelTest { TEST_F(QuantizeUnidirectionalSequenceLSTMTest, VerifyUnidirectionalSequenceLSTM) { - // Quantize model. auto status = QuantizeModelAllOperators( - &builder_, &model_, TensorType_FLOAT32, TensorType_FLOAT32, - /*allow_float=*/false, TensorType_INT8, &error_reporter_); + &model_, TensorType_FLOAT32, TensorType_FLOAT32, /*allow_float=*/false, + TensorType_INT8, &error_reporter_, output_buffer_); ASSERT_THAT(status, Eq(kTfLiteOk)); // Read expected model. @@ -1105,8 +1105,8 @@ class QuantizeSVDFTest : public QuantizeModelTest { TEST_F(QuantizeSVDFTest, VerifySVDF) { // Quantize model. auto status = QuantizeModelAllOperators( - &builder_, &model_, TensorType_INT8, TensorType_INT8, - /*allow_float=*/false, TensorType_INT8, &error_reporter_); + &model_, TensorType_INT8, TensorType_INT8, /*allow_float=*/false, + TensorType_INT8, &error_reporter_, output_buffer_); ASSERT_THAT(status, Eq(kTfLiteOk)); // Read expected model. @@ -1129,8 +1129,8 @@ class QuantizeFCTest : public QuantizeModelTest { TEST_F(QuantizeFCTest, VerifyFC8x8) { auto status = QuantizeModelAllOperators( - &builder_, &model_, TensorType_INT8, TensorType_INT8, - /*allow_float=*/false, TensorType_INT8, &error_reporter_); + &model_, TensorType_INT8, TensorType_INT8, /*allow_float=*/false, + TensorType_INT8, &error_reporter_, output_buffer_); ASSERT_THAT(status, Eq(kTfLiteOk)); const auto& subgraph = model_.subgraphs[0]; @@ -1182,8 +1182,8 @@ TEST_F(QuantizeFCTest, VerifyFC8x8) { TEST_F(QuantizeFCTest, VerifyFCFor16x8) { auto status = QuantizeModelAllOperators( - &builder_, &model_, TensorType_INT8, TensorType_INT8, - /*allow_float=*/false, TensorType_INT16, &error_reporter_); + &model_, TensorType_INT8, TensorType_INT8, /*allow_float=*/false, + TensorType_INT16, &error_reporter_, output_buffer_); ASSERT_THAT(status, Eq(kTfLiteOk)); const std::unique_ptr& subgraph = model_.subgraphs[0]; @@ -1247,9 +1247,9 @@ class QuantizeCustomOpTest }; TEST_P(QuantizeCustomOpTest, VerifyMixedQuantization) { - auto status = QuantizeModelAllOperators( - &builder_, &model_, GetParam(), GetParam(), - /*allow_float=*/true, GetParam(), &error_reporter_); + auto status = QuantizeModelAllOperators(&model_, GetParam(), GetParam(), + /*allow_float=*/true, GetParam(), + &error_reporter_, output_buffer_); ASSERT_THAT(status, Eq(kTfLiteOk)); const auto& subgraph = model_.subgraphs[0]; auto float_graph = readonly_model_->subgraphs()->Get(0); @@ -1286,7 +1286,7 @@ class QuantizePackTest : public QuantizeModelTest { }; TEST_F(QuantizePackTest, VerifyPack) { - auto status = QuantizeModel(&builder_, &model_, &error_reporter_); + auto status = QuantizeModel(&model_, &error_reporter_, output_buffer_); ASSERT_THAT(status, Eq(kTfLiteOk)); @@ -1350,7 +1350,7 @@ class QuantizeMinimumMaximumTest }; TEST_P(QuantizeMinimumMaximumTest, VerifyMinimumMaximum) { - auto status = QuantizeModel(&builder_, &model_, &error_reporter_); + auto status = QuantizeModel(&model_, &error_reporter_, output_buffer_); ASSERT_THAT(status, Eq(kTfLiteOk)); const auto& subgraph = model_.subgraphs[0]; // Check that the first op is Quantize and the last is Dequant. @@ -1413,7 +1413,7 @@ class QuantizeUnpackTest : public QuantizeModelTest { }; TEST_F(QuantizeUnpackTest, VerifyUnpack) { - auto status = QuantizeModel(&builder_, &model_, &error_reporter_); + auto status = QuantizeModel(&model_, &error_reporter_, output_buffer_); ASSERT_THAT(status, Eq(kTfLiteOk)); @@ -1470,9 +1470,9 @@ INSTANTIATE_TEST_SUITE_P(QuantizeBroadcastToModelTestInst, testing::ValuesIn({TensorType_INT8})); TEST_P(QuantizeBroadcastToModelTest, VerifyBroadcastToQuantization) { - auto status = QuantizeModelAllOperators( - &builder_, &model_, tensor_type_, tensor_type_, /*allow_float=*/false, - tensor_type_, &error_reporter_); + auto status = QuantizeModelAllOperators(&model_, tensor_type_, tensor_type_, + /*allow_float=*/false, tensor_type_, + &error_reporter_, output_buffer_); EXPECT_THAT(status, Eq(kTfLiteOk)); // There is only one subgraph. @@ -1537,9 +1537,9 @@ INSTANTIATE_TEST_SUITE_P(QuantizeGatherNDModelTestInst, testing::ValuesIn({TensorType_INT8})); TEST_P(QuantizeGatherNDModelTest, QuantizeGatherND) { - auto status = QuantizeModelAllOperators( - &builder_, &model_, tensor_type_, tensor_type_, /*allow_float=*/false, - tensor_type_, &error_reporter_); + auto status = QuantizeModelAllOperators(&model_, tensor_type_, tensor_type_, + /*allow_float=*/false, tensor_type_, + &error_reporter_, output_buffer_); EXPECT_THAT(status, Eq(kTfLiteOk)); // There is only one subgraph. @@ -1596,8 +1596,8 @@ TEST_F(QuantizeWhereModelTest, QuantizeWhere) { // Where operator takes a BOOL tensor as input // and outputs INT64 indices, both of which // should not be quantized - auto status = QuantizeModel(&builder_, &model_, TensorType_BOOL, - TensorType_INT64, &error_reporter_); + auto status = QuantizeModel(&model_, TensorType_BOOL, TensorType_INT64, + &error_reporter_, output_buffer_); EXPECT_THAT(status, Eq(kTfLiteOk)); // There is only one subgraph. diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/tfl_quantizer.cc b/tensorflow/compiler/mlir/lite/quantization/lite/tfl_quantizer.cc index fe5ca2ca8f1d47..4bf154e892bcdb 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/tfl_quantizer.cc +++ b/tensorflow/compiler/mlir/lite/quantization/lite/tfl_quantizer.cc @@ -14,10 +14,7 @@ limitations under the License. ==============================================================================*/ #include -#include -#include -#include "absl/strings/string_view.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/InitLLVM.h" #include "llvm/Support/MemoryBuffer.h" @@ -36,21 +33,14 @@ static opt inputFileName(llvm::cl::Positional, namespace mlir { namespace { -TfLiteStatus QuantizeAnnotatedModel(llvm::StringRef buffer, - flatbuffers::FlatBufferBuilder* builder) { - auto model_ptr = tflite::FlatBufferModel::VerifyAndBuildFromBuffer( - buffer.data(), buffer.size()); - if (nullptr == model_ptr) { - return TfLiteStatus::kTfLiteError; - } - std::unique_ptr model(model_ptr->GetModel()->UnPack()); +TfLiteStatus QuantizeAnnotatedModel(llvm::StringRef buffer, + std::string& output_buffer) { tflite::StderrReporter error_reporter; return mlir::lite::QuantizeModel( - *model, tflite::TensorType_INT8, tflite::TensorType_INT8, - tflite::TensorType_INT8, {}, - /*disable_per_channel=*/false, - /*fully_quantize=*/true, builder, &error_reporter); + buffer, tflite::TensorType_INT8, tflite::TensorType_INT8, + tflite::TensorType_INT8, {}, /*disable_per_channel=*/false, + /*fully_quantize=*/true, output_buffer, &error_reporter); } } // namespace @@ -66,16 +56,13 @@ int main(int argc, char** argv) { return 1; } auto buffer = file_or_err->get(); - flatbuffers::FlatBufferBuilder builder; - auto status = - mlir::QuantizeAnnotatedModel(buffer->getBuffer().str(), &builder); - if (status != kTfLiteOk) { + std::string output_buffer; + if (auto status = mlir::QuantizeAnnotatedModel(buffer->getBuffer().str(), + output_buffer); + status != kTfLiteOk) { return 1; } - std::cout << std::string( - reinterpret_cast(builder.GetBufferPointer()), - builder.GetSize()) - << "\n"; + std::cout << output_buffer << "\n"; return 0; } diff --git a/tensorflow/lite/toco/python/BUILD b/tensorflow/lite/toco/python/BUILD index 3a172ea0613691..a75d7351a8d593 100644 --- a/tensorflow/lite/toco/python/BUILD +++ b/tensorflow/lite/toco/python/BUILD @@ -47,6 +47,7 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "//tensorflow/lite:model_builder", "//tensorflow/lite/core/api", "//tensorflow/lite/core/c:common", "//tensorflow/lite/python/interpreter_wrapper:python_error_reporter", diff --git a/tensorflow/lite/toco/python/toco_python_api.cc b/tensorflow/lite/toco/python/toco_python_api.cc index c6339e81cb0080..48af2bdce7cb9a 100644 --- a/tensorflow/lite/toco/python/toco_python_api.cc +++ b/tensorflow/lite/toco/python/toco_python_api.cc @@ -15,7 +15,6 @@ limitations under the License. #include "tensorflow/lite/toco/python/toco_python_api.h" #include -#include #include #include #include @@ -32,9 +31,9 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/sparsity/sparsify_model.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_def.pb.h" -#include "tensorflow/core/platform/logging.h" #include "tensorflow/lite/core/api/error_reporter.h" #include "tensorflow/lite/core/c/common.h" +#include "tensorflow/lite/model_builder.h" #include "tensorflow/lite/python/interpreter_wrapper/python_error_reporter.h" #include "tensorflow/lite/python/interpreter_wrapper/python_utils.h" #include "tensorflow/lite/schema/schema_generated.h" @@ -327,28 +326,28 @@ PyObject* MlirQuantizeModel(PyObject* data, bool disable_per_channel, auto tflite_model = std::make_unique(); model->GetModel()->UnPackTo(tflite_model.get(), nullptr); - tflite::TensorType inference_tensor_type = + const tflite::TensorType inference_tensor_type = FromTocoDataTypeToTflitToTensorType(inference_type); - tflite::TensorType input_type = + const tflite::TensorType input_type = FromTocoDataTypeToTflitToTensorType(input_data_type); - tflite::TensorType output_type = + const tflite::TensorType output_type = FromTocoDataTypeToTflitToTensorType(output_data_type); - flatbuffers::FlatBufferBuilder builder; + std::string output_model; + const absl::string_view input_model_buffer(buf, length); auto status = mlir::lite::QuantizeModel( - *tflite_model, input_type, output_type, inference_tensor_type, {}, - disable_per_channel, fully_quantize, &builder, error_reporter.get(), - enable_numeric_verify, enable_whole_model_verify, + input_model_buffer, input_type, output_type, inference_tensor_type, + /*operator_names=*/{}, disable_per_channel, fully_quantize, output_model, + error_reporter.get(), enable_numeric_verify, enable_whole_model_verify, /*legacy_float_scale=*/true, denylisted_ops, denylisted_nodes, enable_variable_quantization); - if (status != kTfLiteOk) { error_reporter->exception(); return nullptr; } - return tflite::python_utils::ConvertToPyString( - reinterpret_cast(builder.GetCurrentBufferPointer()), - builder.GetSize()); + + return tflite::python_utils::ConvertToPyString(output_model.data(), + output_model.size()); } PyObject* MlirSparsifyModel(PyObject* data) { From c35ac7f2e984cec48c811da7e76882bd813553f2 Mon Sep 17 00:00:00 2001 From: Berkin Ilbeyi Date: Mon, 17 Jul 2023 00:56:58 -0700 Subject: [PATCH 362/376] [XLA] Fix a bug in required assignment matching that caused buffers illegally being put in the alternate mem when inefficient allocations are enabled. The combination of redundant eviction elimination and inefficient allocation features caused a corner case bug. We had an overly conservative logic to find previous allocations in AllocateSegment. When inefefficient allocation detection inserted a required allocation in default memory, this logic sometimes failed to find the correct previous allocation in the default memory. This CL makes this logic less conservative. This CL also adds a mechanism for tests to inject custom logic into inefficient allocation detection in order to test complex scenarios like in this bug. PiperOrigin-RevId: 548604568 --- .../xla/service/memory_space_assignment.cc | 16 ++- .../xla/service/memory_space_assignment.h | 6 + .../service/memory_space_assignment_test.cc | 109 ++++++++++++++++++ 3 files changed, 128 insertions(+), 3 deletions(-) diff --git a/tensorflow/compiler/xla/service/memory_space_assignment.cc b/tensorflow/compiler/xla/service/memory_space_assignment.cc index 6fd0645a9a45b5..5b2168c1fa0c8e 100644 --- a/tensorflow/compiler/xla/service/memory_space_assignment.cc +++ b/tensorflow/compiler/xla/service/memory_space_assignment.cc @@ -3715,6 +3715,18 @@ std::vector AlternateMemoryBestFitHeap::GetInefficientAllocationSites( absl::Span allocation_values) const { + // The logic below is used mostly for testing, allowing a test case to inject + // some custom logic for this method. + if (options_.get_inefficient_allocation_sites_fn) { + std::vector defining_positions; + defining_positions.reserve(allocation_values.size()); + for (const AllocationValue& value : allocation_values) { + defining_positions.push_back(value.defining_position()); + } + return options_.get_inefficient_allocation_sites_fn( + absl::MakeSpan(defining_positions)); + } + if (!options_.cost_analysis || options_.inefficient_use_to_copy_ratio == 0.0) { return {}; @@ -5313,9 +5325,7 @@ AlternateMemoryBestFitHeap::Result AlternateMemoryBestFitHeap::AllocateSegment( auto prev_allocation_it = std::find_if( allocation_sequence->rbegin(), allocation_sequence->rend(), [&](const auto& allocation) { - return allocation->memory_space() == - required_memory_space_at_start && - allocation->defining_position() == defining_position; + return allocation->memory_space() == required_memory_space_at_start; }); if (prev_allocation_it != allocation_sequence->rend()) { (*prev_allocation_it)->Extend(request.start_time); diff --git a/tensorflow/compiler/xla/service/memory_space_assignment.h b/tensorflow/compiler/xla/service/memory_space_assignment.h index be40b81099f187..fd3f3f449949df 100644 --- a/tensorflow/compiler/xla/service/memory_space_assignment.h +++ b/tensorflow/compiler/xla/service/memory_space_assignment.h @@ -1488,6 +1488,12 @@ struct Options { // case copy_bytes would be twice the size of the tensor. float inefficient_use_to_copy_ratio = 0.0; + // This is mostly used for testing, it allows a test case to inject its own + // logic for AlternateMemoryBestFitHeap::GetInefficientAllocationSites. + std::function>( + absl::Span)> + get_inefficient_allocation_sites_fn = nullptr; + // The window size used to calculate the pipeline overhead when HLO accesses // the default memory, in MiB. float pipeline_overhead_window_size_mib = 0; diff --git a/tensorflow/compiler/xla/service/memory_space_assignment_test.cc b/tensorflow/compiler/xla/service/memory_space_assignment_test.cc index 28169ac0205ccd..86f02ec4aaffe8 100644 --- a/tensorflow/compiler/xla/service/memory_space_assignment_test.cc +++ b/tensorflow/compiler/xla/service/memory_space_assignment_test.cc @@ -27,6 +27,7 @@ limitations under the License. #include #include #include +#include #include #include "absl/algorithm/container.h" @@ -5390,6 +5391,114 @@ TEST_P(MemorySpaceAssignmentTest, } } +TEST_P(MemorySpaceAssignmentTest, + WhileRedundantEvictionWithInefficientAllocationBug) { + absl::string_view hlo_string = R"( + HloModule module, is_scheduled=true + + while_cond { + p0 = (f32[3]{0}, f32[3]{0}, pred[]) parameter(0) + ROOT gte = pred[] get-tuple-element(p0), index=2 + } + + while_body { + p0 = (f32[3]{0}, f32[3]{0}, pred[]) parameter(0) + gte0 = f32[3]{0} get-tuple-element(p0), index=0 + gte1 = f32[3]{0} get-tuple-element(p0), index=1 + tanh = f32[3]{0} tanh(gte1) + gte2 = pred[] get-tuple-element(p0), index=2 + negate0 = f32[3]{0} negate(gte0) + negate1 = f32[3]{0} negate(negate0) + add = f32[3]{0} add(negate1, tanh) + ROOT tuple = (f32[3]{0}, f32[3]{0}, pred[]) tuple(add, gte1, gte2) + } + + while_cond1 { + p0 = (f32[3]{0}, f32[3]{0}, pred[]) parameter(0) + ROOT gte = pred[] get-tuple-element(p0), index=2 + } + + while_body1 { + p0 = (f32[3]{0}, f32[3]{0}, pred[]) parameter(0) + gte0 = f32[3]{0} get-tuple-element(p0), index=0 + gte2 = pred[] get-tuple-element(p0), index=2 + negate0 = f32[3]{0} negate(gte0) + negate1 = f32[3]{0} negate(negate0) + negate2 = f32[3]{0} negate(negate1) + negate3 = f32[3]{0} negate(negate2) + negate4 = f32[3]{0} negate(negate3) + negate5 = f32[3]{0} negate(negate4) + negate6 = f32[3]{0} negate(negate5) + negate7 = f32[3]{0} negate(negate6) + negate8 = f32[3]{0} negate(negate7) + negate9 = f32[3]{0} negate(negate8) + negate10 = f32[3]{0} negate(negate9) + negate11 = f32[3]{0} negate(negate10) + negate12 = f32[3]{0} negate(negate11) + negate13 = f32[3]{0} negate(negate12) + negate14 = f32[3]{0} negate(negate13) + gte1 = f32[3]{0} get-tuple-element(p0), index=1 + tanh = f32[3]{0} tanh(gte1) + add = f32[3]{0} add(negate14, tanh) + ROOT tuple = (f32[3]{0}, f32[3]{0}, pred[]) tuple(add, gte1, gte2) + } + + ENTRY entry { + p0 = f32[3]{0} parameter(0) + p1 = pred[] parameter(1) + p2 = f32[3]{0} parameter(2) + copy = f32[3]{0} copy(p0) + tuple1 = (f32[3]{0}, f32[3]{0}, pred[]) tuple(copy, p0, p1) + while1 = (f32[3]{0}, f32[3]{0}, pred[]) while(tuple1), condition=while_cond, body=while_body + gte0 = f32[3]{0} get-tuple-element(while1), index=0 + gte1 = f32[3]{0} get-tuple-element(while1), index=1 + negate0_entry = f32[3]{0} negate(gte1) + gte2 = pred[] get-tuple-element(while1), index=2 + tuple2 = (f32[3]{0}, f32[3]{0}, pred[]) tuple(gte0, gte1, gte2) + while2 = (f32[3]{0}, f32[3]{0}, pred[]) while(tuple2), condition=while_cond1, body=while_body1 + negate1 = f32[3]{0} negate(negate0_entry) + negate2 = f32[3]{0} negate(negate1) + negate3 = f32[3]{0} negate(negate2) + negate4 = f32[3]{0} negate(negate3) + negate5 = f32[3]{0} negate(negate4) + negate6 = f32[3]{0} negate(negate5) + negate7 = f32[3]{0} negate(negate6) + negate8 = f32[3]{0} negate(negate7) + negate9 = f32[3]{0} negate(negate8) + negate10 = f32[3]{0} negate(negate9) + negate11 = f32[3]{0} negate(negate10) + negate12 = f32[3]{0} negate(negate11) + negate13 = f32[3]{0} negate(negate12) + negate14 = f32[3]{0} negate(negate13) + gte = f32[3]{0} get-tuple-element(while2), index=1 + ROOT add = f32[3]{0} add(gte, negate14) + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + Options options = DefaultMemorySpaceOptions(); + // Inject GetInefficientAllocationSites to mark negate0_entry use as + // inefficient. This triggers a corner case bug where allocating for while2{1} + // in the retry allocation fails to find the previous required allocation in + // default memory, and creates a new one which is wrong. + bool marked_inefficient = false; + options.get_inefficient_allocation_sites_fn = + [&](absl::Span defining_positions) + -> std::vector> { + if (absl::c_find(defining_positions, + HloPosition{FindInstruction(module.get(), "while1"), + {1}}) != defining_positions.end() && + !marked_inefficient) { + LOG(INFO) << "Marking the use inefficient."; + marked_inefficient = true; + return {HloUse{FindInstruction(module.get(), "negate0_entry"), 0}}; + } + return {}; + }; + AssignMemorySpace(module.get(), options); +} + TEST_P(MemorySpaceAssignmentTest, BitcastRoot) { // Tests against a bug where the root of entry computation is a bitcast // instruction and it ends up getting an allocation in the alternate memory. From bdfbde0707fd86e97c171684c0d0fc6b56f6a063 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 17 Jul 2023 02:02:11 -0700 Subject: [PATCH 363/376] compat: Update forward compatibility horizon to 2023-07-17 PiperOrigin-RevId: 548620774 --- tensorflow/python/compat/compat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index 88b73342c3202f..a8256da81c5d6c 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -29,7 +29,7 @@ # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2023, 7, 16) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2023, 7, 17) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None From 58482014ed2e6b309c98836cd78d019a7b8e53f1 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 17 Jul 2023 02:02:11 -0700 Subject: [PATCH 364/376] Update GraphDef version to 1560. PiperOrigin-RevId: 548620777 --- tensorflow/core/public/version.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index 963b5d9169e7c0..ef972445c7fc58 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -108,7 +108,7 @@ limitations under the License. #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 -#define TF_GRAPH_DEF_VERSION 1559 // Updated: 2023/7/16 +#define TF_GRAPH_DEF_VERSION 1560 // Updated: 2023/7/17 // Checkpoint compatibility versions (the versions field in SavedSliceMeta). // From 396f8abd77b02c88e08dc13f9be7304bec881de2 Mon Sep 17 00:00:00 2001 From: Johannes Reifferscheid Date: Mon, 17 Jul 2023 02:48:08 -0700 Subject: [PATCH 365/376] Support all fusion kinds except Triton in GetLaunchDimensions. PiperOrigin-RevId: 548629503 --- .../xla/service/gpu/hlo_fusion_analysis.cc | 30 ++++++++++++++++++- .../xla/service/gpu/hlo_fusion_analysis.h | 4 +-- .../xla/service/gpu/ir_emitter_unnested.cc | 28 +++++++---------- .../xla/service/gpu/ir_emitter_unnested.h | 6 ++-- .../xla/service/gpu/launch_dimensions.cc | 10 +++---- .../xla/service/gpu/launch_dimensions.h | 4 +-- 6 files changed, 53 insertions(+), 29 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/hlo_fusion_analysis.cc b/tensorflow/compiler/xla/service/gpu/hlo_fusion_analysis.cc index 39619ad46c96ec..58215f84138c9e 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_fusion_analysis.cc +++ b/tensorflow/compiler/xla/service/gpu/hlo_fusion_analysis.cc @@ -324,7 +324,35 @@ StatusOr HloFusionAnalysis::GetLaunchDimensions( return LaunchDimensions(tiling_scheme->GetNumberOfBlocksPhysical(), tiling_scheme->GetNumThreadsPerBlockPhysical()); } - default: + case EmitterFusionKind::kInputSlices: { + auto* root = + fusion_->fused_instructions_computation()->root_instruction(); + xla::Shape shape; + if (root->opcode() == HloOpcode::kSlice) { + shape = root->operands()[0]->shape(); + } else { + CHECK_EQ(root->opcode(), HloOpcode::kTuple); + // We already verified that the shapes are compatible in + // `GetEmitterFusionKind`. + shape = root->operands()[0]->operands()[0]->shape(); + } + constexpr int kUnrollFactor = 1; + return CalculateLaunchDimensions( + shape, *device_info_, use_experimental_block_size, {kUnrollFactor}); + } + case EmitterFusionKind::kScatter: { + const auto& root_shape = fusion_->fused_instructions_computation() + ->root_instruction() + ->shape(); + int64_t num_elements = ShapeUtil::ElementsIn(root_shape); + int unroll_factor = num_elements % 4 == 0 ? 4 + : num_elements % 2 == 0 ? 2 + : 1; + return CalculateLaunchDimensions(root_shape, *device_info_, + use_experimental_block_size, + {unroll_factor, /*few_waves=*/false}); + } + case EmitterFusionKind::kTriton: return Unimplemented("GetLaunchDimensions"); } } diff --git a/tensorflow/compiler/xla/service/gpu/hlo_fusion_analysis.h b/tensorflow/compiler/xla/service/gpu/hlo_fusion_analysis.h index 035916f8282fc7..0d771a8b897d3c 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_fusion_analysis.h +++ b/tensorflow/compiler/xla/service/gpu/hlo_fusion_analysis.h @@ -76,8 +76,8 @@ class HloFusionAnalysis { // Determines the fusion type for the emitter. EmitterFusionKind GetEmitterFusionKind() const; - // Determines the launch dimensions for the fusion. The fusion kind must be - // one of `kLoop`, `kReduction` or `kTranspose`. + // Determines the launch dimensions for the fusion. The fusion kind must not + // be `kTriton`. StatusOr GetLaunchDimensions( bool use_experimental_block_size = false); diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 299dd9fdb10785..f11ae1aa4ce1bd 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -2077,9 +2077,9 @@ Status IrEmitterUnnested::EmitFusion(mlir::Operation* op) { case HloFusionAnalysis::EmitterFusionKind::kTranspose: return EmitUnnestedTranspose(fusion_op, fusion_analysis); case HloFusionAnalysis::EmitterFusionKind::kInputSlices: - return EmitInputFusibleNonStridedSlices(op); + return EmitInputFusibleNonStridedSlices(op, fusion_analysis); case HloFusionAnalysis::EmitterFusionKind::kScatter: - return EmitScatter(fusion_op, fused_computation); + return EmitScatter(fusion_op, fused_computation, fusion_analysis); case HloFusionAnalysis::EmitterFusionKind::kLoop: { // Special case: DUS bool is_single = IsSingleInstructionFusion(fusion_op); @@ -5043,25 +5043,19 @@ Status IrEmitterUnnested::EmitElementForInputFusibleSlices( } Status IrEmitterUnnested::EmitInputFusibleNonStridedSlices( - mlir::Operation* op) { + mlir::Operation* op, HloFusionAnalysis& fusion_analysis) { auto fusion = mlir::cast(op); - constexpr int unroll_factor = 1; - TF_ASSIGN_OR_RETURN(const HloComputation* fused_computation, GetOrCreateSubComputationFromRegion(&fusion.getRegion(), /*is_fusion=*/true)); - TF_ASSIGN_OR_RETURN(Shape element_shape, - GetConsistentInputShapeForRootSlices(fused_computation)); bool use_experimental_block_size = hlo_module_config_.debug_options() .xla_gpu_enable_experimental_block_size(); - - TF_ASSIGN_OR_RETURN(LaunchDimensions launch_dimensions, - CalculateLaunchDimensions( - element_shape, ir_emitter_context_->gpu_device_info(), - use_experimental_block_size, {unroll_factor})); + TF_ASSIGN_OR_RETURN( + LaunchDimensions launch_dimensions, + fusion_analysis.GetLaunchDimensions(use_experimental_block_size)); TF_ASSIGN_OR_RETURN( std::optional> opt_ir_arrays, @@ -5072,6 +5066,8 @@ Status IrEmitterUnnested::EmitInputFusibleNonStridedSlices( } std::vector& ir_arrays = opt_ir_arrays.value(); + TF_ASSIGN_OR_RETURN(Shape element_shape, + GetConsistentInputShapeForRootSlices(fused_computation)); return ParallelLoopEmitter( [&](const llvm_ir::IrArray::Index index) -> Status { return EmitElementForInputFusibleSlices(fused_computation, @@ -5178,7 +5174,8 @@ Status IrEmitterUnnested::EmitDynamicUpdateSlice( } Status IrEmitterUnnested::EmitScatter(mlir::lmhlo::FusionOp fusion_op, - const HloComputation* fused_computation) { + const HloComputation* fused_computation, + HloFusionAnalysis& fusion_analysis) { auto* root = fused_computation->root_instruction(); // The initialization from 'operand' is using different loop bounds, so @@ -5190,12 +5187,9 @@ Status IrEmitterUnnested::EmitScatter(mlir::lmhlo::FusionOp fusion_op, TF_RETURN_IF_ERROR([&] { auto unroll_factor = ComputeMaxUnrollFactor(fusion_op); - const Shape& element_shape = root->shape(); TF_ASSIGN_OR_RETURN( LaunchDimensions launch_dimensions, - CalculateLaunchDimensions( - element_shape, ir_emitter_context_->gpu_device_info(), - use_experimental_block_size, {unroll_factor, /*few_waves=*/false})); + fusion_analysis.GetLaunchDimensions(use_experimental_block_size)); TF_ASSIGN_OR_RETURN( std::optional> opt_ir_arrays, diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h index 60cfd3cf2f302f..cadd1c72bc81af 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h @@ -505,7 +505,8 @@ class IrEmitterUnnested : public IrEmitter { // different. On the other hand, the input ranges of slices can be // overlapping. Further generalization/specialization when the needs are seen // in the future. - Status EmitInputFusibleNonStridedSlices(mlir::Operation* op); + Status EmitInputFusibleNonStridedSlices(mlir::Operation* op, + HloFusionAnalysis& fusion_analysis); Status EmitElementForInputFusibleSlices( const HloComputation* fused_computation, @@ -556,7 +557,8 @@ class IrEmitterUnnested : public IrEmitter { const LaunchDimensions& launch_dimensions); Status EmitScatter(mlir::lmhlo::FusionOp fusion_op, - const HloComputation* fused_computation); + const HloComputation* fused_computation, + HloFusionAnalysis& fusion_analysis); Status EmitDynamicUpdateSlice(mlir::lmhlo::FusionOp fusion_op, const HloComputation* fused_computation); diff --git a/tensorflow/compiler/xla/service/gpu/launch_dimensions.cc b/tensorflow/compiler/xla/service/gpu/launch_dimensions.cc index cc4f7f34e4c722..90ead8a3f7c815 100644 --- a/tensorflow/compiler/xla/service/gpu/launch_dimensions.cc +++ b/tensorflow/compiler/xla/service/gpu/launch_dimensions.cc @@ -33,7 +33,7 @@ std::ostream& operator<<(std::ostream& out, return out; } -static int64_t ThreadsPerBlockLimit(GpuDeviceInfo gpu_device_info) { +static int64_t ThreadsPerBlockLimit(const GpuDeviceInfo& gpu_device_info) { int64_t threads_per_block = gpu_device_info.threads_per_block_limit; if (threads_per_block <= 0) { static std::atomic log_count{0}; @@ -53,7 +53,7 @@ static int64_t ThreadsPerBlockLimit(GpuDeviceInfo gpu_device_info) { } int64_t ThreadsPerBlockRowVectorized(const Shape& shape, - GpuDeviceInfo gpu_device_info, + const GpuDeviceInfo& gpu_device_info, LaunchDimensionsConfig dim_config) { if (shape.dimensions().empty()) { return -1; @@ -75,7 +75,7 @@ int64_t ThreadsPerBlockRowVectorized(const Shape& shape, } StatusOr CalculateLaunchDimensionsImplExperimental( - const Shape& shape, GpuDeviceInfo gpu_device_info, + const Shape& shape, const GpuDeviceInfo& gpu_device_info, LaunchDimensionsConfig dim_config) { int64_t num_elements = ShapeUtil::ElementsIn(shape); if (num_elements <= 1) { @@ -97,7 +97,7 @@ StatusOr CalculateLaunchDimensionsImplExperimental( } StatusOr CalculateLaunchDimensionsImpl( - const Shape& shape, GpuDeviceInfo gpu_device_info, + const Shape& shape, const GpuDeviceInfo& gpu_device_info, LaunchDimensionsConfig dim_config) { int64_t num_elements = ShapeUtil::ElementsIn(shape); if (num_elements <= 1) { @@ -200,7 +200,7 @@ StatusOr CalculateLaunchDimensionsImpl( } StatusOr CalculateLaunchDimensions( - const Shape& shape, GpuDeviceInfo gpu_device_info, + const Shape& shape, const GpuDeviceInfo& gpu_device_info, bool use_experimental_block_size, LaunchDimensionsConfig dim_config) { if (use_experimental_block_size) { VLOG(2) << "Experimental block size is enabled"; diff --git a/tensorflow/compiler/xla/service/gpu/launch_dimensions.h b/tensorflow/compiler/xla/service/gpu/launch_dimensions.h index 05ce2b1be70411..95228825403b8f 100644 --- a/tensorflow/compiler/xla/service/gpu/launch_dimensions.h +++ b/tensorflow/compiler/xla/service/gpu/launch_dimensions.h @@ -130,12 +130,12 @@ struct LaunchDimensionsConfig { // Returns -1 if the shape doesn't allow the row vectorization code path. // If supported, return the number of threads to use in that case. int64_t ThreadsPerBlockRowVectorized(const Shape& shape, - GpuDeviceInfo gpu_device_info, + const GpuDeviceInfo& gpu_device_info, LaunchDimensionsConfig dim_config); // Calculates the launch dimensions used to invoke `hlo`. StatusOr CalculateLaunchDimensions( - const Shape& shape, GpuDeviceInfo gpu_device_info, + const Shape& shape, const GpuDeviceInfo& gpu_device_info, bool use_experimental_block_size, LaunchDimensionsConfig dim_config = {}); } // namespace gpu From 1a66ca361e5c10dc9b0e71da58b1a41fd60a603c Mon Sep 17 00:00:00 2001 From: Matt Kreileder Date: Mon, 17 Jul 2023 03:10:58 -0700 Subject: [PATCH 366/376] Extend c_api_opaque to support reading and writing strings values from and to an opaque tensor. PiperOrigin-RevId: 548633760 --- tensorflow/lite/core/c/BUILD | 4 + tensorflow/lite/core/c/c_api_opaque.cc | 39 +++++ tensorflow/lite/core/c/c_api_opaque.h | 58 ++++++++ tensorflow/lite/core/c/c_api_test.cc | 191 +++++++++++++++++++++++++ 4 files changed, 292 insertions(+) diff --git a/tensorflow/lite/core/c/BUILD b/tensorflow/lite/core/c/BUILD index 972df7b7677582..28fd64d1021052 100644 --- a/tensorflow/lite/core/c/BUILD +++ b/tensorflow/lite/core/c/BUILD @@ -182,6 +182,7 @@ cc_test( ":c_api_types", ":common", "//tensorflow/core/platform:resource_loader", + "//tensorflow/lite:string_util", "//tensorflow/lite/c:c_api_internal", "//tensorflow/lite/core:subgraph", "//tensorflow/lite/delegates:delegate_test_util", @@ -208,6 +209,7 @@ cc_test( ":c_api_without_op_resolver_without_alwayslink", ":common", "//tensorflow/core/platform:resource_loader", + "//tensorflow/lite:string_util", "//tensorflow/lite/c:c_api_internal", "//tensorflow/lite/c:selectively_built_c_api_test_lib", "//tensorflow/lite/core:subgraph", @@ -349,6 +351,7 @@ tflite_cc_library_with_c_headers_test( "//tensorflow/lite:builtin_ops", "//tensorflow/lite:framework", "//tensorflow/lite:kernel_api", + "//tensorflow/lite:string_util", "//tensorflow/lite/c:c_api_internal", "//tensorflow/lite/c:c_api_opaque_internal", "//tensorflow/lite/core:framework", @@ -390,6 +393,7 @@ tflite_cc_library_with_c_headers_test( ":common", "//tensorflow/lite:framework", "//tensorflow/lite:kernel_api", + "//tensorflow/lite:string_util", "//tensorflow/lite/c:c_api_internal", "//tensorflow/lite/c:c_api_opaque_internal_without_alwayslink", "//tensorflow/lite/core:framework", diff --git a/tensorflow/lite/core/c/c_api_opaque.cc b/tensorflow/lite/core/c/c_api_opaque.cc index 631412b7aea59a..f889f5a5899899 100644 --- a/tensorflow/lite/core/c/c_api_opaque.cc +++ b/tensorflow/lite/core/c/c_api_opaque.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/lite/core/c/common.h" #include "tensorflow/lite/core/subgraph.h" #include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/string_util.h" namespace { @@ -36,6 +37,13 @@ const TfLiteTensor* Convert(const TfLiteOpaqueTensor* opaque_tensor) { return reinterpret_cast(opaque_tensor); } +TfLiteTensor* Convert(TfLiteOpaqueTensor* opaque_tensor) { + // The following cast is safe only because this code is part of the + // TF Lite runtime implementation. Apps using TF Lite should not rely on + // TfLiteOpaqueTensor and TfLiteTensor being equivalent. + return reinterpret_cast(opaque_tensor); +} + const TfLiteNode* Convert(const TfLiteOpaqueNode* opaque_node) { // The following cast is safe only because this code is part of the // TF Lite runtime implementation. Apps using TF Lite should not rely on @@ -168,6 +176,37 @@ TfLiteStatus TfLiteOpaqueTensorCopyToBuffer( output_data_size); } +int TfLiteOpaqueTensorGetStringCount(const TfLiteOpaqueTensor* tensor) { + return tflite::GetStringCount(Convert(tensor)); +} + +TfLiteStatus TfLiteOpaqueTensorGetString(const TfLiteOpaqueTensor* tensor, + int index, const char** str, + int* len) { + tflite::StringRef str_ref = tflite::GetString(Convert(tensor), index); + *str = str_ref.str; + *len = str_ref.len; + return kTfLiteOk; +} + +TfLiteStatus TfLiteOpaqueTensorWriteStrings(TfLiteOpaqueTensor* tensor, + const char* const* str_array, + int str_array_len, + const int* str_n_len) { + tflite::DynamicBuffer buf; + for (int i = 0; i < str_array_len; ++i) { + buf.AddString(str_array[i], str_n_len[i]); + } + buf.WriteToTensorAsVector(Convert(tensor)); + return kTfLiteOk; +} + +TfLiteStatus TfLiteOpaqueTensorWriteString(TfLiteOpaqueTensor* tensor, + const char* str, const int len) { + TfLiteOpaqueTensorWriteStrings(tensor, &str, 1, &len); + return kTfLiteOk; +} + const TfLiteOpaqueTensor* TfLiteOpaqueNodeGetInput( const TfLiteOpaqueContext* opaque_context, const TfLiteOpaqueNode* opaque_node, int index) { diff --git a/tensorflow/lite/core/c/c_api_opaque.h b/tensorflow/lite/core/c/c_api_opaque.h index ff70c4304401e4..44c95a51d4ece0 100644 --- a/tensorflow/lite/core/c/c_api_opaque.h +++ b/tensorflow/lite/core/c/c_api_opaque.h @@ -113,6 +113,64 @@ TFL_CAPI_EXPORT extern TfLiteStatus TfLiteOpaqueTensorCopyToBuffer( const TfLiteOpaqueTensor* opaque_tensor, void* output_data, size_t output_data_size); +// Returns the number of strings stored in the provided 'tensor'. Returns -1 in +// case of failure. +int TfLiteOpaqueTensorGetStringCount(const TfLiteOpaqueTensor* tensor); + +// Stores the address of the n-th (denoted by the provided 'index') string +// contained in the provided 'tensor' in the provided '*str' pointer. Stores +// the length of the string in the provided '*len' argument. +// +// Returns 'kTfLiteOk' if '*str' and '*len' have been set successfully. Any +// other return value indicates a failure, which leaves '*str' and '*len' in an +// unspecified state. +// +// The range of valid indices is defined by the half open interval [0, N), +// where N == TfLiteOpaqueTensorGetStringCount(tensor). +// +// Note that 'str' is not guaranteed to be null-terminated. Also note that this +// function will not create a copy of the underlying string data. The data is +// owned by the 'tensor'. +TfLiteStatus TfLiteOpaqueTensorGetString(const TfLiteOpaqueTensor* tensor, + int index, const char** str, int* len); + +// Writes the array of strings specified by 'str_array' into +// the specified 'tensor'. The strings provided via the 'str_array' are being +// copied into the 'tensor'. Returns 'kTfLiteOk' in case of success. Any other +// return value indicates a failure. +// +// The provided 'str_array_len' must denote the length of 'str_array' +// and 'str_n_len[i]' must denote the length of the i-th string. +// +// The provided strings don't need to be null terminated and may contain +// embedded null characters. The amount of bytes copied into the 'tensor' is +// entirely determined by 'str_n_len[i]' and it is the caller's responsibility +// to set this value correctly to avoid undefined behavior. +// +// Also note that calling 'TfLiteOpaqueTensorWriteStrings' deallocates any +// previously stored data in the 'tensor'. +TfLiteStatus TfLiteOpaqueTensorWriteStrings(TfLiteOpaqueTensor* tensor, + const char* const* str_array, + int str_array_len, + const int* str_n_len); + +// Writes the string pointed to by the provided 'str' pointer of length 'len' +// into the provided 'tensor'. The string provided via 'str' is +// copied into the 'tensor'. Returns 'kTfLiteOk' in case of success. Any +// other return value indicates a failure. +// +// Note that calling 'TfLiteOpaqueTensorWriteString' deallocates any +// previously stored data in the 'tensor'. E.g. suppose 't' denotes a +// 'TfLiteOpaqueTensor*', then calling 'TfLiteOpaqueTensorWriteString(t, "AB", +// 2)' followed by a call to 'TfLiteOpaqueTensorWriteString(t, "CD", 2)' will +// lead to 't' containing 'CD', not 'ABCD'. +// +// 'TfLiteOpaqueTensorWriteString' is a convenience function for the use case +// of writing a single string to a tensor and its effects are identical to +// calling 'TfLiteOpaqueTensorWriteStrings' with an array of a single string. +TfLiteStatus TfLiteOpaqueTensorWriteString(TfLiteOpaqueTensor* tensor, + const char* str, int len); + // -------------------------------------------------------------------------- // Accessors for TfLiteOpaqueNode. diff --git a/tensorflow/lite/core/c/c_api_test.cc b/tensorflow/lite/core/c/c_api_test.cc index 152f0a66c4fa26..e17611860a76f9 100644 --- a/tensorflow/lite/core/c/c_api_test.cc +++ b/tensorflow/lite/core/c/c_api_test.cc @@ -37,6 +37,7 @@ limitations under the License. #include "tensorflow/lite/core/subgraph.h" #include "tensorflow/lite/delegates/delegate_test_util.h" #include "tensorflow/lite/schema/schema_generated.h" +#include "tensorflow/lite/string_util.h" #include "tensorflow/lite/testing/util.h" namespace { @@ -1648,6 +1649,196 @@ TEST(CApiSimple, OpaqueApiAccessors) { EXPECT_TRUE(delegate_kernel_invoked); } +TEST(CApiSimple, OpaqueApiAccessorsStrings) { + ::tflite::Interpreter interpreter; + interpreter.AddTensors(3); + std::vector dims = {1}; + TfLiteQuantizationParams quant{}; + interpreter.SetTensorParametersReadWrite(0, kTfLiteString, "a", dims, quant, + /*is_variable=*/false); + interpreter.SetTensorParametersReadWrite(1, kTfLiteString, "b", dims, quant, + /*is_variable=*/false); + interpreter.SetTensorParametersReadWrite(2, kTfLiteString, "c", dims, quant, + /*is_variable=*/false); + + interpreter.SetInputs({0, 1}); + interpreter.SetOutputs({2}); + const char* initial_data = ""; + tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates resolver; + TfLiteAddParams* builtin_data = + reinterpret_cast(malloc(sizeof(TfLiteAddParams))); + builtin_data->activation = kTfLiteActNone; + builtin_data->pot_scale_int16 = false; + const TfLiteRegistration* registration = + resolver.FindOp(::tflite::BuiltinOperator_ADD, 1); + interpreter.AddNodeWithParameters({0, 1}, {2}, initial_data, 0, builtin_data, + registration); + + TfLiteOpaqueDelegateBuilder opaque_delegate_builder{}; + opaque_delegate_builder.flags = kTfLiteDelegateFlagsAllowDynamicTensors; + bool delegate_kernel_invoked = false; + opaque_delegate_builder.data = &delegate_kernel_invoked; + opaque_delegate_builder.Prepare = [](TfLiteOpaqueContext* context, + TfLiteOpaqueDelegate* delegate, + void* data) -> TfLiteStatus { + TfLiteRegistrationExternal* registration = TfLiteRegistrationExternalCreate( + kTfLiteBuiltinDelegate, "my delegate", 123); + TfLiteRegistrationExternalSetInit( + registration, + [](TfLiteOpaqueContext* opaque_context, const char* buffer, + size_t length) -> void* { + const TfLiteOpaqueDelegateParams* params = + reinterpret_cast(buffer); + EXPECT_EQ(2, params->input_tensors->size); + TfLiteOpaqueTensor* opaque_input_tensor = + TfLiteOpaqueContextGetOpaqueTensor( + opaque_context, params->input_tensors->data[0]); + EXPECT_EQ(1, TfLiteOpaqueTensorNumDims(opaque_input_tensor)); + EXPECT_EQ(1, TfLiteOpaqueTensorDim(opaque_input_tensor, 0)); + EXPECT_EQ(kTfLiteDynamic, + TfLiteOpaqueTensorGetAllocationType(opaque_input_tensor)); + + bool* delegate_kernel_invoked = + static_cast(params->delegate_data); + *delegate_kernel_invoked = true; + return nullptr; + }); + + TfLiteRegistrationExternalSetPrepare( + registration, + [](TfLiteOpaqueContext* context, + TfLiteOpaqueNode* node) -> TfLiteStatus { return kTfLiteOk; }); + + TfLiteRegistrationExternalSetInvoke( + registration, + [](TfLiteOpaqueContext* context, + TfLiteOpaqueNode* node) -> TfLiteStatus { + const TfLiteOpaqueTensor* input0 = + TfLiteOpaqueNodeGetInput(context, node, 0); + + EXPECT_EQ(TfLiteOpaqueTensorGetStringCount(input0), 4); + const char* input0_string2 = nullptr; + int input0_string2_len = -1; + EXPECT_EQ(kTfLiteOk, + TfLiteOpaqueTensorGetString(input0, 2, &input0_string2, + &input0_string2_len)); + EXPECT_EQ(std::string(input0_string2, input0_string2_len), "F"); + EXPECT_EQ(1, input0_string2_len); + + const TfLiteOpaqueTensor* input1 = + TfLiteOpaqueNodeGetInput(context, node, 1); + EXPECT_EQ(TfLiteOpaqueTensorGetStringCount(input1), 1); + const char* input1_string0 = nullptr; + int input1_string0_len = -1; + EXPECT_EQ(kTfLiteOk, + TfLiteOpaqueTensorGetString(input1, 0, &input1_string0, + &input1_string0_len)); + EXPECT_EQ(std::string(input1_string0, input1_string0_len), "XYZ"); + EXPECT_EQ(3, input1_string0_len); + + TfLiteOpaqueTensor* opaque_output0 = + TfLiteOpaqueNodeGetOutput(context, node, 0); + + // + // First use 'TfLiteOpaqueTensorWriteString' to check that we can copy + // a string from an input tensor to an output tensor. + // + EXPECT_EQ(kTfLiteOk, + TfLiteOpaqueTensorWriteString( + opaque_output0, input0_string2, input0_string2_len)); + const char* output_str_from_opaque_tensor = nullptr; + int output_str_from_opaque_tensor_len = -1; + EXPECT_EQ(kTfLiteOk, + TfLiteOpaqueTensorGetString( + opaque_output0, 0, &output_str_from_opaque_tensor, + &output_str_from_opaque_tensor_len)); + EXPECT_EQ(std::string(output_str_from_opaque_tensor, + output_str_from_opaque_tensor_len), + "F"); + EXPECT_EQ(1, output_str_from_opaque_tensor_len); + + // + // Then perform the 'actual' ADD operation of adding the input tensor + // string to the output tensor. + // + std::vector str_array; + std::vector str_array_len; + for (int i = 0; i < TfLiteOpaqueTensorGetStringCount(input0); ++i) { + const char* input_string = nullptr; + int input_string_len = -1; + EXPECT_EQ(kTfLiteOk, + TfLiteOpaqueTensorGetString(input0, i, &input_string, + &input_string_len)); + str_array.push_back(input_string); + str_array_len.push_back(input_string_len); + } + str_array.push_back(input1_string0); + str_array_len.push_back(input1_string0_len); + + EXPECT_EQ(kTfLiteOk, TfLiteOpaqueTensorWriteStrings( + opaque_output0, str_array.data(), + str_array.size(), str_array_len.data())); + return kTfLiteOk; + }); + + TfLiteIntArray* execution_plan{}; + TfLiteOpaqueContextGetExecutionPlan(context, &execution_plan); + TfLiteOpaqueContextReplaceNodeSubsetsWithDelegateKernels( + context, registration, execution_plan, delegate); + return kTfLiteOk; + }; + + TfLiteDelegate my_delegate{}; + my_delegate.opaque_delegate_builder = &opaque_delegate_builder; + EXPECT_EQ(kTfLiteOk, interpreter.ModifyGraphWithDelegate(&my_delegate)); + EXPECT_TRUE(delegate_kernel_invoked); + EXPECT_EQ(kTfLiteOk, interpreter.AllocateTensors()); + + // + // Load input tensors with string data. + // + TfLiteTensor* t0 = interpreter.tensor(0); + tflite::DynamicBuffer buf0; + const char* raw_buf_with_embedded_null = "DDD\0EEE"; + const char* raw_buf_without_embedded_null = "12345678"; + std::vector t0_strings{ + "ABC", + std::string(raw_buf_with_embedded_null, raw_buf_with_embedded_null + 6), + "F", + std::string(raw_buf_without_embedded_null, + raw_buf_without_embedded_null + 4)}; + for (const std::string& s : t0_strings) { + ASSERT_EQ(buf0.AddString(s.data(), s.size()), kTfLiteOk); + } + buf0.WriteToTensorAsVector(t0); + + TfLiteTensor* t1 = interpreter.tensor(1); + char s1[] = "XYZ"; + tflite::DynamicBuffer buf1; + ASSERT_EQ(buf1.AddString(s1, 3), kTfLiteOk); + buf1.WriteToTensorAsVector(t1); + + // + // Invoke the interpreter, so that the input tensor strings get copied to the + // output tensor. + // + EXPECT_EQ(kTfLiteOk, interpreter.Invoke()); + + // + // Check that the output tensor stores the combination of the input strings. + // + const std::vector expected_strings{ + "ABC", + std::string(raw_buf_with_embedded_null, raw_buf_with_embedded_null + 6), + "F", "1234", "XYZ"}; + TfLiteTensor* t2 = interpreter.tensor(2); + EXPECT_EQ(tflite::GetStringCount(t2), expected_strings.size()); + for (int i = 0; i < tflite::GetStringCount(t2); ++i) { + tflite::StringRef str_ref = tflite::GetString(t2, i); + EXPECT_EQ(std::string(str_ref.str, str_ref.len), expected_strings[i]); + } +} + void AddNode( tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates* resolver, ::tflite::Interpreter* interpreter) { From fa791fcf536a1c1cf8fdab83d051e5624cd115e4 Mon Sep 17 00:00:00 2001 From: Ilia Sergachev Date: Mon, 17 Jul 2023 03:24:57 -0700 Subject: [PATCH 367/376] [XLA:GPU] Switch Triton GEMM to block pointers. PiperOrigin-RevId: 548636217 --- .../xla/service/gpu/ir_emitter_triton.cc | 631 ++++++++++-------- 1 file changed, 341 insertions(+), 290 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_triton.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_triton.cc index 6065cf3236dcb2..f9ddf728322bc0 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_triton.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_triton.cc @@ -92,6 +92,7 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/tsl/platform/errors.h" #include "tensorflow/tsl/platform/path.h" +#include "tensorflow/tsl/platform/statusor.h" #include "tensorflow/tsl/platform/tensor_float_32_utils.h" #include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.h" #include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h" @@ -412,24 +413,16 @@ Value EmitElementwise(mlir::ImplicitLocOpBuilder& b, } } -Value EmitParameter(mlir::ImplicitLocOpBuilder& b, - const HloInstruction& parameter, mlir::triton::FuncOp fn, - Value load_offsets, Value load_mask) { - Value param = fn.getArgument(parameter.parameter_number()); - mlir::ArrayRef tile_shape = - load_offsets.dyn_cast().getType().getShape(); - if (load_mask != nullptr) { - Value zeros_like = CreateConst( - b, TritonType(b, parameter.shape().element_type()), 0, tile_shape); - return b.create( - AddPtr(b, Splat(b, param, tile_shape), load_offsets), load_mask, - zeros_like, mt::CacheModifier::NONE, mt::EvictionPolicy::NORMAL, - /*isVolatile=*/false); +Value EmitParameterLoad(mlir::ImplicitLocOpBuilder& b, Value tensor_pointer, + mlir::ArrayRef boundary_checks) { + std::optional padding; + if (!boundary_checks.empty()) { + padding = mt::PaddingOption::PAD_ZERO; } - return b.create( - AddPtr(b, Splat(b, param, tile_shape), load_offsets), - mt::CacheModifier::NONE, mt::EvictionPolicy::NORMAL, - /*isVolatile=*/false); + return b.create(tensor_pointer, boundary_checks, padding, + mt::CacheModifier::NONE, + mt::EvictionPolicy::NORMAL, + /*isVolatile=*/false); } Value EmitConstant(mlir::ImplicitLocOpBuilder& b, @@ -457,16 +450,16 @@ Value EmitBroadcast(mlir::ImplicitLocOpBuilder& b, return input; } -Value EmitScope(mlir::ImplicitLocOpBuilder& b, absl::string_view libdevice_path, - mlir::triton::FuncOp fn, - absl::Span instructions, - absl::flat_hash_map& values, - Value load_offsets, Value load_mask); +StatusOr EmitScope( + mlir::ImplicitLocOpBuilder& b, absl::string_view libdevice_path, + absl::Span instructions, + absl::flat_hash_map& values, + mlir::ArrayRef tile_shape, Value tile_mask); -Value EmitReduce(mlir::ImplicitLocOpBuilder& b, - const HloInstruction& hlo_reduce, - absl::string_view libdevice_path, mlir::triton::FuncOp fn, - Value input, Value tile_mask) { +StatusOr EmitReduce(mlir::ImplicitLocOpBuilder& b, + const HloInstruction& hlo_reduce, + absl::string_view libdevice_path, Value input, + Value tile_mask) { llvm::ArrayRef input_shape = input.cast().getType().getShape(); @@ -490,12 +483,14 @@ Value EmitReduce(mlir::ImplicitLocOpBuilder& b, // reduction is computed correctly, since it is the neutral value with regards // to the reducer. Value neutral = EmitConstant(b, *hlo_reduce.operand(1)); - Value masked_input = - b.create(tile_mask, input, Splat(b, neutral, input_shape)); + if (tile_mask) { + input = b.create(tile_mask, input, + Splat(b, neutral, input_shape)); + } // Triton actually only performs reductions on float32 inputs, and we must // thus upcast/downcast our input if its data type is different. - Value casted_input = Cast(b, masked_input, b.getF32Type()); + Value casted_input = Cast(b, input, b.getF32Type()); mt::ReduceOp reduction = b.create( SmallVector({casted_input}), (int)input_shape.size() - 1); @@ -525,8 +520,9 @@ Value EmitReduce(mlir::ImplicitLocOpBuilder& b, CHECK(!to_emit.empty()); b.setInsertionPointToStart(reducer); - Value result = - EmitScope(b, libdevice_path, fn, to_emit, region_values, {}, {}); + TF_ASSIGN_OR_RETURN(Value result, + EmitScope(b, libdevice_path, to_emit, region_values, + /*tile_shape=*/{}, /*tile_mask=*/{})); b.create(SmallVector({result})); b.setInsertionPointAfter(reduction); } @@ -537,24 +533,25 @@ Value EmitReduce(mlir::ImplicitLocOpBuilder& b, // Emit sequence of instructions using compatible tiling ordered producers // before consumers. -Value EmitScope(mlir::ImplicitLocOpBuilder& b, absl::string_view libdevice_path, - mlir::triton::FuncOp fn, - absl::Span instructions, - absl::flat_hash_map& values, - Value load_offsets, Value load_mask) { +StatusOr EmitScope( + mlir::ImplicitLocOpBuilder& b, absl::string_view libdevice_path, + absl::Span instructions, + absl::flat_hash_map& values, + mlir::ArrayRef tile_shape, Value tile_mask) { for (const HloInstruction* hlo : instructions) { Value result; if (hlo->opcode() == HloOpcode::kParameter) { - result = EmitParameter(b, *hlo, fn, load_offsets, load_mask); + // Parameter loads are handled outside EmitScope. + TF_RET_CHECK(values.contains(hlo)) << hlo->ToString(); + continue; } else if (hlo->opcode() == HloOpcode::kConstant) { result = EmitConstant(b, *hlo); } else if (hlo->opcode() == HloOpcode::kBroadcast) { - mlir::ArrayRef tile_shape = - load_offsets.dyn_cast().getType().getShape(); result = EmitBroadcast(b, *hlo, values[hlo->operand(0)], tile_shape); } else if (hlo->opcode() == HloOpcode::kReduce) { - result = EmitReduce(b, *hlo, libdevice_path, fn, values[hlo->operand(0)], - load_mask); + TF_ASSIGN_OR_RETURN( + result, EmitReduce(b, *hlo, libdevice_path, values[hlo->operand(0)], + tile_mask)); } else if (hlo->IsElementwise()) { std::vector operands; operands.reserve(hlo->operands().size()); @@ -563,14 +560,14 @@ Value EmitScope(mlir::ImplicitLocOpBuilder& b, absl::string_view libdevice_path, } result = EmitElementwise(b, libdevice_path, *hlo, operands); } else if (hlo->opcode() == HloOpcode::kTuple) { - CHECK(hlo->IsRoot()) << hlo->ToString(); + TF_RET_CHECK(hlo->IsRoot()) << hlo->ToString(); } else if (hlo->opcode() == HloOpcode::kBitcast || hlo->opcode() == HloOpcode::kReshape) { result = values[hlo->operand(0)]; } else { LOG(FATAL) << hlo->ToString(); } - CHECK(values.insert({hlo, result}).second) << hlo->ToString(); + TF_RET_CHECK(values.insert({hlo, result}).second) << hlo->ToString(); VLOG(8) << "Emitted " << hlo->ToString(); } return values[instructions.back()]; @@ -582,6 +579,7 @@ void CreateTritonPipeline(mlir::OpPassManager& pm, const int ccAsInt = cc.major * 10 + cc.minor; // Based on optimize_ttir() in // @triton//:python/triton/compiler/compiler.py + pm.addPass(mt::createRewriteTensorPointerPass()); pm.addPass(mlir::createInlinerPass()); pm.addPass(mt::createCombineOpsPass()); pm.addPass(mlir::createCanonicalizerPass()); @@ -713,11 +711,12 @@ StatusOr MatMulImpl( auto loc = mlir::NameLoc::get(builder.getStringAttr(dot_instr->name())); mlir::ImplicitLocOpBuilder b(loc, builder); Type i32_ty = b.getI32Type(); + Type i64_ty = b.getI64Type(); Type int_ty; if constexpr (std::is_same_v) { - int_ty = b.getI64Type(); + int_ty = i64_ty; } else { - int_ty = b.getI32Type(); + int_ty = i32_ty; } const DotDimensionNumbers& dims = dot_instr->dot_dimension_numbers(); const DotFusionAnalysis analysis(dot_instr->parent(), config.split_k()); @@ -746,12 +745,12 @@ StatusOr MatMulImpl( const bool have_batch = dims.lhs_batch_dimensions_size() - have_split_k; CHECK_EQ(dot_instr->operand(0)->shape().rank(), 2 + have_split_k + have_batch); - const int64_t lhs_noncontracting_dim_idx = + const int lhs_noncontracting_dim_idx = GetNonContractingDims(dot_instr->operand(0)->shape(), dims.lhs_batch_dimensions(), dims.lhs_contracting_dimensions()) .value()[0]; - const int64_t rhs_noncontracting_dim_idx = + const int rhs_noncontracting_dim_idx = GetNonContractingDims(dot_instr->operand(1)->shape(), dims.rhs_batch_dimensions(), dims.rhs_contracting_dimensions()) @@ -786,8 +785,6 @@ StatusOr MatMulImpl( bool lhs_nc_split = false; // Either batch size or upper part of the length of a split nc dimension. int batch_size = 1; - IndexT stride_lhs_m = 0; - IndexT stride_lhs_k = 0; IndexT stride_lhs_batch = 0; IndexT stride_rhs_batch = 0; if (!analysis.ScopeParameters(DotFusionAnalysis::Scope::LHS).empty()) { @@ -828,20 +825,12 @@ StatusOr MatMulImpl( dims.lhs_contracting_dimensions(0)) ->size(), 1); - stride_lhs_m = lhs_nc_iter_spec->at(0).stride; - stride_lhs_k = analysis - .IterSpec(DotFusionAnalysis::Scope::LHS, lhs_param0, - dims.lhs_contracting_dimensions(0)) - ->at(0) - .stride; // Just the fastest-varying part of it if the dimension is split. m = lhs_nc_iter_spec->at(0).count; } CHECK_GE(m, 1); - IndexT stride_rhs_k = 0; - IndexT stride_rhs_n = 0; if (!analysis.ScopeParameters(DotFusionAnalysis::Scope::RHS).empty()) { const HloInstruction* rhs_param0 = *analysis.ScopeParameters(DotFusionAnalysis::Scope::RHS).begin(); @@ -851,16 +840,6 @@ StatusOr MatMulImpl( rhs_noncontracting_dim_idx) ->size(), 1); - stride_rhs_k = analysis - .IterSpec(DotFusionAnalysis::Scope::RHS, rhs_param0, - dims.rhs_contracting_dimensions(0)) - ->at(0) - .stride; - stride_rhs_n = analysis - .IterSpec(DotFusionAnalysis::Scope::RHS, rhs_param0, - rhs_noncontracting_dim_idx) - ->at(0) - .stride; if (have_batch) { const int64_t rhs_batch_dim_idx = *(dims.rhs_batch_dimensions().cend() - 1); @@ -875,57 +854,11 @@ StatusOr MatMulImpl( constexpr int group_m = 8; - IndexT stride_out_m = - analysis.IterSpec(DotFusionAnalysis::Scope::OUTPUT, root, lhs_nc_out_idx) - ->at(0) - .stride; - const int64_t n = + const int n = analysis.IterSpec(DotFusionAnalysis::Scope::OUTPUT, root, rhs_nc_out_idx) ->at(0) .count; CHECK_GE(n, 1); - IndexT stride_out_n = - analysis.IterSpec(DotFusionAnalysis::Scope::OUTPUT, root, rhs_nc_out_idx) - ->at(0) - .stride; - IndexT stride_out_split_k = 0; - if (have_split_k) { - stride_out_split_k = - analysis - .IterSpec(DotFusionAnalysis::Scope::OUTPUT, root, split_k_out_idx) - ->at(0) - .stride; - CHECK_GE(stride_out_split_k, 1); - } - IndexT stride_out_batch = 0; - if (have_batch) { - stride_out_batch = - analysis - .IterSpec(DotFusionAnalysis::Scope::OUTPUT, root, batch_out_idx) - ->at(0) - .stride; - CHECK_GE(stride_out_batch, 1); - } - { - const TensorIterationSpec::DimIterationSpec* spec = analysis.IterSpec( - DotFusionAnalysis::Scope::OUTPUT, root, lhs_nc_out_idx); - if (spec->size() > 1) { - CHECK_EQ(spec->size(), 2); - // Support one specific kind of output transpose that splits the dimension - // originating from the split LHS non-contracting one. - CHECK(!have_batch); - CHECK(lhs_nc_split); - CHECK_EQ(spec->at(1).count, batch_size); - stride_out_batch = spec->at(1).stride; - } else if (lhs_nc_split) { - // Dimension of the output produced by the non-contracting LHS one - // is physically contiguous though the producing LHS one is split. - // Because the major part of the split is implemented using the batch - // logic stride_out_batch is populated here as the stride of the minor - // part times its size. - stride_out_batch = stride_out_m * m; - } - } const int block_m = config.block_m(); const int block_k = config.block_k(); @@ -992,125 +925,120 @@ StatusOr MatMulImpl( } return value; }; - auto convert_range = [&](Value value) -> Value { - if constexpr (std::is_same_v) { - auto type = mlir::RankedTensorType::get( - value.dyn_cast().getType().getShape(), int_ty); - return b.create(type, value); - } - return value; - }; auto pid_m = b.create(first_pid_m, b.create(pid_nc, group_size)); - auto pid_m_stride = + auto pid_m_offset = b.create(pid_m, CreateConst(b, i32_ty, block_m)); - // TODO(b/270351731): Consider regenerating range_m to reduce register - // pressure if we figure out how to make this optimization survive CSE. - auto range_m = - b.create(Splat(b, pid_m_stride, block_m), Range(b, block_m)); auto pid_n = b.create( b.create(pid_nc, CreateConst(b, i32_ty, width)), group_size); - auto pid_n_stride = + auto pid_n_offset = b.create(pid_n, CreateConst(b, i32_ty, block_n)); - auto range_n = - b.create(Splat(b, pid_n_stride, block_n), Range(b, block_n)); - - auto range_k = b.create( - Splat(b, b.create(pid_k, CreateConst(b, i32_ty, block_k)), - block_k), - Range(b, block_k)); - - SmallVector shape_m_1{block_m, 1}; - auto range_lhs_m = convert_range( - b.create(range_m, CreateConst(b, i32_ty, m, block_m))); - auto lhs_offsets_m = - b.create(b.create(range_lhs_m, 1), - CreateConst(b, int_ty, stride_lhs_m, shape_m_1)); - SmallVector shape_1_k{1, block_k}; - auto lhs_offsets_k = b.create( - b.create(convert_range(range_k), 0), - CreateConst(b, int_ty, stride_lhs_k, shape_1_k)); - SmallVector shape_m_k{block_m, block_k}; - auto lhs_offset_batch = b.create( - convert_scalar(pid_batch), CreateConst(b, int_ty, stride_lhs_batch)); - auto lhs_offsets_init = b.create( - Broadcast(b, lhs_offsets_m.getResult().template cast(), - shape_m_k), - Broadcast(b, lhs_offsets_k.getResult().template cast(), - shape_m_k)); - lhs_offsets_init = b.create( - lhs_offsets_init, Splat(b, lhs_offset_batch, shape_m_k)); - - SmallVector shape_k_1{block_k, 1}; - auto rhs_offsets_k = b.create( - b.create(convert_range(range_k), 1), - CreateConst(b, int_ty, stride_rhs_k, shape_k_1)); - SmallVector shape_1_n{1, block_n}; - auto range_rhs_n = convert_range( - b.create(range_n, CreateConst(b, i32_ty, n, block_n))); - auto rhs_offsets_n = - b.create(b.create(range_rhs_n, 0), - CreateConst(b, int_ty, stride_rhs_n, shape_1_n)); - SmallVector shape_k_n{block_k, block_n}; - auto rhs_offset_batch = b.create( - convert_scalar(pid_batch), CreateConst(b, int_ty, stride_rhs_batch)); - auto rhs_offsets_init = b.create( - Broadcast(b, rhs_offsets_k.getResult().template cast(), - shape_k_n), - Broadcast(b, rhs_offsets_n.getResult().template cast(), - shape_k_n)); - rhs_offsets_init = b.create( - rhs_offsets_init, Splat(b, rhs_offset_batch, shape_k_n)); - SmallVector shape_m_n{block_m, block_n}; - ma::ConstantOp accumulator_init = CreateConst(b, acc_ty, 0, shape_m_n); + + auto pid_k_offset = + b.create(pid_k, CreateConst(b, i32_ty, block_k)); + + ma::ConstantOp accumulator_init = + CreateConst(b, acc_ty, 0, {block_m, block_n}); + + // Numbers of dimensions of tensor pointers that need masking on loads or + // stores. + std::vector boundary_checks_lhs; + std::vector boundary_checks_rhs; + std::vector boundary_checks_out; + if (m % block_m != 0) { + boundary_checks_lhs.push_back(0); + boundary_checks_out.push_back(0); + } + if (k % (block_k * config.split_k()) != 0) { + boundary_checks_lhs.push_back(1); + boundary_checks_rhs.push_back(0); + } + if (n % block_n != 0) { + boundary_checks_rhs.push_back(1); + boundary_checks_out.push_back(1); + } + + // Parameters are passed to the loop in non-trivial order, this map helps + // finding them. + absl::flat_hash_map iter_args_to_parameters; auto body_builder = [&](mlir::OpBuilder&, mlir::Location, Value ki, - mlir::ValueRange iterArgs) { - Value lhs_offsets = iterArgs[0]; - Value rhs_offsets = iterArgs[1]; - Value accumulator = iterArgs[2]; - Value lhs_mask = nullptr; - Value rhs_mask = nullptr; + mlir::ValueRange iter_args) { + SmallVector iter_args_next; + iter_args_next.reserve(iter_args.size()); + absl::flat_hash_map values_lhs; + absl::flat_hash_map values_rhs; + // Load tiles of all parameters of LHS and RHS scopes and advance pointers. + for (int i = 0; i < iter_args.size() - 1; ++i) { + const bool is_lhs = + i < analysis.ScopeParameters(DotFusionAnalysis::Scope::LHS).size(); + const int increment_dim0 = block_k * config.split_k() * (is_lhs ? 0 : 1); + const int increment_dim1 = block_k * config.split_k() * (is_lhs ? 1 : 0); + absl::flat_hash_map& values = + is_lhs ? values_lhs : values_rhs; + CHECK(values + .insert({iter_args_to_parameters[i], + EmitParameterLoad(b, iter_args[i], + is_lhs ? boundary_checks_lhs + : boundary_checks_rhs)}) + .second); + iter_args_next.push_back(b.create( + iter_args[i].getType(), iter_args[i], + mlir::ValueRange{CreateConst(b, i32_ty, increment_dim0), + CreateConst(b, i32_ty, increment_dim1)})); + } + // TODO(b/269726484): Peel the loop instead of inserting a masked load in // every iteration, even the ones that do not need it. const bool need_masking = k % (block_k * config.split_k()) > 0; + Value lhs_mask; + Value rhs_mask; if (need_masking) { auto elements_in_tile = b.create(CreateConst(b, i32_ty, k), ki); - lhs_mask = - Broadcast(b, - b.create(ma::CmpIPredicate::slt, - b.create(range_k, 0), - Splat(b, elements_in_tile, shape_1_k)) - .getResult() - .template cast(), - shape_m_k); - rhs_mask = - Broadcast(b, - b.create(ma::CmpIPredicate::slt, - b.create(range_k, 1), - Splat(b, elements_in_tile, shape_k_1)) - .getResult() - .template cast(), - shape_k_n); + auto range_k = b.create( + Splat(b, b.create(pid_k, CreateConst(b, i32_ty, block_k)), + block_k), + Range(b, block_k)); + lhs_mask = Broadcast( + b, + b.create(ma::CmpIPredicate::slt, + b.create(range_k, 0), + Splat(b, elements_in_tile, {1, block_k})) + .getResult() + .template cast(), + {block_m, block_k}); + rhs_mask = Broadcast( + b, + b.create(ma::CmpIPredicate::slt, + b.create(range_k, 1), + Splat(b, elements_in_tile, {block_k, 1})) + .getResult() + .template cast(), + {block_k, block_n}); } - // For now use one shape for LHS inputs and one for RHS. - absl::flat_hash_map values_lhs; - Value dot_input_lhs = - EmitScope(b, libdevice_path, fn, + // Emit all operations of LHS and RHS scopes. + TF_ASSIGN_OR_RETURN( + Value dot_input_lhs, + EmitScope(b, libdevice_path, dot_instr->parent()->MakeInstructionPostOrderFrom( const_cast(*dot_instr->operand(0))), - values_lhs, lhs_offsets, lhs_mask); - absl::flat_hash_map values_rhs; - Value dot_input_rhs = - EmitScope(b, libdevice_path, fn, + values_lhs, {block_m, block_k}, lhs_mask)); + TF_ASSIGN_OR_RETURN( + Value dot_input_rhs, + EmitScope(b, libdevice_path, dot_instr->parent()->MakeInstructionPostOrderFrom( const_cast(*dot_instr->operand(1))), - values_rhs, rhs_offsets, rhs_mask); + values_rhs, {block_k, block_n}, rhs_mask)); + // Operation in the fusion before the dot can alter the elements of the + // tiles that were zero masked during loads. These have to be zeroed here + // again just before the dot so that they do not affect the output. + // Only the K dimension needs masking here because unnecessary elements in + // the other two get discarded by the masked store at the end. if (need_masking) { dot_input_lhs = b.create(lhs_mask, dot_input_lhs, ZerosLike(b, dot_input_lhs)); @@ -1118,22 +1046,90 @@ StatusOr MatMulImpl( ZerosLike(b, dot_input_rhs)); } - auto accumulator_next = b.create( - dot_input_lhs, dot_input_rhs, accumulator, + // Execute matrix multiplication of input tiles and pass the accumulator. + Value accumulator_next = b.create( + dot_input_lhs, dot_input_rhs, iter_args.back(), /*allowTF32=*/tsl::tensor_float_32_execution_enabled()); + iter_args_next.push_back(accumulator_next); - Value lhs_offsets_next = b.create( - lhs_offsets, - CreateConst(b, int_ty, block_k * config.split_k() * stride_lhs_k, - shape_m_k)); - Value rhs_offsets_next = b.create( - rhs_offsets, - CreateConst(b, int_ty, block_k * config.split_k() * stride_rhs_k, - shape_k_n)); - - b.create( - mlir::ValueRange{lhs_offsets_next, rhs_offsets_next, accumulator_next}); + b.create(iter_args_next); + return OkStatus(); }; + + // Pointers to parameters of LHS scope, then RHS, then the accumulator + // that change with every loop iteration and are passed between them. + // LHS and RHS can use same HLO computation parameters, but because they use + // different pointers they have to be stored separately for each scope. + SmallVector iter_args; + iter_args.reserve( + analysis.ScopeParameters(DotFusionAnalysis::Scope::LHS).size() + + analysis.ScopeParameters(DotFusionAnalysis::Scope::RHS).size() + 1); + + Value lhs_offset_batch = b.create( + convert_scalar(pid_batch), CreateConst(b, int_ty, stride_lhs_batch)); + for (const HloInstruction* parameter : + analysis.ScopeParameters(DotFusionAnalysis::Scope::LHS)) { + Value base = fn.getArgument(parameter->parameter_number()); + const int64_t stride_lhs_m = + analysis + .IterSpec(DotFusionAnalysis::Scope::LHS, parameter, + lhs_noncontracting_dim_idx) + ->at(0) + .stride; + const int64_t stride_lhs_k = + analysis + .IterSpec(DotFusionAnalysis::Scope::LHS, parameter, + dims.lhs_contracting_dimensions(0)) + ->at(0) + .stride; + Value ptrs = b.create( + /*base=*/AddPtr(b, base, lhs_offset_batch), + /*shape=*/ + mlir::ValueRange{CreateConst(b, i64_ty, m), CreateConst(b, i64_ty, k)}, + /*strides=*/ + mlir::ValueRange{CreateConst(b, i64_ty, stride_lhs_m), + CreateConst(b, i64_ty, stride_lhs_k)}, + /*offsets=*/mlir::ValueRange{pid_m_offset, pid_k_offset}, + /*tensorShape=*/std::vector{block_m, block_k}, + /*order=*/std::vector{1, 0}); + CHECK(iter_args_to_parameters.insert({iter_args.size(), parameter}).second) + << parameter->ToString(); + iter_args.push_back(ptrs); + } + + Value rhs_offset_batch = b.create( + convert_scalar(pid_batch), CreateConst(b, int_ty, stride_rhs_batch)); + for (const HloInstruction* parameter : + analysis.ScopeParameters(DotFusionAnalysis::Scope::RHS)) { + Value base = fn.getArgument(parameter->parameter_number()); + const IndexT stride_rhs_k = + analysis + .IterSpec(DotFusionAnalysis::Scope::RHS, parameter, + dims.rhs_contracting_dimensions(0)) + ->at(0) + .stride; + const IndexT stride_rhs_n = + analysis + .IterSpec(DotFusionAnalysis::Scope::RHS, parameter, + rhs_noncontracting_dim_idx) + ->at(0) + .stride; + Value ptrs = b.create( + /*base=*/AddPtr(b, base, rhs_offset_batch), + /*shape=*/ + mlir::ValueRange{CreateConst(b, i64_ty, k), CreateConst(b, i64_ty, n)}, + /*strides=*/ + mlir::ValueRange{CreateConst(b, i64_ty, stride_rhs_k), + CreateConst(b, i64_ty, stride_rhs_n)}, + /*offsets=*/mlir::ValueRange{pid_k_offset, pid_n_offset}, + /*tensorShape=*/std::vector{block_k, block_n}, + /*order=*/std::vector{1, 0}); + CHECK(iter_args_to_parameters.insert({iter_args.size(), parameter}).second) + << parameter->ToString(); + iter_args.push_back(ptrs); + } + + iter_args.push_back(accumulator_init); Value acc_final = b.create( /*lowerBound=*/b.create(0, /*width=*/32), @@ -1141,43 +1137,85 @@ StatusOr MatMulImpl( /*step=*/ b.create(block_k * config.split_k(), /*width=*/32), - /*iterArgs=*/ - mlir::ValueRange{lhs_offsets_init, rhs_offsets_init, - accumulator_init}, - body_builder) - .getResult(2); + /*iterArgs=*/iter_args, body_builder) + .getResult(iter_args.size() - 1); absl::flat_hash_map values_out; values_out[dot_instr] = Cast(b, acc_final, TritonType(b, dot_instr->shape().element_type())); - // Output tile offsets. - auto out_offset_batch = b.create( - convert_scalar(pid_batch), CreateConst(b, int_ty, stride_out_batch)); - auto out_offsets_m = b.create( - b.create(convert_range(range_m), 1), - CreateConst(b, int_ty, stride_out_m, shape_m_1)); - - auto out_offsets_n = b.create( - b.create(convert_range(range_n), 0), - CreateConst(b, int_ty, stride_out_n, shape_1_n)); - auto out_offsets = b.create(Splat(b, out_offset_batch, shape_m_1), - out_offsets_m); - out_offsets = b.create( - Broadcast(b, out_offsets.getResult().template cast(), - shape_m_n), - Broadcast(b, out_offsets_n.getResult().template cast(), - shape_m_n)); - - // Output tile mask: check that the indices are within [M, N]. - auto rm_cmp = b.create(ma::CmpIPredicate::slt, - b.create(range_m, 1), - CreateConst(b, i32_ty, m, shape_m_1)); - auto rn_cmp = b.create(ma::CmpIPredicate::slt, - b.create(range_n, 0), - CreateConst(b, i32_ty, n, shape_1_n)); - auto out_mask = b.create( - Broadcast(b, rm_cmp.getResult().template cast(), shape_m_n), - Broadcast(b, rn_cmp.getResult().template cast(), shape_m_n)); + // Generate tensor pointer for a parameter load or output store within the + // dot's output scope. + auto output_scope_tensor_pointer = [&](const HloInstruction* hlo, Value base, + bool add_split_k_offset) { + const IndexT stride_m = + analysis + .IterSpec(DotFusionAnalysis::Scope::OUTPUT, hlo, lhs_nc_out_idx) + ->at(0) + .stride; + { + IndexT stride_batch = 0; + if (have_batch) { + stride_batch = + analysis + .IterSpec(DotFusionAnalysis::Scope::OUTPUT, hlo, batch_out_idx) + ->at(0) + .stride; + CHECK_GE(stride_batch, 1); + } + { + const TensorIterationSpec::DimIterationSpec* spec = analysis.IterSpec( + DotFusionAnalysis::Scope::OUTPUT, hlo, lhs_nc_out_idx); + if (spec->size() > 1) { + CHECK_EQ(spec->size(), 2); + // Support one specific kind of output transpose that splits the + // dimension originating from the split LHS non-contracting one. + CHECK(!have_batch); + CHECK(lhs_nc_split); + CHECK_EQ(spec->at(1).count, batch_size); + stride_batch = spec->at(1).stride; + } else if (lhs_nc_split) { + // Dimension of the output produced by the non-contracting LHS one + // is physically contiguous though the producing LHS one is split. + // Because the major part of the split is implemented using the batch + // logic stride_out_batch is populated here as the stride of the minor + // part times its size. + stride_batch = stride_m * m; + } + } + Value offset_batch = b.create( + convert_scalar(pid_batch), CreateConst(b, int_ty, stride_batch)); + base = AddPtr(b, base, offset_batch); + } + if (add_split_k_offset) { + IndexT stride_split_k = 0; + if (have_split_k) { + stride_split_k = analysis + .IterSpec(DotFusionAnalysis::Scope::OUTPUT, hlo, + split_k_out_idx) + ->at(0) + .stride; + CHECK_GE(stride_split_k, 1); + } + Value offset_split_k = b.create( + convert_scalar(pid_k), CreateConst(b, int_ty, stride_split_k)); + base = AddPtr(b, base, offset_split_k); + } + const IndexT stride_n = + analysis + .IterSpec(DotFusionAnalysis::Scope::OUTPUT, hlo, rhs_nc_out_idx) + ->at(0) + .stride; + return b.create( + /*base=*/base, + /*shape=*/ + mlir::ValueRange{CreateConst(b, i64_ty, m), CreateConst(b, i64_ty, n)}, + /*strides=*/ + mlir::ValueRange{CreateConst(b, i64_ty, stride_m), + CreateConst(b, i64_ty, stride_n)}, + /*offsets=*/mlir::ValueRange{pid_m_offset, pid_n_offset}, + /*tensorShape=*/std::vector{block_m, block_n}, + /*order=*/std::vector{1, 0}); + }; // Collect all instructions of the dot's output scope. absl::flat_hash_set to_order; @@ -1207,23 +1245,34 @@ StatusOr MatMulImpl( to_emit.push_back(hlo); } } + // Emit the output scope. if (!to_emit.empty()) { - EmitScope(b, libdevice_path, fn, to_emit, values_out, out_offsets, - out_mask); + for (const HloInstruction* parameter : + analysis.ScopeParameters(DotFusionAnalysis::Scope::OUTPUT)) { + Value tensor_pointer = output_scope_tensor_pointer( + parameter, fn.getArgument(parameter->parameter_number()), + /*add_split_k_offset=*/false); + CHECK(values_out + .insert({parameter, EmitParameterLoad(b, tensor_pointer, + boundary_checks_out)}) + .second); + } + TF_RETURN_IF_ERROR(EmitScope(b, libdevice_path, to_emit, values_out, + {block_m, block_n}, /*tile_mask=*/{}) + .status()); } - auto out_offset_split_k = b.create( - convert_scalar(pid_k), CreateConst(b, int_ty, stride_out_split_k)); - out_offsets = b.create(out_offsets, - Splat(b, out_offset_split_k, shape_m_n)); + // Emit tensor store operations for all outputs. for (int i = 0; i < fn.getNumArguments() - dot_instr->parent()->num_parameters(); ++i) { - Value out = fn.getArgument(i + dot_instr->parent()->num_parameters()); const HloInstruction* producer = root->shape().IsTuple() ? root->operand(i) : root; - b.create(AddPtr(b, Splat(b, out, shape_m_n), out_offsets), - values_out[producer], out_mask, - mt::CacheModifier::NONE, mt::EvictionPolicy::NORMAL); + Value tensor_pointer = output_scope_tensor_pointer( + producer, fn.getArgument(i + dot_instr->parent()->num_parameters()), + /*add_split_k_offset=*/true); + b.create(tensor_pointer, values_out[producer], + boundary_checks_out, mt::CacheModifier::NONE, + mt::EvictionPolicy::NORMAL); } return launch_dimensions; } @@ -1301,35 +1350,37 @@ StatusOr SoftMax(mlir::OpBuilder builder, for (int minor_axis = 1; minor_axis < reduce_input_shape.rank(); ++minor_axis) num_rows *= reduce_input_shape.dimensions_minor(minor_axis); - // softmax_kernel(input_ptr, output_ptr, num_rows, row_len, block_row) { - // row_index = tl.program_id(0) - // row_stride = row_len - // offset = row_index * row_stride Value row_index = b.create(mt::ProgramIDDim::X); - Value row_stride = b.create(row_len, /*width=*/32); - Value offset = b.create(row_index, row_stride); - - // row_tile = tl.arange(0, block_row) + offset - Value splat_offsets = Splat(b, offset, block_row); - Value row_tile = b.create(splat_offsets, Range(b, block_row)); - - // mask = row_tile < row_stride - Value splat_row_stride = Splat(b, row_stride, block_row); - Value mask = b.create(ma::CmpIPredicate::slt, Range(b, block_row), - splat_row_stride); + Value row_stride = CreateConst(b, b.getI32Type(), row_len); absl::flat_hash_map values_out; - Value result = - EmitScope(b, libdevice_path, fn, computation->MakeInstructionPostOrder(), - values_out, row_tile, mask); - - // tl.store(output_ptr + row_tile, result, mask=mask) - Value splat_output_ptr = Splat(b, fn.getArgument(1), block_row); - Value store_ptrs = AddPtr(b, splat_output_ptr, row_tile); + auto make_tensor_pointer = [&](Value base) { + Value offset = b.create(row_index, row_stride); + return b.create( + /*base=*/AddPtr(b, base, offset), + /*shape=*/mlir::ValueRange{CreateConst(b, b.getI64Type(), row_len)}, + /*strides=*/mlir::ValueRange{CreateConst(b, b.getI64Type(), 1)}, + /*offsets=*/mlir::ValueRange{CreateConst(b, b.getI32Type(), 0)}, + /*tensorShape=*/std::vector{block_row}, + /*order=*/std::vector{0}); + }; - b.create(store_ptrs, result, mask, mt::CacheModifier::NONE, + std::vector boundary_checks; + if (block_row != row_len) { + boundary_checks.push_back(0); + } + values_out[computation->parameter_instruction(0)] = EmitParameterLoad( + b, make_tensor_pointer(fn.getArgument(0)), boundary_checks); + Value mask = b.create(ma::CmpIPredicate::slt, Range(b, block_row), + Splat(b, row_stride, block_row)); + TF_ASSIGN_OR_RETURN( + Value result, + EmitScope(b, libdevice_path, computation->MakeInstructionPostOrder(), + values_out, {block_row}, mask)); + + b.create(make_tensor_pointer(fn.getArgument(1)), result, + std::vector{0}, mt::CacheModifier::NONE, mt::EvictionPolicy::NORMAL); - // } const LaunchDimensions launch_dimensions{ {num_rows, 1, 1}, {config.num_warps() * WarpSize(), 1, 1}}; From 76d0af9cfd48198082af6f3cdec26520abc753a1 Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Mon, 17 Jul 2023 03:36:48 -0700 Subject: [PATCH 368/376] [XLA:GPU] Prevent matching converts from/to bf16 in Triton Softmax rewriter if the CUDA compute capability is older than Ampere, since they result in unsupported PTX instructions. PiperOrigin-RevId: 548638356 --- .../xla/service/gpu/ir_emitter_triton_test.cc | 50 +++++++++++ .../service/gpu/softmax_rewriter_triton.cc | 86 +++++++++++-------- .../gpu/softmax_rewriter_triton_test.cc | 41 +++++++++ 3 files changed, 143 insertions(+), 34 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_triton_test.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_triton_test.cc index 51183398f4bef9..272b3520ae7af7 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_triton_test.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_triton_test.cc @@ -1970,6 +1970,13 @@ class TritonSoftmaxTest : public GpuCodegenTest { debug_options.set_xla_gpu_enable_triton_softmax_fusion(true); return debug_options; } + + se::CudaComputeCapability GetCudaComputeCapability() { + return backend() + .default_stream_executor() + ->GetDeviceDescription() + .cuda_compute_capability(); + } }; TEST_F(TritonSoftmaxTest, CanFuseAndEmitExactSoftmaxF32) { @@ -2328,6 +2335,49 @@ ENTRY main { EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec(1e-6, 1e-6))); } +TEST_F( + TritonSoftmaxTest, + CanFuseAndEmitConvertInvolvingBF16InputIntoSoftmaxDiamondCorrectlyForAmpereAndVoltaComputeCapability) { // NOLINT(whitespace/line_length) + const std::string hlo_text = R"( +HloModule softmax +max_computation { + arg_0 = f32[] parameter(0) + arg_1 = f32[] parameter(1) + ROOT maximum = f32[] maximum(arg_0, arg_1) +} +ENTRY main { + param_0 = bf16[127,125]{1,0} parameter(0) + param_0_f32 = f32[127,125]{1,0} convert(param_0) + constant_neg_inf = f32[] constant(-inf) + reduce = f32[127]{0} reduce(param_0_f32, constant_neg_inf), dimensions={1}, to_apply=max_computation + broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0} + ROOT subtract = f32[127,125]{1,0} subtract(param_0_f32, broadcast) +} +)"; + + if (GetCudaComputeCapability().IsAtLeast(se::CudaComputeCapability::AMPERE)) { + MatchOptimizedHlo(hlo_text, R"( +; CHECK: ENTRY +; CHECK: %[[P0:.*]] = bf16[127,125]{1,0} parameter(0) +; CHECK: ROOT +; CHECK-SAME: fusion(%[[P0]]) +; CHECK-SAME: kind=kCustom +; CHECK-SAME: __triton_softmax +)"); + } else { + MatchOptimizedHlo(hlo_text, R"( +; CHECK: ENTRY +; CHECK: %[[P0:.*]] = bf16[127,125]{1,0} parameter(0) +; CHECK: %[[CONVERT:.*]] = f32[127,125]{1,0} convert(%[[P0]]) +; CHECK: ROOT +; CHECK-SAME: fusion(%[[CONVERT]]) +; CHECK-SAME: kind=kCustom +; CHECK-SAME: __triton_softmax +)"); + } + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec(1e-6, 1e-6))); +} + } // namespace } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/softmax_rewriter_triton.cc b/tensorflow/compiler/xla/service/gpu/softmax_rewriter_triton.cc index 0c6bc5c8997182..69c77cb250c60c 100644 --- a/tensorflow/compiler/xla/service/gpu/softmax_rewriter_triton.cc +++ b/tensorflow/compiler/xla/service/gpu/softmax_rewriter_triton.cc @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" #include "tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_types.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/util.h" @@ -42,16 +43,23 @@ bool HasDefaultLayout(const Shape& shape) { LayoutUtil::IsMonotonicWithDim0Major(shape.layout()); } -bool IsTritonSupportedInstruction(const HloInstruction* instr) { +bool IsTritonSupportedInstruction(const HloInstruction* instr, + const GpuVersion& gpu_version) { // TODO(bchetioui): expand with non-trivial instructions. if (instr->IsElementwise()) { + if (instr->opcode() == HloOpcode::kConvert && + (instr->operand(0)->shape().element_type() == BF16 || + instr->shape().element_type() == BF16) && + !std::get(gpu_version) + .IsAtLeast(stream_executor::CudaComputeCapability::AMPERE)) { + return false; + } return IsTritonSupportedElementwise(instr->opcode(), instr->shape().element_type()); } switch (instr->opcode()) { case HloOpcode::kBitcast: - case HloOpcode::kConvert: case HloOpcode::kParameter: return true; default: @@ -64,9 +72,10 @@ bool IsTritonSupportedInstruction(const HloInstruction* instr) { // set to it. The definition of "trivial" operations is as given in // 'IsTriviallyFusible'. bool TrivialEdge(HloInstruction** producer, HloInstruction* consumer, - HloOpcode opcode); + HloOpcode opcode, const GpuVersion& gpu_version); -bool BitcastIsTilingNoop(HloInstruction* bitcast) { +bool BitcastIsTilingNoop(HloInstruction* bitcast, + const GpuVersion& gpu_version) { CHECK_EQ(bitcast->opcode(), HloOpcode::kBitcast); if (bitcast->shape().rank() == 0) { @@ -88,7 +97,8 @@ bool BitcastIsTilingNoop(HloInstruction* bitcast) { }; HloInstruction* reduce = nullptr; - TrivialEdge(&reduce, bitcast->mutable_operand(0), HloOpcode::kReduce); + TrivialEdge(&reduce, bitcast->mutable_operand(0), HloOpcode::kReduce, + gpu_version); return (HasDefaultLayout(bitcast->shape()) && HasDefaultLayout(bitcast->operand(0)->shape()) && @@ -96,7 +106,8 @@ bool BitcastIsTilingNoop(HloInstruction* bitcast) { last_dimension(bitcast->operand(0)) == last_dimension(bitcast))); } -bool IsTriviallyFusible(HloInstruction* instr, int num_allowed_users = 1) { +bool IsTriviallyFusible(HloInstruction* instr, const GpuVersion& gpu_version, + int num_allowed_users = 1) { // Checks whether an op is trivially fusible. An op is said to be trivially // fusible if it does not increase the amount of memory read/written by the // resulting fusion, is compatible with any chosen tiling, and can be @@ -107,21 +118,22 @@ bool IsTriviallyFusible(HloInstruction* instr, int num_allowed_users = 1) { return false; } - if (instr->opcode() == HloOpcode::kBitcast && BitcastIsTilingNoop(instr)) { + if (instr->opcode() == HloOpcode::kBitcast && + BitcastIsTilingNoop(instr, gpu_version)) { return true; } if (instr->IsElementwise() && instr->operand_count() == 1) { - return IsTritonSupportedInstruction(instr); + return IsTritonSupportedInstruction(instr, gpu_version); } return false; } bool TrivialEdge(HloInstruction** producer, HloInstruction* consumer, - HloOpcode opcode) { + HloOpcode opcode, const GpuVersion& gpu_version) { while (consumer->opcode() != opcode) { - if (IsTriviallyFusible(consumer)) { + if (IsTriviallyFusible(consumer, gpu_version)) { consumer = consumer->mutable_operand(0); } else { return false; @@ -133,18 +145,20 @@ bool TrivialEdge(HloInstruction** producer, HloInstruction* consumer, } bool IsTriviallyConnectedProducerOf(HloInstruction* producer, - HloInstruction* consumer) { + HloInstruction* consumer, + const GpuVersion& gpu_version) { if (producer == consumer) { return true; } HloInstruction* found_producer = consumer; - while (TrivialEdge(&found_producer, consumer, producer->opcode())) { + while ( + TrivialEdge(&found_producer, consumer, producer->opcode(), gpu_version)) { if (found_producer == producer) { return true; } - if (!IsTriviallyFusible(found_producer)) { + if (!IsTriviallyFusible(found_producer, gpu_version)) { return false; } @@ -158,9 +172,10 @@ inline bool HasOneUse(const HloInstruction* instr) { return instr->user_count() == 1; } -bool IsTritonSupportedComputation(const HloComputation* computation) { +bool IsTritonSupportedComputation(const HloComputation* computation, + const GpuVersion& gpu_version) { for (const HloInstruction* instr : computation->instructions()) { - if (!IsTritonSupportedInstruction(instr)) { + if (!IsTritonSupportedInstruction(instr, gpu_version)) { return false; } } @@ -168,7 +183,7 @@ bool IsTritonSupportedComputation(const HloComputation* computation) { } std::optional MatchesTritonCompatibleClosedReductionDiamond( - HloInstruction* instr) { + HloInstruction* instr, const GpuVersion& gpu_version) { // Return the producer of the following pattern: // // producer @@ -188,7 +203,8 @@ std::optional MatchesTritonCompatibleClosedReductionDiamond( // array. std::optional match_failure = std::nullopt; - if (!instr->IsElementwiseBinary() || !IsTritonSupportedInstruction(instr)) { + if (!instr->IsElementwiseBinary() || + !IsTritonSupportedInstruction(instr, gpu_version)) { return match_failure; } @@ -197,13 +213,13 @@ std::optional MatchesTritonCompatibleClosedReductionDiamond( HloInstruction* reduce; if (!(TrivialEdge(&broadcast, instr->mutable_operand(1), - HloOpcode::kBroadcast) && - TrivialEdge(&reduce, broadcast->mutable_operand(0), - HloOpcode::kReduce) && + HloOpcode::kBroadcast, gpu_version) && + TrivialEdge(&reduce, broadcast->mutable_operand(0), HloOpcode::kReduce, + gpu_version) && HasDefaultLayout(broadcast->shape()) && HasDefaultLayout(reduce->shape()) && reduce->operand_count() == 2 && reduce->operand(1)->opcode() == HloOpcode::kConstant && - IsTritonSupportedComputation(reduce->to_apply()))) { + IsTritonSupportedComputation(reduce->to_apply(), gpu_version))) { return match_failure; } @@ -226,12 +242,13 @@ std::optional MatchesTritonCompatibleClosedReductionDiamond( return match_failure; } - while (IsTriviallyFusible(producer)) { + while (IsTriviallyFusible(producer, gpu_version)) { producer = producer->mutable_operand(0); } if (!HasDefaultLayout(producer->shape()) || - !IsTriviallyConnectedProducerOf(producer, instr->mutable_operand(0)) || + !IsTriviallyConnectedProducerOf(producer, instr->mutable_operand(0), + gpu_version) || !(producer == instr->operand(0) || instr->operand(0)->user_count() == 1)) { return match_failure; @@ -247,10 +264,11 @@ std::optional MatchesTritonCompatibleClosedReductionDiamond( // that instruction is used more than once, and/or is not trivially // fusible. HloInstruction* FindFirstNonFusibleDiamondProducer( - HloInstruction* diamond_producer) { - if (IsTriviallyFusible(diamond_producer, /*num_allowed_users=*/2)) { + HloInstruction* diamond_producer, const GpuVersion& gpu_version) { + if (IsTriviallyFusible(diamond_producer, gpu_version, + /*num_allowed_users=*/2)) { diamond_producer = diamond_producer->mutable_operand(0); - while (IsTriviallyFusible(diamond_producer)) { + while (IsTriviallyFusible(diamond_producer, gpu_version)) { diamond_producer = diamond_producer->mutable_operand(0); } } @@ -340,8 +358,8 @@ StatusOr SoftmaxRewriterTriton::Run( continue; } - if (auto producer = - MatchesTritonCompatibleClosedReductionDiamond(instr)) { + if (auto producer = MatchesTritonCompatibleClosedReductionDiamond( + instr, gpu_version_)) { matched_diamonds.push_back(DiamondDescriptor{instr, producer.value()}); } } @@ -364,9 +382,9 @@ StatusOr SoftmaxRewriterTriton::Run( return instr->operand(0)->shape().dimensions(operand_rank - 1); }; - auto last_trivially_fusible_user = [](HloInstruction* instr) { + auto last_trivially_fusible_user = [&](HloInstruction* instr) { while (HasOneUse(instr) && !instr->IsRoot() && - IsTriviallyFusible(instr->users().front())) { + IsTriviallyFusible(instr->users().front(), gpu_version_)) { instr = instr->users().front(); } @@ -375,7 +393,7 @@ StatusOr SoftmaxRewriterTriton::Run( // restriction. if (HasOneUse(instr) && !instr->IsRoot() && IsTriviallyFusible( - instr->users().front(), + instr->users().front(), gpu_version_, /*num_allowed_users=*/instr->users().front()->user_count())) { instr = instr->users().front(); } @@ -399,8 +417,8 @@ StatusOr SoftmaxRewriterTriton::Run( // Crucially, this approach relies on a diamond root never being considered a // trivially fusible operation. std::vector diamond_chains; - HloInstruction* current_fusion_producer = - FindFirstNonFusibleDiamondProducer(matched_diamonds.front().producer); + HloInstruction* current_fusion_producer = FindFirstNonFusibleDiamondProducer( + matched_diamonds.front().producer, gpu_version_); int current_reduce_dimension_size = reduction_dimension_size_from_diamond_root(matched_diamonds.front().root); @@ -411,7 +429,7 @@ StatusOr SoftmaxRewriterTriton::Run( matched_diamonds[diamond_idx - 1].root; HloInstruction* first_non_fusible_diamond_producer = - FindFirstNonFusibleDiamondProducer(diamond_producer); + FindFirstNonFusibleDiamondProducer(diamond_producer, gpu_version_); int diamond_reduce_dimension_size = reduction_dimension_size_from_diamond_root(diamond_root); diff --git a/tensorflow/compiler/xla/service/gpu/softmax_rewriter_triton_test.cc b/tensorflow/compiler/xla/service/gpu/softmax_rewriter_triton_test.cc index ab4022d0bd038d..4c55e6908734ae 100644 --- a/tensorflow/compiler/xla/service/gpu/softmax_rewriter_triton_test.cc +++ b/tensorflow/compiler/xla/service/gpu/softmax_rewriter_triton_test.cc @@ -810,6 +810,47 @@ ENTRY main { EXPECT_FALSE(fusion_rewriter.Run(module.get()).value()); } +TEST_F( + SoftmaxRewriterTritonTest, + CanOnlyFuseConvertInvolvingBF16InputIntoSoftmaxDiamondWithAtLeastAmpereComputeCapability) { // NOLINT(whitespace/line_length) + const std::string hlo_string = R"( +HloModule softmax +max_computation { + arg_0 = f32[] parameter(0) + arg_1 = f32[] parameter(1) + ROOT maximum = f32[] maximum(arg_0, arg_1) +} +ENTRY main { + param_0 = bf16[127,125]{1,0} parameter(0) + param_0_f32 = f32[127,125]{1,0} convert(param_0) + constant_neg_inf = f32[] constant(-inf) + reduce = f32[127]{0} reduce(param_0_f32, constant_neg_inf), dimensions={1}, to_apply=max_computation + broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0} + ROOT subtract = f32[127,125]{1,0} subtract(param_0_f32, broadcast) +} +)"; + auto ampere_module = ParseAndReturnVerifiedModule(hlo_string).value(); + auto volta_module = ampere_module->Clone(); + + // Ampere + SoftmaxRewriterTriton fusion_rewriter_ampere( + se::CudaComputeCapability{se::CudaComputeCapability::AMPERE, 0}); + EXPECT_TRUE(fusion_rewriter_ampere.Run(ampere_module.get()).value()); + EXPECT_TRUE(verifier().Run(ampere_module.get()).status().ok()); + VLOG(2) << ampere_module->ToString(); + EXPECT_THAT(ampere_module->entry_computation()->root_instruction(), + GmockMatch(m::Fusion(m::Parameter()))); + + // Volta (pre-Ampere) + SoftmaxRewriterTriton fusion_rewriter_volta( + se::CudaComputeCapability{se::CudaComputeCapability::VOLTA, 0}); + EXPECT_TRUE(fusion_rewriter_volta.Run(volta_module.get()).value()); + EXPECT_TRUE(verifier().Run(volta_module.get()).status().ok()); + VLOG(2) << volta_module->ToString(); + EXPECT_THAT(volta_module->entry_computation()->root_instruction(), + GmockMatch(m::Fusion(m::Convert(m::Parameter())))); +} + } // anonymous namespace } // namespace gpu } // namespace xla From e1ad3b74ad44b883c7b3fdc3a19adcea1d28bfbc Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Mon, 17 Jul 2023 04:12:08 -0700 Subject: [PATCH 369/376] [XLA:GPU] Handle edge case in Triton Softmax rewriter where bitcast is an effective scalar. This short-circuit avoids crashing within last_dimension when attempting to match and either the operand or the result of the bitcast has a shape with rank 0. PiperOrigin-RevId: 548645429 --- tensorflow/compiler/xla/service/gpu/softmax_rewriter_triton.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/compiler/xla/service/gpu/softmax_rewriter_triton.cc b/tensorflow/compiler/xla/service/gpu/softmax_rewriter_triton.cc index 69c77cb250c60c..3b686af91a2631 100644 --- a/tensorflow/compiler/xla/service/gpu/softmax_rewriter_triton.cc +++ b/tensorflow/compiler/xla/service/gpu/softmax_rewriter_triton.cc @@ -78,7 +78,7 @@ bool BitcastIsTilingNoop(HloInstruction* bitcast, const GpuVersion& gpu_version) { CHECK_EQ(bitcast->opcode(), HloOpcode::kBitcast); - if (bitcast->shape().rank() == 0) { + if (ShapeUtil::IsEffectiveScalar(bitcast->shape())) { return true; } From 8c1827c1a7a3048599d3606b9f4bee8c8a1c808b Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Mon, 17 Jul 2023 05:24:41 -0700 Subject: [PATCH 370/376] [XLA:GPU] Flip default for --xla_gpu_enable_triton_softmax_fusion flag. PiperOrigin-RevId: 548658687 --- tensorflow/compiler/xla/debug_options_flags.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/compiler/xla/debug_options_flags.cc b/tensorflow/compiler/xla/debug_options_flags.cc index facc9f0cc482f0..5ae440522cd384 100644 --- a/tensorflow/compiler/xla/debug_options_flags.cc +++ b/tensorflow/compiler/xla/debug_options_flags.cc @@ -160,7 +160,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_gpu_enable_triton_gemm(true); opts.set_xla_gpu_enable_cudnn_int8x32_convolution_reordering(true); opts.set_xla_gpu_triton_gemm_any(false); - opts.set_xla_gpu_enable_triton_softmax_fusion(false); + opts.set_xla_gpu_enable_triton_softmax_fusion(true); opts.set_xla_gpu_triton_fusion_level(1); // Moving reduce-scatter out of while loops can increase memory footprint, so From 90513cc320cb098a43edd266d062c6dd57e3d688 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 17 Jul 2023 05:58:16 -0700 Subject: [PATCH 371/376] [XLA:GPU] Only use nvlink for linking if it is at least as new as the ptxas we are using. This fixes a failure in the case that the user installs a new ptxas from the CUDA pip packages, but has an older nvlink installed system-wide that cannot understand the output of ptxas. Fixes https://github.com/google/jax/issues/16586 PiperOrigin-RevId: 548664235 --- tensorflow/compiler/xla/service/gpu/BUILD | 1 + .../xla/service/gpu/nvptx_compiler.cc | 100 ++++++++++++------ .../compiler/xla/service/gpu/nvptx_compiler.h | 20 ++-- 3 files changed, 77 insertions(+), 44 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index c3f4bf4ddaeef4..6816bff32b3f4d 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -2579,6 +2579,7 @@ cc_library( ":triangular_solve_rewriter", ":triton_autotuner", "@com_google_absl//absl/base", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:node_hash_map", "@com_google_absl//absl/strings:str_format", "@llvm-project//llvm:IRReader", diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc index fda453bbe47e89..2bccdceadb15c9 100644 --- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/nvptx_compiler.h" +#include #include #include #include @@ -623,7 +624,8 @@ std::vector NVPTXCompiler::CompileGpuAsmOrGetCachedResult( return cache_value->cubin_data; } -static bool UseNvlink(const std::string& preferred_cuda_dir) { +static std::optional> GetNvLinkVersion( + const std::string& preferred_cuda_dir) { const bool use_nvlink_by_default = #ifdef TF_DISABLE_NVLINK_BY_DEFAULT false; @@ -636,13 +638,66 @@ static bool UseNvlink(const std::string& preferred_cuda_dir) { use_nvlink_by_default, &use_nvlink)); if (!use_nvlink) { - return false; + return std::nullopt; } // Make sure nvlink exists and is executable. const std::string bin_path = se::FindCudaExecutable("nvlink", preferred_cuda_dir); - return se::GetToolVersion(bin_path).ok(); + auto version = se::GetToolVersion(bin_path); + if (!version.ok()) { + return std::nullopt; + } + return *version; +} + +StatusOr NVPTXCompiler::ChooseLinkingMethod( + const std::string& preferred_cuda_dir) { + { + absl::MutexLock lock(&mutex_); + auto it = linking_methods_.find(preferred_cuda_dir); + if (it != linking_methods_.end()) { + return it->second; + } + } + + LinkingMethod linking_method = LinkingMethod::kNone; + TF_ASSIGN_OR_RETURN(auto ptxas_version_tuple, + se::GetAsmCompilerVersion(preferred_cuda_dir)); + + static const std::optional> nvlink_version = + GetNvLinkVersion(preferred_cuda_dir); + if (nvlink_version && *nvlink_version >= ptxas_version_tuple) { + linking_method = LinkingMethod::kNvLink; + } else { + int ptxas_version = std::get<0>(ptxas_version_tuple) * 1000 + + std::get<1>(ptxas_version_tuple) * 10; + int driver_version; + if (!se::gpu::GpuDriver::GetDriverVersion(&driver_version)) { + return FailedPrecondition("Unable to get CUDA driver version"); + } + bool ok = driver_version >= ptxas_version; + if (!ok) { + LOG_FIRST_N(WARNING, 1) + << "The NVIDIA driver's CUDA version is " + << absl::StrFormat("%d.%d", driver_version / 1000, + (driver_version % 1000) / 10) + << " which is older than the ptxas CUDA version " + << absl::StrFormat("(%d.%d.%d)", std::get<0>(ptxas_version_tuple), + std::get<1>(ptxas_version_tuple), + std::get<2>(ptxas_version_tuple)) + << ". Because the driver is older than the ptxas version, XLA is " + "disabling parallel compilation, which may slow down compilation. " + "You should update your NVIDIA driver or use the NVIDIA-provided " + "CUDA forward compatibility packages."; + } + linking_method = LinkingMethod::kDriver; + } + { + absl::MutexLock lock(&mutex_); + linking_methods_[preferred_cuda_dir] = linking_method; + } + return linking_method; } StatusOr NVPTXCompiler::CanUseLinkModules( @@ -651,37 +706,9 @@ StatusOr NVPTXCompiler::CanUseLinkModules( // robust if we simply tried to link something the first time we compile. auto ptxas_config = PtxOptsFromDebugOptions(hlo_module_config.debug_options()); - - static const bool use_nvlink = UseNvlink(ptxas_config.preferred_cuda_dir); - if (use_nvlink) { - return true; - } - - TF_ASSIGN_OR_RETURN( - auto ptxas_version_tuple, - se::GetAsmCompilerVersion(ptxas_config.preferred_cuda_dir)); - int ptxas_version = std::get<0>(ptxas_version_tuple) * 1000 + - std::get<1>(ptxas_version_tuple) * 10; - int driver_version; - if (!se::gpu::GpuDriver::GetDriverVersion(&driver_version)) { - return FailedPrecondition("Unable to get CUDA driver version"); - } - bool ok = driver_version >= ptxas_version; - if (!ok) { - LOG_FIRST_N(WARNING, 1) - << "The NVIDIA driver's CUDA version is " - << absl::StrFormat("%d.%d", driver_version / 1000, - (driver_version % 1000) / 10) - << " which is older than the ptxas CUDA version " - << absl::StrFormat("(%d.%d.%d)", std::get<0>(ptxas_version_tuple), - std::get<1>(ptxas_version_tuple), - std::get<2>(ptxas_version_tuple)) - << ". Because the driver is older than the ptxas version, XLA is " - "disabling parallel compilation, which may slow down compilation. " - "You should update your NVIDIA driver or use the NVIDIA-provided " - "CUDA forward compatibility packages."; - } - return ok; + TF_ASSIGN_OR_RETURN(LinkingMethod linking_method, + ChooseLinkingMethod(ptxas_config.preferred_cuda_dir)); + return linking_method != LinkingMethod::kNone; } StatusOr> NVPTXCompiler::LinkModules( @@ -696,7 +723,10 @@ StatusOr> NVPTXCompiler::LinkModules( } auto context = static_cast( stream_exec->implementation()->GpuContextHack()); - if (UseNvlink(ptxas_config.preferred_cuda_dir)) { + + TF_ASSIGN_OR_RETURN(LinkingMethod linking_method, + ChooseLinkingMethod(ptxas_config.preferred_cuda_dir)); + if (linking_method == LinkingMethod::kNvLink) { return LinkUsingNvlink(debug_options.xla_gpu_cuda_data_dir(), context, images); } diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h index 3b8906bbe46d41..500037d083ec35 100644 --- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h +++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" #include "absl/container/node_hash_map.h" #include "tensorflow/compiler/xla/service/gpu/gpu_compiler.h" #include "tensorflow/compiler/xla/statusor.h" @@ -85,15 +86,16 @@ class NVPTXCompiler : public GpuCompiler { absl::Mutex mutex_; - // When compiling an HLO module, we need to find a path to the nvvm libdevice - // files. We search in the module's config.debug_options().cuda_data_dir() - // and in tensorflow::LibdeviceRoot(), the latter of which is a constant. - // - // We cache the cuda_data_dir() and the result of our search, so that if the - // next module we have to compile has the same cuda_data_dir(), we can skip - // the search. - std::string cached_cuda_data_dir_ ABSL_GUARDED_BY(mutex_); - std::string cached_libdevice_dir_ ABSL_GUARDED_BY(mutex_); + enum class LinkingMethod { + kNone, + kNvLink, + kDriver, + }; + absl::flat_hash_map linking_methods_ + ABSL_GUARDED_BY(mutex_); + + StatusOr ChooseLinkingMethod( + const std::string& preferred_cuda_dir); // Tries to compile the given ptx string to cubin. Returns a vector with the // compiled cubin. If compilation was unsuccessful, returns an empty vector. From 26ad97df1b2410e128b2b421d39f7f75b001d778 Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Mon, 17 Jul 2023 06:57:43 -0700 Subject: [PATCH 372/376] Compute MakeEmbeddedComputationsList() iteratively. This avoids that we run out of stack space if the number of stacked computations is huge. PiperOrigin-RevId: 548678213 --- .../compiler/xla/hlo/ir/hlo_computation.cc | 90 +++++++++++-------- .../compiler/xla/hlo/ir/hlo_computation.h | 7 ++ 2 files changed, 62 insertions(+), 35 deletions(-) diff --git a/tensorflow/compiler/xla/hlo/ir/hlo_computation.cc b/tensorflow/compiler/xla/hlo/ir/hlo_computation.cc index 4c096e58b5ed1a..7fd242ca42c32f 100644 --- a/tensorflow/compiler/xla/hlo/ir/hlo_computation.cc +++ b/tensorflow/compiler/xla/hlo/ir/hlo_computation.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -42,6 +43,7 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/tsl/platform/errors.h" +#include "tensorflow/tsl/platform/logging.h" #include "tensorflow/tsl/platform/status.h" namespace xla { @@ -375,6 +377,22 @@ Status HloComputation::RemoveInstructionImpl(HloInstruction* instruction, return OkStatus(); } +HloInstruction* HloComputation::NextInstruction(HloInstruction* current) { + InstructionList::iterator instructions_it; + if (current == nullptr) { + instructions_it = instructions_.begin(); + } else { + auto it = instruction_iterators_.find(current); + CHECK(it != instruction_iterators_.end()); + instructions_it = it->second; + ++instructions_it; + } + if (instructions_it == instructions_.end()) { + return nullptr; + } + return instructions_it->get(); +} + void HloComputation::set_root_instruction(HloInstruction* new_root_instruction, bool accept_different_shape) { // The shape of the root (ignoring layout) is an invariant of the computation @@ -407,25 +425,6 @@ void HloComputation::set_root_instruction(HloInstruction* new_root_instruction, root_instruction_ = new_root_instruction; } -namespace { - -// Helper which builds a post order of the HLO call graph. -void ComputeComputationPostOrder(HloComputation* computation, - absl::flat_hash_set* visited, - std::vector* post_order) { - if (visited->insert(computation).second) { - for (auto* instruction : computation->instructions()) { - for (HloComputation* called_computation : - instruction->called_computations()) { - ComputeComputationPostOrder(called_computation, visited, post_order); - } - } - post_order->push_back(computation); - } -} - -} // namespace - void HloComputation::ComputeInstructionPostOrder( HloInstruction* root, const ChannelDependencies& channel_dependencies, absl::flat_hash_map& visited, @@ -583,21 +582,43 @@ std::vector HloComputation::MakeEmbeddedComputationsList() const { absl::flat_hash_set visited; std::vector post_order; - - // To avoid special handling of this computation, cast away const of - // 'this'. 'this' is immediately removed from the post order after - // construction. - // - // TODO(b/78350259): This violates const-correctness, since while the original - // computation is not returned, we still retrieve non-const computations from - // a const one. Consider also avoiding const for HloComputation, or review XLA - // for const-correctness of non-HloInstruction* types like this. - ComputeComputationPostOrder(const_cast(this), &visited, - &post_order); - - // We don't want to include this computation in the post order. - CHECK_EQ(this, post_order.back()); - post_order.pop_back(); + // The first element of the pair is the currently processed computation, the + // second is the instruction within the computation that is currently being + // processed. 'nullptr' for the instruction indicates that no instruction has + // been processed so far. + std::stack> st; + + // We cannot directly push (this, nullptr) to the stack, as the stack should + // contain only mutable computations. Also, we don't want to include the + // computation itself in the list of embedded computations. + for (auto* instruction : instructions()) { + auto process_called_computations = + [&](std::vector called_computations) { + // Put the called computations in reverse order onto the stack. + // Otherwise we don't match the recursive enumeration of + // computations, which processes the first called computation first. + absl::c_reverse(called_computations); + for (HloComputation* called_computation : called_computations) { + if (visited.insert(called_computation).second) { + st.emplace(called_computation, nullptr); + } + } + }; + process_called_computations(instruction->called_computations()); + while (!st.empty()) { + auto cur = st.top(); + st.pop(); + HloComputation* computation = cur.first; + HloInstruction* next_instruction = + computation->NextInstruction(cur.second); + if (next_instruction == nullptr) { + post_order.push_back(computation); + } else { + st.emplace(computation, next_instruction); + process_called_computations(next_instruction->called_computations()); + } + } + } return post_order; } @@ -1279,7 +1300,6 @@ void SortClonedInstructionUsersAndControlLists( const HloCloneContext& context, absl::FunctionRef replace, const HloComputation::InstructionList& sorted_instructions) { - using InstructionSorter = MappedPtrContainerSorter; auto instruction_mapper = [&context, replace](const HloInstruction* i) { return context.FindInstruction(replace(i)); }; diff --git a/tensorflow/compiler/xla/hlo/ir/hlo_computation.h b/tensorflow/compiler/xla/hlo/ir/hlo_computation.h index d3aeedc3c7c133..1da5ee5109e37e 100644 --- a/tensorflow/compiler/xla/hlo/ir/hlo_computation.h +++ b/tensorflow/compiler/xla/hlo/ir/hlo_computation.h @@ -782,6 +782,13 @@ class HloComputation { Status RemoveInstructionImpl(HloInstruction* instruction, bool ignore_safety_check); + // Finds the next instruction in the 'instructions_' list after 'current'. + // 'current' must either be nullptr or an instruction that is part of this + // computation. If it is nullptr, next_instruction returns the first + // instruction of the computation. Returns nullptr if there is no next + // instruction. + HloInstruction* NextInstruction(HloInstruction* current); + std::string name_; int64_t unique_id_; HloInstruction* root_instruction_; From d8033773e2f546acd8520aa989b66f00281109be Mon Sep 17 00:00:00 2001 From: weihanmines Date: Mon, 17 Jul 2023 14:38:18 +0000 Subject: [PATCH 373/376] weekly sync 230717 after solving conflicts --- .../xla/python/pjrt_ifrt/xla_sharding_test.cc | 24 ---------- tensorflow/compiler/xla/service/gpu/BUILD | 11 +---- .../xla/service/gpu/cublas_lt_matmul_thunk.cc | 10 +---- .../xla/service/gpu/cublas_lt_matmul_thunk.h | 6 --- .../compiler/xla/service/gpu/matmul_utils.h | 45 ------------------- 5 files changed, 2 insertions(+), 94 deletions(-) diff --git a/tensorflow/compiler/xla/python/pjrt_ifrt/xla_sharding_test.cc b/tensorflow/compiler/xla/python/pjrt_ifrt/xla_sharding_test.cc index 5185bb38eb7070..304b22247a0b3c 100644 --- a/tensorflow/compiler/xla/python/pjrt_ifrt/xla_sharding_test.cc +++ b/tensorflow/compiler/xla/python/pjrt_ifrt/xla_sharding_test.cc @@ -77,12 +77,8 @@ TEST_P(HloShardingTest, DisassembleWithReplication) { TEST_P(HloShardingTest, IndexDomainsWithTile) { auto device_list = GetDevices({0, 1}); // 2-way sharded along axis 0, 1-way sharded along axis 1. -<<<<<<< HEAD - auto xla_hlo_sharding = xla::HloSharding::Tile(xla::TileAssignment((absl::Span){2, 1})); -======= auto xla_hlo_sharding = xla::HloSharding::Tile( xla::TileAssignment((absl::Span){2, 1})); ->>>>>>> upstream/master std::shared_ptr sharding = HloSharding::Create(device_list, xla_hlo_sharding); @@ -100,12 +96,8 @@ TEST_P(HloShardingTest, IndexDomainsWithTile) { TEST_P(HloShardingTest, DisassembleWithTile) { auto device_list = GetDevices({0, 1}); // 2-way sharded along axis 0, 1-way sharded along axis 1. -<<<<<<< HEAD - auto xla_hlo_sharding = xla::HloSharding::Tile(xla::TileAssignment((absl::Span){2, 1})); -======= auto xla_hlo_sharding = xla::HloSharding::Tile( xla::TileAssignment((absl::Span){2, 1})); ->>>>>>> upstream/master std::shared_ptr sharding = HloSharding::Create(device_list, xla_hlo_sharding); @@ -125,12 +117,8 @@ TEST_P(HloShardingTest, DisassembleWithTile) { TEST_P(HloShardingTest, IndexDomainsWithUnevenTile) { auto device_list = GetDevices({0, 1}); // 2-way sharded along axis 0, 1-way sharded along axis 1. -<<<<<<< HEAD - auto xla_hlo_sharding = xla::HloSharding::Tile(xla::TileAssignment((absl::Span){2, 1})); -======= auto xla_hlo_sharding = xla::HloSharding::Tile( xla::TileAssignment((absl::Span){2, 1})); ->>>>>>> upstream/master std::shared_ptr sharding = HloSharding::Create(device_list, xla_hlo_sharding); @@ -148,12 +136,8 @@ TEST_P(HloShardingTest, IndexDomainsWithUnevenTile) { TEST_P(HloShardingTest, DisassembleWithUnevenTile) { auto device_list = GetDevices({0, 1}); // 2-way sharded along axis 0, 1-way sharded along axis 1. -<<<<<<< HEAD - auto xla_hlo_sharding = xla::HloSharding::Tile(xla::TileAssignment((absl::Span){2, 1})); -======= auto xla_hlo_sharding = xla::HloSharding::Tile( xla::TileAssignment((absl::Span){2, 1})); ->>>>>>> upstream/master std::shared_ptr sharding = HloSharding::Create(device_list, xla_hlo_sharding); @@ -315,12 +299,8 @@ TEST_P(HloShardingTest, DisassembleWithSubgroupMaximalSlowPath) { TEST_P(HloShardingTest, DisassembleFailsWithInvalidDeviceCount) { auto device_list = GetDevices({0}); // 2-way sharded along axis 0, 1-way sharded along axis 1. -<<<<<<< HEAD - auto xla_hlo_sharding = xla::HloSharding::Tile(xla::TileAssignment((absl::Span){2, 1})); -======= auto xla_hlo_sharding = xla::HloSharding::Tile( xla::TileAssignment((absl::Span){2, 1})); ->>>>>>> upstream/master std::shared_ptr sharding = HloSharding::Create(device_list, xla_hlo_sharding); @@ -334,12 +314,8 @@ TEST_P(HloShardingTest, DisassembleFailsWithInvalidDeviceCount) { TEST_P(HloShardingTest, DisassembleFailsWithMismatchingShapeDimsSize) { auto device_list = GetDevices({0, 1}); // 2-way sharded along axis 0, 1-way sharded along axis 1. -<<<<<<< HEAD - auto xla_hlo_sharding = xla::HloSharding::Tile(xla::TileAssignment((absl::Span){2, 1})); -======= auto xla_hlo_sharding = xla::HloSharding::Tile( xla::TileAssignment((absl::Span){2, 1})); ->>>>>>> upstream/master std::shared_ptr sharding = HloSharding::Create(device_list, xla_hlo_sharding); diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 0446ce64d9b3c4..37085399ce383b 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -1253,14 +1253,10 @@ cc_library( ]) + if_cuda_is_configured([ "//tensorflow/compiler/xla/stream_executor/cuda:cublas_lt_header", "//tensorflow/compiler/xla/stream_executor/cuda:cublas_plugin", -<<<<<<< HEAD + "//tensorflow/tsl/platform:statusor", ]) + if_rocm_is_configured([ "//tensorflow/compiler/xla/stream_executor/rocm:hipblas_lt_header", - ]), -======= - "//tensorflow/tsl/platform:statusor", ]) + ["//tensorflow/tsl/platform:logging"], ->>>>>>> upstream/master ) cc_library( @@ -1385,13 +1381,8 @@ cc_library( name = "matmul_utils", srcs = ["matmul_utils.cc"], hdrs = ["matmul_utils.h"], -<<<<<<< HEAD compatible_with = get_compatible_with_cloud(), defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured(["TENSORFLOW_USE_ROCM=1"]), -======= - compatible_with = get_compatible_with_portable(), - defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), ->>>>>>> upstream/master deps = [ ":backend_configs_cc", ":ir_emission_utils", diff --git a/tensorflow/compiler/xla/service/gpu/cublas_lt_matmul_thunk.cc b/tensorflow/compiler/xla/service/gpu/cublas_lt_matmul_thunk.cc index f191e485d718e2..d0b794ded531cb 100644 --- a/tensorflow/compiler/xla/service/gpu/cublas_lt_matmul_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/cublas_lt_matmul_thunk.cc @@ -14,12 +14,9 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/gpu/cublas_lt_matmul_thunk.h" -<<<<<<< HEAD #if GOOGLE_CUDA || TF_HIPBLASLT -======= #include ->>>>>>> upstream/master #include #include "tensorflow/compiler/xla/service/gpu/matmul_utils.h" @@ -66,13 +63,8 @@ Status CublasLtMatmulThunk::ExecuteOnStream(const ExecuteParams& params) { GetMatmulPlan(params.stream)); if (!algorithm_) { TF_ASSIGN_OR_RETURN( -<<<<<<< HEAD - std::vector algorithms, - plan_.GetAlgorithms(params.stream)); -======= std::vector algorithms, plan->GetAlgorithms(params.stream)); ->>>>>>> upstream/master TF_RET_CHECK(algorithm_idx_ >= 0 && algorithm_idx_ < algorithms.size()); algorithm_ = algorithms[algorithm_idx_]; } @@ -132,4 +124,4 @@ StatusOr CublasLtMatmulThunk::GetMatmulPlan( } // namespace gpu } // namespace xla -#endif \ No newline at end of file +#endif diff --git a/tensorflow/compiler/xla/service/gpu/cublas_lt_matmul_thunk.h b/tensorflow/compiler/xla/service/gpu/cublas_lt_matmul_thunk.h index bf5a9176675318..45bb4bcaf5b831 100644 --- a/tensorflow/compiler/xla/service/gpu/cublas_lt_matmul_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/cublas_lt_matmul_thunk.h @@ -16,15 +16,12 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUBLAS_LT_MATMUL_THUNK_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUBLAS_LT_MATMUL_THUNK_H_ -<<<<<<< HEAD #if TENSORFLOW_USE_ROCM #include "rocm/rocm_config.h" #endif #if GOOGLE_CUDA || TF_HIPBLASLT -======= #include ->>>>>>> upstream/master #include #include @@ -34,13 +31,10 @@ limitations under the License. #include "tensorflow/compiler/xla/status.h" #if GOOGLE_CUDA #include "tensorflow/compiler/xla/stream_executor/cuda/cuda_blas_lt.h" -<<<<<<< HEAD #else #include "tensorflow/compiler/xla/stream_executor/rocm/hip_blas_lt.h" #endif -======= #include "tensorflow/tsl/platform/statusor.h" ->>>>>>> upstream/master namespace xla { namespace gpu { diff --git a/tensorflow/compiler/xla/service/gpu/matmul_utils.h b/tensorflow/compiler/xla/service/gpu/matmul_utils.h index 91789478952c42..52b12ecb665f04 100644 --- a/tensorflow/compiler/xla/service/gpu/matmul_utils.h +++ b/tensorflow/compiler/xla/service/gpu/matmul_utils.h @@ -207,51 +207,6 @@ StatusOr AsBlasLtEpilogue( class MatmulPlan { public: -<<<<<<< HEAD - template ::value || - std::is_same::value>> - static StatusOr For(CublasLtMatmulMaybeF8Op op) { - mlir::mhlo::DotDimensionNumbersAttr dot_dims = op.getDotDimensionNumbers(); - - int64_t compute_precision = 0; // Default - if (op.getPrecisionConfig().has_value()) { - auto precision_config = op.getPrecisionConfig(); - for (auto attr : precision_config.value()) { - int64_t value = static_cast( - attr.template cast().getValue()); - if (value > compute_precision) { - compute_precision = value; - } - } - } - - Shape bias_shape; - if (op.getBias() != nullptr) { - bias_shape = GetShape(op.getBias()); - } - TF_ASSIGN_OR_RETURN( - GemmConfig config, - GemmConfig::For( - GetShape(op.getA()), dot_dims.getLhsBatchingDimensions(), - dot_dims.getLhsContractingDimensions(), GetShape(op.getB()), - dot_dims.getRhsBatchingDimensions(), - dot_dims.getRhsContractingDimensions(), GetShape(op.getC()), - op.getBias() == nullptr ? nullptr : &bias_shape, - GetShape(op.getD()), op.getAlphaReal().convertToDouble(), - op.getAlphaImag().convertToDouble(), op.getBeta().convertToDouble(), - op.getAlgorithm(), compute_precision)); - - TF_ASSIGN_OR_RETURN(se::gpu::BlasLt::Epilogue epilogue, - AsBlasLtEpilogue(op.getEpilogue())); - return From(config, epilogue); - } - -======= ->>>>>>> upstream/master static StatusOr From(const GemmConfig& config, se::gpu::BlasLt::Epilogue epilogue); From 44ec8a332ca4844f6c8c4fef36b98bdeda0a5fbc Mon Sep 17 00:00:00 2001 From: weihanmines Date: Mon, 17 Jul 2023 19:37:51 +0000 Subject: [PATCH 374/376] fixed API changes in a few places --- tensorflow/compiler/xla/debug_options_flags.cc | 2 +- tensorflow/compiler/xla/service/gpu/BUILD | 6 +++--- .../compiler/xla/service/gpu/cublas_lt_matmul_thunk.cc | 4 ++-- .../compiler/xla/service/gpu/cublas_lt_matmul_thunk.h | 4 ++-- .../compiler/xla/service/gpu/ir_emitter_unnested.cc | 2 +- tensorflow/core/framework/kernel_shape_util.cc | 8 ++++---- tensorflow/core/kernels/gpu_fusion_ops_convbiasactv.cc | 4 ++-- 7 files changed, 15 insertions(+), 15 deletions(-) diff --git a/tensorflow/compiler/xla/debug_options_flags.cc b/tensorflow/compiler/xla/debug_options_flags.cc index 556ea3239d0bf2..e9d6e29ba31fb1 100644 --- a/tensorflow/compiler/xla/debug_options_flags.cc +++ b/tensorflow/compiler/xla/debug_options_flags.cc @@ -104,7 +104,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { // TODO(b/258036887): Enable cuda_graph_level=2. Currently blocked by CUDA 12 // integration. - opts.set_xla_gpu_cuda_graph_level(1); + opts.set_xla_gpu_cuda_graph_level(0); opts.set_xla_gpu_cuda_graph_num_runs_to_instantiate(-1); opts.set_xla_gpu_enable_persistent_temp_buffers(false); opts.set_xla_gpu_cuda_graph_min_graph_size(5); diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 37085399ce383b..f39efc59ba4397 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -1250,13 +1250,13 @@ cc_library( "//tensorflow/compiler/xla/stream_executor:device_memory", "//tensorflow/compiler/xla/stream_executor:stream_executor_headers", "//tensorflow/tsl/platform:logging", + "//tensorflow/tsl/platform:statusor", ]) + if_cuda_is_configured([ "//tensorflow/compiler/xla/stream_executor/cuda:cublas_lt_header", "//tensorflow/compiler/xla/stream_executor/cuda:cublas_plugin", - "//tensorflow/tsl/platform:statusor", ]) + if_rocm_is_configured([ "//tensorflow/compiler/xla/stream_executor/rocm:hipblas_lt_header", - ]) + ["//tensorflow/tsl/platform:logging"], + ]), ) cc_library( @@ -1381,7 +1381,7 @@ cc_library( name = "matmul_utils", srcs = ["matmul_utils.cc"], hdrs = ["matmul_utils.h"], - compatible_with = get_compatible_with_cloud(), + # compatible_with = get_compatible_with_cloud(), defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured(["TENSORFLOW_USE_ROCM=1"]), deps = [ ":backend_configs_cc", diff --git a/tensorflow/compiler/xla/service/gpu/cublas_lt_matmul_thunk.cc b/tensorflow/compiler/xla/service/gpu/cublas_lt_matmul_thunk.cc index d0b794ded531cb..78ea68419d5367 100644 --- a/tensorflow/compiler/xla/service/gpu/cublas_lt_matmul_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/cublas_lt_matmul_thunk.cc @@ -35,7 +35,7 @@ namespace gpu { CublasLtMatmulThunk::CublasLtMatmulThunk( ThunkInfo thunk_info, GemmConfig gemm_config, - se::cuda::BlasLt::Epilogue epilogue, int64_t algorithm_idx, + se::gpu::BlasLt::Epilogue epilogue, int64_t algorithm_idx, BufferAllocation::Slice a_buffer, BufferAllocation::Slice b_buffer, BufferAllocation::Slice c_buffer, BufferAllocation::Slice d_buffer, BufferAllocation::Slice bias_buffer, BufferAllocation::Slice aux_buffer, @@ -63,7 +63,7 @@ Status CublasLtMatmulThunk::ExecuteOnStream(const ExecuteParams& params) { GetMatmulPlan(params.stream)); if (!algorithm_) { TF_ASSIGN_OR_RETURN( - std::vector algorithms, + std::vector algorithms, plan->GetAlgorithms(params.stream)); TF_RET_CHECK(algorithm_idx_ >= 0 && algorithm_idx_ < algorithms.size()); algorithm_ = algorithms[algorithm_idx_]; diff --git a/tensorflow/compiler/xla/service/gpu/cublas_lt_matmul_thunk.h b/tensorflow/compiler/xla/service/gpu/cublas_lt_matmul_thunk.h index 45bb4bcaf5b831..662e677426a54d 100644 --- a/tensorflow/compiler/xla/service/gpu/cublas_lt_matmul_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/cublas_lt_matmul_thunk.h @@ -42,7 +42,7 @@ namespace gpu { class CublasLtMatmulThunk : public Thunk { public: CublasLtMatmulThunk(ThunkInfo thunk_info, GemmConfig gemm_config, - se::cuda::BlasLt::Epilogue epilogue, + se::gpu::BlasLt::Epilogue epilogue, int64_t algorithm_idx, BufferAllocation::Slice a_buffer, BufferAllocation::Slice b_buffer, BufferAllocation::Slice c_buffer, @@ -67,7 +67,7 @@ class CublasLtMatmulThunk : public Thunk { matmul_plans_cache_ ABSL_GUARDED_BY(matmul_plans_cache_mutex_); GemmConfig gemm_config_; - se::cuda::BlasLt::Epilogue epilogue_; + se::gpu::BlasLt::Epilogue epilogue_; int64_t algorithm_idx_; BufferAllocation::Slice a_buffer_; BufferAllocation::Slice b_buffer_; diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 2bd5c805f20a73..c66438cb26eff3 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -1141,7 +1141,7 @@ Status IrEmitterUnnested::EmitCublasLtMatmulThunk(mlir::Operation* op) { } TF_ASSIGN_OR_RETURN(GemmConfig gemm_config, GemmConfig::For(matmul)); - TF_ASSIGN_OR_RETURN(se::cuda::BlasLt::Epilogue epilogue, + TF_ASSIGN_OR_RETURN(se::gpu::BlasLt::Epilogue epilogue, cublas_lt::AsBlasLtEpilogue(matmul.getEpilogue())); auto thunk = std::make_unique( GetThunkInfo(op), std::move(gemm_config), epilogue, matmul.getAlgorithm(), diff --git a/tensorflow/core/framework/kernel_shape_util.cc b/tensorflow/core/framework/kernel_shape_util.cc index 071821ce4a56d6..c03540a9dad034 100644 --- a/tensorflow/core/framework/kernel_shape_util.cc +++ b/tensorflow/core/framework/kernel_shape_util.cc @@ -21,10 +21,10 @@ limitations under the License. namespace tensorflow { Status GetWindowedOutputSizeVerbose(int64_t input_size, int64_t filter_size, - int64_t dilation_rate, int64_t stride, - Padding padding_type, int64_t* output_size, - int64_t* padding_before, - int64_t* padding_after) { + int64_t dilation_rate, int64_t stride, + Padding padding_type, int64_t* output_size, + int64_t* padding_before, + int64_t* padding_after) { if (stride <= 0) { return errors::InvalidArgument("Stride must be > 0, but got ", stride); } diff --git a/tensorflow/core/kernels/gpu_fusion_ops_convbiasactv.cc b/tensorflow/core/kernels/gpu_fusion_ops_convbiasactv.cc index 787e10009c96da..ea17709afeba3d 100644 --- a/tensorflow/core/kernels/gpu_fusion_ops_convbiasactv.cc +++ b/tensorflow/core/kernels/gpu_fusion_ops_convbiasactv.cc @@ -93,14 +93,14 @@ class ROCmFusionKernelConvolutionBiasActivation : public OpKernel { int64 output_rows = 0, padding_left = 0, padding_right = 0; OP_REQUIRES_OK( - ctx, GetWindowedOutputSizeVerboseV2( + ctx, GetWindowedOutputSizeVerbose( input_rows, filter_rows, dilation_rows, stride_rows, padding_type_, &output_rows, &padding_left, &padding_right)); int64 padding_rows = padding_left + padding_right; int64 output_cols = 0, padding_top = 0, padding_bottom = 0; OP_REQUIRES_OK( - ctx, GetWindowedOutputSizeVerboseV2( + ctx, GetWindowedOutputSizeVerbose( input_cols, filter_cols, dilation_cols, stride_cols, padding_type_, &output_cols, &padding_top, &padding_bottom)); int64 padding_cols = padding_top + padding_bottom; From 4d7ba9ddeb22daefabd4423f2eb201becd602a9d Mon Sep 17 00:00:00 2001 From: weihanmines Date: Tue, 18 Jul 2023 00:23:10 +0000 Subject: [PATCH 375/376] attemp to fix xla_sharding_serdes_test failure --- .../compiler/xla/python/pjrt_ifrt/xla_sharding_serdes_test.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tensorflow/compiler/xla/python/pjrt_ifrt/xla_sharding_serdes_test.cc b/tensorflow/compiler/xla/python/pjrt_ifrt/xla_sharding_serdes_test.cc index a98a0271a03a41..1a5bc8ef7b5804 100644 --- a/tensorflow/compiler/xla/python/pjrt_ifrt/xla_sharding_serdes_test.cc +++ b/tensorflow/compiler/xla/python/pjrt_ifrt/xla_sharding_serdes_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include #include "absl/functional/bind_front.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_sharding.h" #include "tensorflow/compiler/xla/python/ifrt/sharding_serdes.h" #include "tensorflow/compiler/xla/python/ifrt/sharding_test_util.h" @@ -34,7 +35,7 @@ class XlaShardingSerDesTest : public test_util::ShardingTest {}; TEST_P(XlaShardingSerDesTest, HloShardingRoundTrip) { auto device_list = GetDevices({0, 1}); - auto xla_hlo_sharding = xla::HloSharding::Tile(xla::TileAssignment({2, 1})); + auto xla_hlo_sharding = xla::HloSharding::Tile(xla::TileAssignment(absl::Span({2, 1}))); auto sharding = HloSharding::Create(device_list, /*xla_hlo_sharding=*/xla_hlo_sharding); From 36719ed81ee27959b2e8664c17bf4609435ca652 Mon Sep 17 00:00:00 2001 From: weihanmines Date: Tue, 18 Jul 2023 14:26:04 +0000 Subject: [PATCH 376/376] disable hlo-llvm ir tests and unaray op gpu test --- tensorflow/compiler/tests/BUILD | 1 + tensorflow/compiler/xla/service/gpu/tests/BUILD | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 213cb96046ec8b..9af16abb6f5618 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -1707,6 +1707,7 @@ tf_xla_py_strict_test( python_version = "PY3", shard_count = 50, tags = [ + "no_rocm", "no_cuda_asan", # times out "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip "noasan", #times out diff --git a/tensorflow/compiler/xla/service/gpu/tests/BUILD b/tensorflow/compiler/xla/service/gpu/tests/BUILD index 2158e556a850d5..4249b66b6c8d66 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/BUILD +++ b/tensorflow/compiler/xla/service/gpu/tests/BUILD @@ -783,9 +783,9 @@ glob_lit_tests( "calling_convention_amdgcn.hlo": ["no_cuda_asan"], "copy_amdgcn.hlo": ["no_cuda_asan"], "copy_nested_amdgcn.hlo": ["no_cuda_asan", "no_rocm"], - "dynamic_update_slice_inplace_amdgcn.hlo": ["no_cuda_asan"], + "dynamic_update_slice_inplace_amdgcn.hlo": ["no_cuda_asan", "no_rocm"], "fused_scatter_amdgcn.hlo": ["no_cuda_asan", "no_rocm"], - "fused_slice_amdgcn.hlo": ["no_cuda_asan"], + "fused_slice_amdgcn.hlo": ["no_cuda_asan", "no_rocm"], "fused_slice_different_operands_amdgcn.hlo": ["no_cuda_asan"], "fusion_amdgcn.hlo": ["no_cuda_asan"], "launch_dimensions_amdgcn.hlo": ["no_cuda_asan", "no_rocm"],